oxi_ai/context.rs
1//! Conversation context management
2
3use super::{Message, Tool};
4use serde::{Deserialize, Serialize};
5
6/// Conversation context for LLM interactions.
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct Context {
9 /// System prompt sent with each request
10 #[serde(skip_serializing_if = "Option::is_none")]
11 pub system_prompt: Option<String>,
12
13 /// Conversation history
14 pub messages: Vec<Message>,
15
16 /// Available tools for this context
17 #[serde(default)]
18 pub tools: Vec<Tool>,
19}
20
21impl Default for Context {
22 fn default() -> Self {
23 Self::new()
24 }
25}
26
27impl Context {
28 /// Create a new empty context
29 ///
30 /// # Examples
31 ///
32 /// ```
33 /// use oxi_ai::Context;
34 /// let mut ctx = Context::new();
35 /// ctx.set_system_prompt("You are a helpful assistant.");
36 /// ```
37 pub fn new() -> Self {
38 Self {
39 system_prompt: None,
40 messages: Vec::new(),
41 tools: Vec::new(),
42 }
43 }
44
45 /// Create a context with a system prompt
46 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
47 self.system_prompt = Some(prompt.into());
48 self
49 }
50
51 /// Add a message to the context
52 ///
53 /// # Examples
54 ///
55 /// ```
56 /// use oxi_ai::{Context, Message, UserMessage};
57 /// let mut ctx = Context::new();
58 /// ctx.add_message(Message::User(UserMessage::new("Hello!")));
59 /// assert_eq!(ctx.len(), 1);
60 /// ```
61 pub fn add_message(&mut self, message: Message) {
62 self.messages.push(message);
63 }
64
65 /// Get a message by index
66 pub fn message(&self, index: usize) -> Option<&Message> {
67 self.messages.get(index)
68 }
69
70 /// Get the last message
71 ///
72 /// # Examples
73 ///
74 /// ```
75 /// use oxi_ai::{Context, Message, UserMessage};
76 /// let mut ctx = Context::new();
77 /// ctx.add_message(Message::User(UserMessage::new("First")));
78 /// ctx.add_message(Message::User(UserMessage::new("Second")));
79 /// assert!(ctx.last_message().is_some());
80 /// ```
81 pub fn last_message(&self) -> Option<&Message> {
82 self.messages.last()
83 }
84
85 /// Check if context has any messages
86 pub fn is_empty(&self) -> bool {
87 self.messages.is_empty()
88 }
89
90 /// Get number of messages
91 pub fn len(&self) -> usize {
92 self.messages.len()
93 }
94
95 /// Set the system prompt
96 pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
97 self.system_prompt = Some(prompt.into());
98 }
99
100 /// Clear the system prompt
101 pub fn clear_system_prompt(&mut self) {
102 self.system_prompt = None;
103 }
104
105 /// Set available tools
106 ///
107 /// # Examples
108 ///
109 /// ```
110 /// use oxi_ai::{Context, Tool};
111 /// let mut ctx = Context::new();
112 /// let tool = Tool::new(
113 /// "search",
114 /// "Search the web",
115 /// serde_json::json!({"type": "object", "properties": {}}),
116 /// );
117 /// ctx.set_tools(vec![tool]);
118 /// assert_eq!(ctx.tools.len(), 1);
119 /// ```
120 pub fn set_tools(&mut self, tools: Vec<Tool>) {
121 self.tools = tools;
122 }
123
124 /// Add a tool
125 pub fn add_tool(&mut self, tool: Tool) {
126 self.tools.push(tool);
127 }
128
129 /// Get the system prompt for this context.
130 pub fn system_prompt(&self) -> Option<&str> {
131 self.system_prompt.as_deref()
132 }
133
134 /// Serialize context to a JSON string.
135 pub fn to_json(&self) -> Result<String, serde_json::Error> {
136 serde_json::to_string(self)
137 }
138
139 /// Deserialize a context from a JSON string.
140 pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
141 serde_json::from_str(json)
142 }
143
144 /// Duplicate this context
145 pub fn duplicate(&self) -> Self {
146 Self {
147 system_prompt: self.system_prompt.clone(),
148 messages: self.messages.clone(),
149 tools: self.tools.clone(),
150 }
151 }
152}
153
154impl From<Vec<Message>> for Context {
155 fn from(messages: Vec<Message>) -> Self {
156 Self {
157 system_prompt: None,
158 messages,
159 tools: Vec::new(),
160 }
161 }
162}
163
164impl From<Message> for Context {
165 fn from(message: Message) -> Self {
166 Self {
167 system_prompt: None,
168 messages: vec![message],
169 tools: Vec::new(),
170 }
171 }
172}