1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
4#[serde(rename_all = "lowercase")]
5pub enum Role {
6 System,
7 User,
8 Assistant,
9 Tool,
10}
11
12impl From<&str> for Role {
13 fn from(s: &str) -> Self {
14 match s.to_lowercase().as_str() {
15 "system" => Role::System,
16 "user" => Role::User,
17 "assistant" => Role::Assistant,
18 "tool" => Role::Tool,
19 _ => Role::Assistant, }
21 }
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ChatMessage {
26 pub role: Role,
27 #[serde(default, deserialize_with = "deserialize_null_as_empty_string")]
28 pub content: String,
29 #[serde(skip_serializing_if = "Option::is_none")]
30 pub name: Option<String>,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub tool_calls: Option<Vec<ToolCall>>,
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub tool_call_id: Option<String>,
35}
36
37fn deserialize_null_as_empty_string<'de, D>(deserializer: D) -> Result<String, D::Error>
38where
39 D: serde::Deserializer<'de>,
40{
41 use serde::Deserialize;
42 Option::<String>::deserialize(deserializer).map(|opt| opt.unwrap_or_default())
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct ToolCall {
47 pub id: String,
48 #[serde(rename = "type")]
49 pub call_type: String,
50 pub function: FunctionCall,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct FunctionCall {
55 pub name: String,
56 pub arguments: String,
57}
58
59impl ChatMessage {
60 pub fn system(content: impl Into<String>) -> Self {
61 Self {
62 role: Role::System,
63 content: content.into(),
64 name: None,
65 tool_calls: None,
66 tool_call_id: None,
67 }
68 }
69
70 pub fn user(content: impl Into<String>) -> Self {
71 Self {
72 role: Role::User,
73 content: content.into(),
74 name: None,
75 tool_calls: None,
76 tool_call_id: None,
77 }
78 }
79
80 pub fn assistant(content: impl Into<String>) -> Self {
81 Self {
82 role: Role::Assistant,
83 content: content.into(),
84 name: None,
85 tool_calls: None,
86 tool_call_id: None,
87 }
88 }
89
90 pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
91 Self {
92 role: Role::Tool,
93 content: content.into(),
94 name: None,
95 tool_calls: None,
96 tool_call_id: Some(tool_call_id.into()),
97 }
98 }
99}
100
101#[derive(Debug, Clone)]
102pub struct ChatSession {
103 pub messages: Vec<ChatMessage>,
104 pub system_prompt: Option<String>,
105 pub metadata: std::collections::HashMap<String, String>,
106}
107
108impl ChatSession {
109 pub fn new() -> Self {
110 Self {
111 messages: Vec::new(),
112 system_prompt: None,
113 metadata: std::collections::HashMap::new(),
114 }
115 }
116
117 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
118 self.system_prompt = Some(prompt.into());
119 self
120 }
121
122 pub fn add_message(&mut self, message: ChatMessage) {
123 self.messages.push(message);
124 }
125
126 pub fn add_user_message(&mut self, content: impl Into<String>) {
127 self.messages.push(ChatMessage::user(content));
128 }
129
130 pub fn add_assistant_message(&mut self, content: impl Into<String>) {
131 self.messages.push(ChatMessage::assistant(content));
132 }
133
134 pub fn get_messages(&self) -> Vec<ChatMessage> {
135 let mut messages = Vec::new();
136
137 if let Some(ref system_prompt) = self.system_prompt {
138 messages.push(ChatMessage::system(system_prompt.clone()));
139 }
140
141 messages.extend(self.messages.clone());
142 messages
143 }
144
145 pub fn clear(&mut self) {
146 self.messages.clear();
147 }
148
149 pub fn set_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
151 self.metadata.insert(key.into(), value.into());
152 }
153
154 pub fn get_metadata(&self, key: &str) -> Option<&String> {
155 self.metadata.get(key)
156 }
157
158 pub fn remove_metadata(&mut self, key: &str) -> Option<String> {
159 self.metadata.remove(key)
160 }
161
162 pub fn get_summary(&self) -> String {
163 let mut summary = String::new();
164 summary.push_str(&format!("Total messages: {}\n", self.messages.len()));
165
166 let user_msgs = self.messages.iter().filter(|m| matches!(m.role, Role::User)).count();
167 let assistant_msgs = self.messages.iter().filter(|m| matches!(m.role, Role::Assistant)).count();
168 let tool_msgs = self.messages.iter().filter(|m| matches!(m.role, Role::Tool)).count();
169
170 summary.push_str(&format!("User messages: {}\n", user_msgs));
171 summary.push_str(&format!("Assistant messages: {}\n", assistant_msgs));
172 summary.push_str(&format!("Tool messages: {}\n", tool_msgs));
173
174 if !self.metadata.is_empty() {
175 summary.push_str("\nSession metadata:\n");
176 for (key, value) in &self.metadata {
177 summary.push_str(&format!(" {}: {}\n", key, value));
178 }
179 }
180
181 summary
182 }
183}
184
185impl Default for ChatSession {
186 fn default() -> Self {
187 Self::new()
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194
195 #[test]
196 fn test_role_from_str() {
197 assert_eq!(Role::from("system"), Role::System);
198 assert_eq!(Role::from("user"), Role::User);
199 assert_eq!(Role::from("assistant"), Role::Assistant);
200 assert_eq!(Role::from("tool"), Role::Tool);
201 assert_eq!(Role::from("unknown"), Role::Assistant); assert_eq!(Role::from("SYSTEM"), Role::System); }
204
205 #[test]
206 fn test_chat_message_constructors() {
207 let system_msg = ChatMessage::system("System message");
208 assert_eq!(system_msg.role, Role::System);
209 assert_eq!(system_msg.content, "System message");
210 assert!(system_msg.name.is_none());
211 assert!(system_msg.tool_calls.is_none());
212 assert!(system_msg.tool_call_id.is_none());
213
214 let user_msg = ChatMessage::user("User message");
215 assert_eq!(user_msg.role, Role::User);
216 assert_eq!(user_msg.content, "User message");
217
218 let assistant_msg = ChatMessage::assistant("Assistant message");
219 assert_eq!(assistant_msg.role, Role::Assistant);
220 assert_eq!(assistant_msg.content, "Assistant message");
221
222 let tool_msg = ChatMessage::tool("Tool result", "tool_call_123");
223 assert_eq!(tool_msg.role, Role::Tool);
224 assert_eq!(tool_msg.content, "Tool result");
225 assert_eq!(tool_msg.tool_call_id, Some("tool_call_123".to_string()));
226 }
227
228 #[test]
229 fn test_chat_session_new() {
230 let session = ChatSession::new();
231 assert!(session.messages.is_empty());
232 assert!(session.system_prompt.is_none());
233 }
234
235 #[test]
236 fn test_chat_session_with_system_prompt() {
237 let session = ChatSession::new().with_system_prompt("Test system prompt");
238 assert_eq!(
239 session.system_prompt,
240 Some("Test system prompt".to_string())
241 );
242 }
243
244 #[test]
245 fn test_chat_session_add_message() {
246 let mut session = ChatSession::new();
247 let msg = ChatMessage::user("Test message");
248 session.add_message(msg);
249 assert_eq!(session.messages.len(), 1);
250 }
251
252 #[test]
253 fn test_chat_session_add_user_message() {
254 let mut session = ChatSession::new();
255 session.add_user_message("Test user message");
256 assert_eq!(session.messages.len(), 1);
257 assert_eq!(session.messages[0].role, Role::User);
258 assert_eq!(session.messages[0].content, "Test user message");
259 }
260
261 #[test]
262 fn test_chat_session_add_assistant_message() {
263 let mut session = ChatSession::new();
264 session.add_assistant_message("Test assistant message");
265 assert_eq!(session.messages.len(), 1);
266 assert_eq!(session.messages[0].role, Role::Assistant);
267 assert_eq!(session.messages[0].content, "Test assistant message");
268 }
269
270 #[test]
271 fn test_chat_session_get_messages() {
272 let mut session = ChatSession::new().with_system_prompt("System prompt");
273 session.add_user_message("User message");
274 session.add_assistant_message("Assistant message");
275
276 let messages = session.get_messages();
277 assert_eq!(messages.len(), 3); assert_eq!(messages[0].role, Role::System);
279 assert_eq!(messages[0].content, "System prompt");
280 assert_eq!(messages[1].role, Role::User);
281 assert_eq!(messages[1].content, "User message");
282 assert_eq!(messages[2].role, Role::Assistant);
283 assert_eq!(messages[2].content, "Assistant message");
284 }
285
286 #[test]
287 fn test_chat_session_clear() {
288 let mut session = ChatSession::new();
289 session.add_user_message("Test message");
290 assert!(!session.messages.is_empty());
291
292 session.clear();
293 assert!(session.messages.is_empty());
294 }
295}