1use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
11#[serde(rename_all = "lowercase")]
12pub enum Role {
13 System,
15 User,
17 Assistant,
19 Tool,
21}
22
23impl From<&str> for Role {
24 fn from(s: &str) -> Self {
26 match s.to_lowercase().as_str() {
27 "system" => Role::System,
28 "user" => Role::User,
29 "assistant" => Role::Assistant,
30 "tool" => Role::Tool,
31 _ => Role::Assistant, }
33 }
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct ChatMessage {
39 pub role: Role,
41 #[serde(default, deserialize_with = "deserialize_null_as_empty_string")]
43 pub content: String,
44 #[serde(skip_serializing_if = "Option::is_none")]
46 pub name: Option<String>,
47 #[serde(skip_serializing_if = "Option::is_none")]
49 pub tool_calls: Option<Vec<ToolCall>>,
50 #[serde(skip_serializing_if = "Option::is_none")]
52 pub tool_call_id: Option<String>,
53}
54
55fn deserialize_null_as_empty_string<'de, D>(deserializer: D) -> Result<String, D::Error>
57where
58 D: serde::Deserializer<'de>,
59{
60 use serde::Deserialize;
61 Option::<String>::deserialize(deserializer).map(|opt| opt.unwrap_or_default())
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct ToolCall {
67 pub id: String,
69 #[serde(rename = "type")]
71 pub call_type: String,
72 pub function: FunctionCall,
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct FunctionCall {
79 pub name: String,
81 pub arguments: String,
83}
84
85impl ChatMessage {
86 pub fn system(content: impl Into<String>) -> Self {
88 Self {
89 role: Role::System,
90 content: content.into(),
91 name: None,
92 tool_calls: None,
93 tool_call_id: None,
94 }
95 }
96
97 pub fn user(content: impl Into<String>) -> Self {
99 Self {
100 role: Role::User,
101 content: content.into(),
102 name: None,
103 tool_calls: None,
104 tool_call_id: None,
105 }
106 }
107
108 pub fn assistant(content: impl Into<String>) -> Self {
110 Self {
111 role: Role::Assistant,
112 content: content.into(),
113 name: None,
114 tool_calls: None,
115 tool_call_id: None,
116 }
117 }
118
119 pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
121 Self {
122 role: Role::Tool,
123 content: content.into(),
124 name: None,
125 tool_calls: None,
126 tool_call_id: Some(tool_call_id.into()),
127 }
128 }
129}
130
131#[derive(Debug, Clone)]
133pub struct ChatSession {
134 pub messages: Vec<ChatMessage>,
136 pub system_prompt: Option<String>,
138 pub metadata: std::collections::HashMap<String, String>,
140}
141
142impl ChatSession {
143 pub fn new() -> Self {
145 Self {
146 messages: Vec::new(),
147 system_prompt: None,
148 metadata: std::collections::HashMap::new(),
149 }
150 }
151
152 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
154 self.system_prompt = Some(prompt.into());
155 self
156 }
157
158 pub fn add_message(&mut self, message: ChatMessage) {
160 self.messages.push(message);
161 }
162
163 pub fn add_user_message(&mut self, content: impl Into<String>) {
165 self.messages.push(ChatMessage::user(content));
166 }
167
168 pub fn add_assistant_message(&mut self, content: impl Into<String>) {
170 self.messages.push(ChatMessage::assistant(content));
171 }
172
173 pub fn get_messages(&self) -> Vec<ChatMessage> {
175 let mut messages = Vec::new();
176
177 if let Some(ref system_prompt) = self.system_prompt {
178 messages.push(ChatMessage::system(system_prompt.clone()));
179 }
180
181 messages.extend(self.messages.clone());
182 messages
183 }
184
185 pub fn clear(&mut self) {
187 self.messages.clear();
188 }
189
190 pub fn set_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
192 self.metadata.insert(key.into(), value.into());
193 }
194
195 pub fn get_metadata(&self, key: &str) -> Option<&String> {
197 self.metadata.get(key)
198 }
199
200 pub fn remove_metadata(&mut self, key: &str) -> Option<String> {
202 self.metadata.remove(key)
203 }
204
205 pub fn get_summary(&self) -> String {
207 let mut summary = String::new();
208 summary.push_str(&format!("Total messages: {}\n", self.messages.len()));
209
210 let user_msgs = self
211 .messages
212 .iter()
213 .filter(|m| matches!(m.role, Role::User))
214 .count();
215 let assistant_msgs = self
216 .messages
217 .iter()
218 .filter(|m| matches!(m.role, Role::Assistant))
219 .count();
220 let tool_msgs = self
221 .messages
222 .iter()
223 .filter(|m| matches!(m.role, Role::Tool))
224 .count();
225
226 summary.push_str(&format!("User messages: {}\n", user_msgs));
227 summary.push_str(&format!("Assistant messages: {}\n", assistant_msgs));
228 summary.push_str(&format!("Tool messages: {}\n", tool_msgs));
229
230 if !self.metadata.is_empty() {
231 summary.push_str("\nSession metadata:\n");
232 for (key, value) in &self.metadata {
233 summary.push_str(&format!(" {}: {}\n", key, value));
234 }
235 }
236
237 summary
238 }
239}
240
241impl Default for ChatSession {
242 fn default() -> Self {
244 Self::new()
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251
252 #[test]
254 fn test_role_from_str() {
255 assert_eq!(Role::from("system"), Role::System);
256 assert_eq!(Role::from("user"), Role::User);
257 assert_eq!(Role::from("assistant"), Role::Assistant);
258 assert_eq!(Role::from("tool"), Role::Tool);
259 assert_eq!(Role::from("unknown"), Role::Assistant); assert_eq!(Role::from("SYSTEM"), Role::System); }
262
263 #[test]
265 fn test_chat_message_constructors() {
266 let system_msg = ChatMessage::system("System message");
267 assert_eq!(system_msg.role, Role::System);
268 assert_eq!(system_msg.content, "System message");
269 assert!(system_msg.name.is_none());
270 assert!(system_msg.tool_calls.is_none());
271 assert!(system_msg.tool_call_id.is_none());
272
273 let user_msg = ChatMessage::user("User message");
274 assert_eq!(user_msg.role, Role::User);
275 assert_eq!(user_msg.content, "User message");
276
277 let assistant_msg = ChatMessage::assistant("Assistant message");
278 assert_eq!(assistant_msg.role, Role::Assistant);
279 assert_eq!(assistant_msg.content, "Assistant message");
280
281 let tool_msg = ChatMessage::tool("Tool result", "tool_call_123");
282 assert_eq!(tool_msg.role, Role::Tool);
283 assert_eq!(tool_msg.content, "Tool result");
284 assert_eq!(tool_msg.tool_call_id, Some("tool_call_123".to_string()));
285 }
286
287 #[test]
289 fn test_chat_session_new() {
290 let session = ChatSession::new();
291 assert!(session.messages.is_empty());
292 assert!(session.system_prompt.is_none());
293 }
294
295 #[test]
297 fn test_chat_session_with_system_prompt() {
298 let session = ChatSession::new().with_system_prompt("Test system prompt");
299 assert_eq!(
300 session.system_prompt,
301 Some("Test system prompt".to_string())
302 );
303 }
304
305 #[test]
307 fn test_chat_session_add_message() {
308 let mut session = ChatSession::new();
309 let msg = ChatMessage::user("Test message");
310 session.add_message(msg);
311 assert_eq!(session.messages.len(), 1);
312 }
313
314 #[test]
316 fn test_chat_session_add_user_message() {
317 let mut session = ChatSession::new();
318 session.add_user_message("Test user message");
319 assert_eq!(session.messages.len(), 1);
320 assert_eq!(session.messages[0].role, Role::User);
321 assert_eq!(session.messages[0].content, "Test user message");
322 }
323
324 #[test]
326 fn test_chat_session_add_assistant_message() {
327 let mut session = ChatSession::new();
328 session.add_assistant_message("Test assistant message");
329 assert_eq!(session.messages.len(), 1);
330 assert_eq!(session.messages[0].role, Role::Assistant);
331 assert_eq!(session.messages[0].content, "Test assistant message");
332 }
333
334 #[test]
336 fn test_chat_session_get_messages() {
337 let mut session = ChatSession::new().with_system_prompt("System prompt");
338 session.add_user_message("User message");
339 session.add_assistant_message("Assistant message");
340
341 let messages = session.get_messages();
342 assert_eq!(messages.len(), 3); assert_eq!(messages[0].role, Role::System);
344 assert_eq!(messages[0].content, "System prompt");
345 assert_eq!(messages[1].role, Role::User);
346 assert_eq!(messages[1].content, "User message");
347 assert_eq!(messages[2].role, Role::Assistant);
348 assert_eq!(messages[2].content, "Assistant message");
349 }
350
351 #[test]
353 fn test_chat_session_clear() {
354 let mut session = ChatSession::new();
355 session.add_user_message("Test message");
356 assert!(!session.messages.is_empty());
357
358 session.clear();
359 assert!(session.messages.is_empty());
360 }
361}