1use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
11#[serde(rename_all = "lowercase")]
12pub enum Role {
13 System,
15 User,
17 Assistant,
19 Tool,
21}
22
23impl From<&str> for Role {
24 fn from(s: &str) -> Self {
26 match s.to_lowercase().as_str() {
27 "system" => Role::System,
28 "user" => Role::User,
29 "assistant" => Role::Assistant,
30 "tool" => Role::Tool,
31 _ => Role::Assistant, }
33 }
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct ChatMessage {
39 pub role: Role,
41 #[serde(default, deserialize_with = "deserialize_null_as_empty_string")]
43 pub content: String,
44 #[serde(skip_serializing_if = "Option::is_none")]
46 pub name: Option<String>,
47 #[serde(skip_serializing_if = "Option::is_none")]
49 pub tool_calls: Option<Vec<ToolCall>>,
50 #[serde(skip_serializing_if = "Option::is_none")]
52 pub tool_call_id: Option<String>,
53}
54
55fn deserialize_null_as_empty_string<'de, D>(deserializer: D) -> Result<String, D::Error>
57where
58 D: serde::Deserializer<'de>,
59{
60 use serde::Deserialize;
61 Option::<String>::deserialize(deserializer).map(|opt| opt.unwrap_or_default())
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct ToolCall {
67 pub id: String,
69 #[serde(rename = "type")]
71 pub call_type: String,
72 pub function: FunctionCall,
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct FunctionCall {
79 pub name: String,
81 pub arguments: String,
83}
84
85impl ChatMessage {
86 pub fn system(content: impl Into<String>) -> Self {
88 Self {
89 role: Role::System,
90 content: content.into(),
91 name: None,
92 tool_calls: None,
93 tool_call_id: None,
94 }
95 }
96
97 pub fn sys(content: impl Into<String>) -> Self {
99 Self::system(content)
100 }
101
102 pub fn user(content: impl Into<String>) -> Self {
104 Self {
105 role: Role::User,
106 content: content.into(),
107 name: None,
108 tool_calls: None,
109 tool_call_id: None,
110 }
111 }
112
113 pub fn msg(content: impl Into<String>) -> Self {
115 Self::user(content)
116 }
117
118 pub fn assistant(content: impl Into<String>) -> Self {
120 Self {
121 role: Role::Assistant,
122 content: content.into(),
123 name: None,
124 tool_calls: None,
125 tool_call_id: None,
126 }
127 }
128
129 pub fn reply(content: impl Into<String>) -> Self {
131 Self::assistant(content)
132 }
133
134 pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
136 Self {
137 role: Role::Tool,
138 content: content.into(),
139 name: None,
140 tool_calls: None,
141 tool_call_id: Some(tool_call_id.into()),
142 }
143 }
144}
145
146#[derive(Debug, Clone)]
148pub struct ChatSession {
149 pub messages: Vec<ChatMessage>,
151 pub system_prompt: Option<String>,
153 pub metadata: std::collections::HashMap<String, String>,
155}
156
157impl ChatSession {
158 pub fn new() -> Self {
160 Self {
161 messages: Vec::new(),
162 system_prompt: None,
163 metadata: std::collections::HashMap::new(),
164 }
165 }
166
167 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
169 self.system_prompt = Some(prompt.into());
170 self
171 }
172
173 pub fn add_message(&mut self, message: ChatMessage) {
175 self.messages.push(message);
176 }
177
178 pub fn add_user_message(&mut self, content: impl Into<String>) {
180 self.messages.push(ChatMessage::user(content));
181 }
182
183 pub fn add_assistant_message(&mut self, content: impl Into<String>) {
185 self.messages.push(ChatMessage::assistant(content));
186 }
187
188 pub fn add_sys(&mut self, content: impl Into<String>) {
190 self.messages.push(ChatMessage::system(content));
191 }
192
193 pub fn add_msg(&mut self, content: impl Into<String>) {
195 self.messages.push(ChatMessage::user(content));
196 }
197
198 pub fn add_reply(&mut self, content: impl Into<String>) {
200 self.messages.push(ChatMessage::assistant(content));
201 }
202
203 pub fn get_messages(&self) -> Vec<ChatMessage> {
205 let mut messages = Vec::new();
206
207 if let Some(ref system_prompt) = self.system_prompt {
208 messages.push(ChatMessage::system(system_prompt.clone()));
209 }
210
211 messages.extend(self.messages.clone());
212 messages
213 }
214
215 pub fn clear(&mut self) {
217 self.messages.clear();
218 }
219
220 pub fn set_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
222 self.metadata.insert(key.into(), value.into());
223 }
224
225 pub fn get_metadata(&self, key: &str) -> Option<&String> {
227 self.metadata.get(key)
228 }
229
230 pub fn remove_metadata(&mut self, key: &str) -> Option<String> {
232 self.metadata.remove(key)
233 }
234
235 pub fn get_summary(&self) -> String {
237 let mut summary = String::new();
238 summary.push_str(&format!("Total messages: {}\n", self.messages.len()));
239
240 let user_msgs = self
241 .messages
242 .iter()
243 .filter(|m| matches!(m.role, Role::User))
244 .count();
245 let assistant_msgs = self
246 .messages
247 .iter()
248 .filter(|m| matches!(m.role, Role::Assistant))
249 .count();
250 let tool_msgs = self
251 .messages
252 .iter()
253 .filter(|m| matches!(m.role, Role::Tool))
254 .count();
255
256 summary.push_str(&format!("User messages: {}\n", user_msgs));
257 summary.push_str(&format!("Assistant messages: {}\n", assistant_msgs));
258 summary.push_str(&format!("Tool messages: {}\n", tool_msgs));
259
260 if !self.metadata.is_empty() {
261 summary.push_str("\nSession metadata:\n");
262 for (key, value) in &self.metadata {
263 summary.push_str(&format!(" {}: {}\n", key, value));
264 }
265 }
266
267 summary
268 }
269}
270
271impl Default for ChatSession {
272 fn default() -> Self {
274 Self::new()
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281
282 #[test]
284 fn test_role_from_str() {
285 assert_eq!(Role::from("system"), Role::System);
286 assert_eq!(Role::from("user"), Role::User);
287 assert_eq!(Role::from("assistant"), Role::Assistant);
288 assert_eq!(Role::from("tool"), Role::Tool);
289 assert_eq!(Role::from("unknown"), Role::Assistant); assert_eq!(Role::from("SYSTEM"), Role::System); }
292
293 #[test]
295 fn test_chat_message_constructors() {
296 let system_msg = ChatMessage::system("System message");
297 assert_eq!(system_msg.role, Role::System);
298 assert_eq!(system_msg.content, "System message");
299 assert!(system_msg.name.is_none());
300 assert!(system_msg.tool_calls.is_none());
301 assert!(system_msg.tool_call_id.is_none());
302
303 let user_msg = ChatMessage::user("User message");
304 assert_eq!(user_msg.role, Role::User);
305 assert_eq!(user_msg.content, "User message");
306
307 let assistant_msg = ChatMessage::assistant("Assistant message");
308 assert_eq!(assistant_msg.role, Role::Assistant);
309 assert_eq!(assistant_msg.content, "Assistant message");
310
311 let tool_msg = ChatMessage::tool("Tool result", "tool_call_123");
312 assert_eq!(tool_msg.role, Role::Tool);
313 assert_eq!(tool_msg.content, "Tool result");
314 assert_eq!(tool_msg.tool_call_id, Some("tool_call_123".to_string()));
315 }
316
317 #[test]
319 fn test_chat_session_new() {
320 let session = ChatSession::new();
321 assert!(session.messages.is_empty());
322 assert!(session.system_prompt.is_none());
323 }
324
325 #[test]
327 fn test_chat_session_with_system_prompt() {
328 let session = ChatSession::new().with_system_prompt("Test system prompt");
329 assert_eq!(
330 session.system_prompt,
331 Some("Test system prompt".to_string())
332 );
333 }
334
335 #[test]
337 fn test_chat_session_add_message() {
338 let mut session = ChatSession::new();
339 let msg = ChatMessage::user("Test message");
340 session.add_message(msg);
341 assert_eq!(session.messages.len(), 1);
342 }
343
344 #[test]
346 fn test_chat_session_add_user_message() {
347 let mut session = ChatSession::new();
348 session.add_user_message("Test user message");
349 assert_eq!(session.messages.len(), 1);
350 assert_eq!(session.messages[0].role, Role::User);
351 assert_eq!(session.messages[0].content, "Test user message");
352 }
353
354 #[test]
356 fn test_chat_session_add_assistant_message() {
357 let mut session = ChatSession::new();
358 session.add_assistant_message("Test assistant message");
359 assert_eq!(session.messages.len(), 1);
360 assert_eq!(session.messages[0].role, Role::Assistant);
361 assert_eq!(session.messages[0].content, "Test assistant message");
362 }
363
364 #[test]
366 fn test_chat_session_get_messages() {
367 let mut session = ChatSession::new().with_system_prompt("System prompt");
368 session.add_user_message("User message");
369 session.add_assistant_message("Assistant message");
370
371 let messages = session.get_messages();
372 assert_eq!(messages.len(), 3); assert_eq!(messages[0].role, Role::System);
374 assert_eq!(messages[0].content, "System prompt");
375 assert_eq!(messages[1].role, Role::User);
376 assert_eq!(messages[1].content, "User message");
377 assert_eq!(messages[2].role, Role::Assistant);
378 assert_eq!(messages[2].content, "Assistant message");
379 }
380
381 #[test]
383 fn test_chat_session_clear() {
384 let mut session = ChatSession::new();
385 session.add_user_message("Test message");
386 assert!(!session.messages.is_empty());
387
388 session.clear();
389 assert!(session.messages.is_empty());
390 }
391}