1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
4#[serde(rename_all = "lowercase")]
5pub enum Role {
6 System,
7 User,
8 Assistant,
9 Tool,
10}
11
12impl From<&str> for Role {
13 fn from(s: &str) -> Self {
14 match s.to_lowercase().as_str() {
15 "system" => Role::System,
16 "user" => Role::User,
17 "assistant" => Role::Assistant,
18 "tool" => Role::Tool,
19 _ => Role::Assistant, }
21 }
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ChatMessage {
26 pub role: Role,
27 #[serde(default)]
28 pub content: String,
29 #[serde(skip_serializing_if = "Option::is_none")]
30 pub name: Option<String>,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub tool_calls: Option<Vec<ToolCall>>,
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub tool_call_id: Option<String>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct ToolCall {
39 pub id: String,
40 #[serde(rename = "type")]
41 pub call_type: String,
42 pub function: FunctionCall,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct FunctionCall {
47 pub name: String,
48 pub arguments: String,
49}
50
51impl ChatMessage {
52 pub fn system(content: impl Into<String>) -> Self {
53 Self {
54 role: Role::System,
55 content: content.into(),
56 name: None,
57 tool_calls: None,
58 tool_call_id: None,
59 }
60 }
61
62 pub fn user(content: impl Into<String>) -> Self {
63 Self {
64 role: Role::User,
65 content: content.into(),
66 name: None,
67 tool_calls: None,
68 tool_call_id: None,
69 }
70 }
71
72 pub fn assistant(content: impl Into<String>) -> Self {
73 Self {
74 role: Role::Assistant,
75 content: content.into(),
76 name: None,
77 tool_calls: None,
78 tool_call_id: None,
79 }
80 }
81
82 pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
83 Self {
84 role: Role::Tool,
85 content: content.into(),
86 name: None,
87 tool_calls: None,
88 tool_call_id: Some(tool_call_id.into()),
89 }
90 }
91}
92
93#[derive(Debug, Clone)]
94pub struct ChatSession {
95 pub messages: Vec<ChatMessage>,
96 pub system_prompt: Option<String>,
97 pub metadata: std::collections::HashMap<String, String>,
98}
99
100impl ChatSession {
101 pub fn new() -> Self {
102 Self {
103 messages: Vec::new(),
104 system_prompt: None,
105 metadata: std::collections::HashMap::new(),
106 }
107 }
108
109 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
110 self.system_prompt = Some(prompt.into());
111 self
112 }
113
114 pub fn add_message(&mut self, message: ChatMessage) {
115 self.messages.push(message);
116 }
117
118 pub fn add_user_message(&mut self, content: impl Into<String>) {
119 self.messages.push(ChatMessage::user(content));
120 }
121
122 pub fn add_assistant_message(&mut self, content: impl Into<String>) {
123 self.messages.push(ChatMessage::assistant(content));
124 }
125
126 pub fn get_messages(&self) -> Vec<ChatMessage> {
127 let mut messages = Vec::new();
128
129 if let Some(ref system_prompt) = self.system_prompt {
130 messages.push(ChatMessage::system(system_prompt.clone()));
131 }
132
133 messages.extend(self.messages.clone());
134 messages
135 }
136
137 pub fn clear(&mut self) {
138 self.messages.clear();
139 }
140
141 pub fn set_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
143 self.metadata.insert(key.into(), value.into());
144 }
145
146 pub fn get_metadata(&self, key: &str) -> Option<&String> {
147 self.metadata.get(key)
148 }
149
150 pub fn remove_metadata(&mut self, key: &str) -> Option<String> {
151 self.metadata.remove(key)
152 }
153
154 pub fn get_summary(&self) -> String {
155 let mut summary = String::new();
156 summary.push_str(&format!("Total messages: {}\n", self.messages.len()));
157
158 let user_msgs = self.messages.iter().filter(|m| matches!(m.role, Role::User)).count();
159 let assistant_msgs = self.messages.iter().filter(|m| matches!(m.role, Role::Assistant)).count();
160 let tool_msgs = self.messages.iter().filter(|m| matches!(m.role, Role::Tool)).count();
161
162 summary.push_str(&format!("User messages: {}\n", user_msgs));
163 summary.push_str(&format!("Assistant messages: {}\n", assistant_msgs));
164 summary.push_str(&format!("Tool messages: {}\n", tool_msgs));
165
166 if !self.metadata.is_empty() {
167 summary.push_str("\nSession metadata:\n");
168 for (key, value) in &self.metadata {
169 summary.push_str(&format!(" {}: {}\n", key, value));
170 }
171 }
172
173 summary
174 }
175}
176
177impl Default for ChatSession {
178 fn default() -> Self {
179 Self::new()
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186
187 #[test]
188 fn test_role_from_str() {
189 assert_eq!(Role::from("system"), Role::System);
190 assert_eq!(Role::from("user"), Role::User);
191 assert_eq!(Role::from("assistant"), Role::Assistant);
192 assert_eq!(Role::from("tool"), Role::Tool);
193 assert_eq!(Role::from("unknown"), Role::Assistant); assert_eq!(Role::from("SYSTEM"), Role::System); }
196
197 #[test]
198 fn test_chat_message_constructors() {
199 let system_msg = ChatMessage::system("System message");
200 assert_eq!(system_msg.role, Role::System);
201 assert_eq!(system_msg.content, "System message");
202 assert!(system_msg.name.is_none());
203 assert!(system_msg.tool_calls.is_none());
204 assert!(system_msg.tool_call_id.is_none());
205
206 let user_msg = ChatMessage::user("User message");
207 assert_eq!(user_msg.role, Role::User);
208 assert_eq!(user_msg.content, "User message");
209
210 let assistant_msg = ChatMessage::assistant("Assistant message");
211 assert_eq!(assistant_msg.role, Role::Assistant);
212 assert_eq!(assistant_msg.content, "Assistant message");
213
214 let tool_msg = ChatMessage::tool("Tool result", "tool_call_123");
215 assert_eq!(tool_msg.role, Role::Tool);
216 assert_eq!(tool_msg.content, "Tool result");
217 assert_eq!(tool_msg.tool_call_id, Some("tool_call_123".to_string()));
218 }
219
220 #[test]
221 fn test_chat_session_new() {
222 let session = ChatSession::new();
223 assert!(session.messages.is_empty());
224 assert!(session.system_prompt.is_none());
225 }
226
227 #[test]
228 fn test_chat_session_with_system_prompt() {
229 let session = ChatSession::new().with_system_prompt("Test system prompt");
230 assert_eq!(
231 session.system_prompt,
232 Some("Test system prompt".to_string())
233 );
234 }
235
236 #[test]
237 fn test_chat_session_add_message() {
238 let mut session = ChatSession::new();
239 let msg = ChatMessage::user("Test message");
240 session.add_message(msg);
241 assert_eq!(session.messages.len(), 1);
242 }
243
244 #[test]
245 fn test_chat_session_add_user_message() {
246 let mut session = ChatSession::new();
247 session.add_user_message("Test user message");
248 assert_eq!(session.messages.len(), 1);
249 assert_eq!(session.messages[0].role, Role::User);
250 assert_eq!(session.messages[0].content, "Test user message");
251 }
252
253 #[test]
254 fn test_chat_session_add_assistant_message() {
255 let mut session = ChatSession::new();
256 session.add_assistant_message("Test assistant message");
257 assert_eq!(session.messages.len(), 1);
258 assert_eq!(session.messages[0].role, Role::Assistant);
259 assert_eq!(session.messages[0].content, "Test assistant message");
260 }
261
262 #[test]
263 fn test_chat_session_get_messages() {
264 let mut session = ChatSession::new().with_system_prompt("System prompt");
265 session.add_user_message("User message");
266 session.add_assistant_message("Assistant message");
267
268 let messages = session.get_messages();
269 assert_eq!(messages.len(), 3); assert_eq!(messages[0].role, Role::System);
271 assert_eq!(messages[0].content, "System prompt");
272 assert_eq!(messages[1].role, Role::User);
273 assert_eq!(messages[1].content, "User message");
274 assert_eq!(messages[2].role, Role::Assistant);
275 assert_eq!(messages[2].content, "Assistant message");
276 }
277
278 #[test]
279 fn test_chat_session_clear() {
280 let mut session = ChatSession::new();
281 session.add_user_message("Test message");
282 assert!(!session.messages.is_empty());
283
284 session.clear();
285 assert!(session.messages.is_empty());
286 }
287}