helios_engine/
chat.rs

1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
4#[serde(rename_all = "lowercase")]
5pub enum Role {
6    System,
7    User,
8    Assistant,
9    Tool,
10}
11
12impl From<&str> for Role {
13    fn from(s: &str) -> Self {
14        match s.to_lowercase().as_str() {
15            "system" => Role::System,
16            "user" => Role::User,
17            "assistant" => Role::Assistant,
18            "tool" => Role::Tool,
19            _ => Role::Assistant, // Default to assistant
20        }
21    }
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ChatMessage {
26    pub role: Role,
27    #[serde(default)]
28    pub content: String,
29    #[serde(skip_serializing_if = "Option::is_none")]
30    pub name: Option<String>,
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub tool_calls: Option<Vec<ToolCall>>,
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub tool_call_id: Option<String>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct ToolCall {
39    pub id: String,
40    #[serde(rename = "type")]
41    pub call_type: String,
42    pub function: FunctionCall,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct FunctionCall {
47    pub name: String,
48    pub arguments: String,
49}
50
51impl ChatMessage {
52    pub fn system(content: impl Into<String>) -> Self {
53        Self {
54            role: Role::System,
55            content: content.into(),
56            name: None,
57            tool_calls: None,
58            tool_call_id: None,
59        }
60    }
61
62    pub fn user(content: impl Into<String>) -> Self {
63        Self {
64            role: Role::User,
65            content: content.into(),
66            name: None,
67            tool_calls: None,
68            tool_call_id: None,
69        }
70    }
71
72    pub fn assistant(content: impl Into<String>) -> Self {
73        Self {
74            role: Role::Assistant,
75            content: content.into(),
76            name: None,
77            tool_calls: None,
78            tool_call_id: None,
79        }
80    }
81
82    pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
83        Self {
84            role: Role::Tool,
85            content: content.into(),
86            name: None,
87            tool_calls: None,
88            tool_call_id: Some(tool_call_id.into()),
89        }
90    }
91}
92
93#[derive(Debug, Clone)]
94pub struct ChatSession {
95    pub messages: Vec<ChatMessage>,
96    pub system_prompt: Option<String>,
97    pub metadata: std::collections::HashMap<String, String>,
98}
99
100impl ChatSession {
101    pub fn new() -> Self {
102        Self {
103            messages: Vec::new(),
104            system_prompt: None,
105            metadata: std::collections::HashMap::new(),
106        }
107    }
108
109    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
110        self.system_prompt = Some(prompt.into());
111        self
112    }
113
114    pub fn add_message(&mut self, message: ChatMessage) {
115        self.messages.push(message);
116    }
117
118    pub fn add_user_message(&mut self, content: impl Into<String>) {
119        self.messages.push(ChatMessage::user(content));
120    }
121
122    pub fn add_assistant_message(&mut self, content: impl Into<String>) {
123        self.messages.push(ChatMessage::assistant(content));
124    }
125
126    pub fn get_messages(&self) -> Vec<ChatMessage> {
127        let mut messages = Vec::new();
128
129        if let Some(ref system_prompt) = self.system_prompt {
130            messages.push(ChatMessage::system(system_prompt.clone()));
131        }
132
133        messages.extend(self.messages.clone());
134        messages
135    }
136
137    pub fn clear(&mut self) {
138        self.messages.clear();
139    }
140    
141    // Session memory methods
142    pub fn set_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
143        self.metadata.insert(key.into(), value.into());
144    }
145    
146    pub fn get_metadata(&self, key: &str) -> Option<&String> {
147        self.metadata.get(key)
148    }
149    
150    pub fn remove_metadata(&mut self, key: &str) -> Option<String> {
151        self.metadata.remove(key)
152    }
153    
154    pub fn get_summary(&self) -> String {
155        let mut summary = String::new();
156        summary.push_str(&format!("Total messages: {}\n", self.messages.len()));
157        
158        let user_msgs = self.messages.iter().filter(|m| matches!(m.role, Role::User)).count();
159        let assistant_msgs = self.messages.iter().filter(|m| matches!(m.role, Role::Assistant)).count();
160        let tool_msgs = self.messages.iter().filter(|m| matches!(m.role, Role::Tool)).count();
161        
162        summary.push_str(&format!("User messages: {}\n", user_msgs));
163        summary.push_str(&format!("Assistant messages: {}\n", assistant_msgs));
164        summary.push_str(&format!("Tool messages: {}\n", tool_msgs));
165        
166        if !self.metadata.is_empty() {
167            summary.push_str("\nSession metadata:\n");
168            for (key, value) in &self.metadata {
169                summary.push_str(&format!("  {}: {}\n", key, value));
170            }
171        }
172        
173        summary
174    }
175}
176
177impl Default for ChatSession {
178    fn default() -> Self {
179        Self::new()
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186
187    #[test]
188    fn test_role_from_str() {
189        assert_eq!(Role::from("system"), Role::System);
190        assert_eq!(Role::from("user"), Role::User);
191        assert_eq!(Role::from("assistant"), Role::Assistant);
192        assert_eq!(Role::from("tool"), Role::Tool);
193        assert_eq!(Role::from("unknown"), Role::Assistant); // default case
194        assert_eq!(Role::from("SYSTEM"), Role::System); // case insensitive
195    }
196
197    #[test]
198    fn test_chat_message_constructors() {
199        let system_msg = ChatMessage::system("System message");
200        assert_eq!(system_msg.role, Role::System);
201        assert_eq!(system_msg.content, "System message");
202        assert!(system_msg.name.is_none());
203        assert!(system_msg.tool_calls.is_none());
204        assert!(system_msg.tool_call_id.is_none());
205
206        let user_msg = ChatMessage::user("User message");
207        assert_eq!(user_msg.role, Role::User);
208        assert_eq!(user_msg.content, "User message");
209
210        let assistant_msg = ChatMessage::assistant("Assistant message");
211        assert_eq!(assistant_msg.role, Role::Assistant);
212        assert_eq!(assistant_msg.content, "Assistant message");
213
214        let tool_msg = ChatMessage::tool("Tool result", "tool_call_123");
215        assert_eq!(tool_msg.role, Role::Tool);
216        assert_eq!(tool_msg.content, "Tool result");
217        assert_eq!(tool_msg.tool_call_id, Some("tool_call_123".to_string()));
218    }
219
220    #[test]
221    fn test_chat_session_new() {
222        let session = ChatSession::new();
223        assert!(session.messages.is_empty());
224        assert!(session.system_prompt.is_none());
225    }
226
227    #[test]
228    fn test_chat_session_with_system_prompt() {
229        let session = ChatSession::new().with_system_prompt("Test system prompt");
230        assert_eq!(
231            session.system_prompt,
232            Some("Test system prompt".to_string())
233        );
234    }
235
236    #[test]
237    fn test_chat_session_add_message() {
238        let mut session = ChatSession::new();
239        let msg = ChatMessage::user("Test message");
240        session.add_message(msg);
241        assert_eq!(session.messages.len(), 1);
242    }
243
244    #[test]
245    fn test_chat_session_add_user_message() {
246        let mut session = ChatSession::new();
247        session.add_user_message("Test user message");
248        assert_eq!(session.messages.len(), 1);
249        assert_eq!(session.messages[0].role, Role::User);
250        assert_eq!(session.messages[0].content, "Test user message");
251    }
252
253    #[test]
254    fn test_chat_session_add_assistant_message() {
255        let mut session = ChatSession::new();
256        session.add_assistant_message("Test assistant message");
257        assert_eq!(session.messages.len(), 1);
258        assert_eq!(session.messages[0].role, Role::Assistant);
259        assert_eq!(session.messages[0].content, "Test assistant message");
260    }
261
262    #[test]
263    fn test_chat_session_get_messages() {
264        let mut session = ChatSession::new().with_system_prompt("System prompt");
265        session.add_user_message("User message");
266        session.add_assistant_message("Assistant message");
267
268        let messages = session.get_messages();
269        assert_eq!(messages.len(), 3); // system + user + assistant
270        assert_eq!(messages[0].role, Role::System);
271        assert_eq!(messages[0].content, "System prompt");
272        assert_eq!(messages[1].role, Role::User);
273        assert_eq!(messages[1].content, "User message");
274        assert_eq!(messages[2].role, Role::Assistant);
275        assert_eq!(messages[2].content, "Assistant message");
276    }
277
278    #[test]
279    fn test_chat_session_clear() {
280        let mut session = ChatSession::new();
281        session.add_user_message("Test message");
282        assert!(!session.messages.is_empty());
283
284        session.clear();
285        assert!(session.messages.is_empty());
286    }
287}