helios_engine/
chat.rs

1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
4#[serde(rename_all = "lowercase")]
5pub enum Role {
6    System,
7    User,
8    Assistant,
9    Tool,
10}
11
12impl From<&str> for Role {
13    fn from(s: &str) -> Self {
14        match s.to_lowercase().as_str() {
15            "system" => Role::System,
16            "user" => Role::User,
17            "assistant" => Role::Assistant,
18            "tool" => Role::Tool,
19            _ => Role::Assistant, // Default to assistant
20        }
21    }
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ChatMessage {
26    pub role: Role,
27    #[serde(default, deserialize_with = "deserialize_null_as_empty_string")]
28    pub content: String,
29    #[serde(skip_serializing_if = "Option::is_none")]
30    pub name: Option<String>,
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub tool_calls: Option<Vec<ToolCall>>,
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub tool_call_id: Option<String>,
35}
36
37fn deserialize_null_as_empty_string<'de, D>(deserializer: D) -> Result<String, D::Error>
38where
39    D: serde::Deserializer<'de>,
40{
41    use serde::Deserialize;
42    Option::<String>::deserialize(deserializer).map(|opt| opt.unwrap_or_default())
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct ToolCall {
47    pub id: String,
48    #[serde(rename = "type")]
49    pub call_type: String,
50    pub function: FunctionCall,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct FunctionCall {
55    pub name: String,
56    pub arguments: String,
57}
58
59impl ChatMessage {
60    pub fn system(content: impl Into<String>) -> Self {
61        Self {
62            role: Role::System,
63            content: content.into(),
64            name: None,
65            tool_calls: None,
66            tool_call_id: None,
67        }
68    }
69
70    pub fn user(content: impl Into<String>) -> Self {
71        Self {
72            role: Role::User,
73            content: content.into(),
74            name: None,
75            tool_calls: None,
76            tool_call_id: None,
77        }
78    }
79
80    pub fn assistant(content: impl Into<String>) -> Self {
81        Self {
82            role: Role::Assistant,
83            content: content.into(),
84            name: None,
85            tool_calls: None,
86            tool_call_id: None,
87        }
88    }
89
90    pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
91        Self {
92            role: Role::Tool,
93            content: content.into(),
94            name: None,
95            tool_calls: None,
96            tool_call_id: Some(tool_call_id.into()),
97        }
98    }
99}
100
101#[derive(Debug, Clone)]
102pub struct ChatSession {
103    pub messages: Vec<ChatMessage>,
104    pub system_prompt: Option<String>,
105    pub metadata: std::collections::HashMap<String, String>,
106}
107
108impl ChatSession {
109    pub fn new() -> Self {
110        Self {
111            messages: Vec::new(),
112            system_prompt: None,
113            metadata: std::collections::HashMap::new(),
114        }
115    }
116
117    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
118        self.system_prompt = Some(prompt.into());
119        self
120    }
121
122    pub fn add_message(&mut self, message: ChatMessage) {
123        self.messages.push(message);
124    }
125
126    pub fn add_user_message(&mut self, content: impl Into<String>) {
127        self.messages.push(ChatMessage::user(content));
128    }
129
130    pub fn add_assistant_message(&mut self, content: impl Into<String>) {
131        self.messages.push(ChatMessage::assistant(content));
132    }
133
134    pub fn get_messages(&self) -> Vec<ChatMessage> {
135        let mut messages = Vec::new();
136
137        if let Some(ref system_prompt) = self.system_prompt {
138            messages.push(ChatMessage::system(system_prompt.clone()));
139        }
140
141        messages.extend(self.messages.clone());
142        messages
143    }
144
145    pub fn clear(&mut self) {
146        self.messages.clear();
147    }
148    
149    // Session memory methods
150    pub fn set_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
151        self.metadata.insert(key.into(), value.into());
152    }
153    
154    pub fn get_metadata(&self, key: &str) -> Option<&String> {
155        self.metadata.get(key)
156    }
157    
158    pub fn remove_metadata(&mut self, key: &str) -> Option<String> {
159        self.metadata.remove(key)
160    }
161    
162    pub fn get_summary(&self) -> String {
163        let mut summary = String::new();
164        summary.push_str(&format!("Total messages: {}\n", self.messages.len()));
165        
166        let user_msgs = self.messages.iter().filter(|m| matches!(m.role, Role::User)).count();
167        let assistant_msgs = self.messages.iter().filter(|m| matches!(m.role, Role::Assistant)).count();
168        let tool_msgs = self.messages.iter().filter(|m| matches!(m.role, Role::Tool)).count();
169        
170        summary.push_str(&format!("User messages: {}\n", user_msgs));
171        summary.push_str(&format!("Assistant messages: {}\n", assistant_msgs));
172        summary.push_str(&format!("Tool messages: {}\n", tool_msgs));
173        
174        if !self.metadata.is_empty() {
175            summary.push_str("\nSession metadata:\n");
176            for (key, value) in &self.metadata {
177                summary.push_str(&format!("  {}: {}\n", key, value));
178            }
179        }
180        
181        summary
182    }
183}
184
185impl Default for ChatSession {
186    fn default() -> Self {
187        Self::new()
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194
195    #[test]
196    fn test_role_from_str() {
197        assert_eq!(Role::from("system"), Role::System);
198        assert_eq!(Role::from("user"), Role::User);
199        assert_eq!(Role::from("assistant"), Role::Assistant);
200        assert_eq!(Role::from("tool"), Role::Tool);
201        assert_eq!(Role::from("unknown"), Role::Assistant); // default case
202        assert_eq!(Role::from("SYSTEM"), Role::System); // case insensitive
203    }
204
205    #[test]
206    fn test_chat_message_constructors() {
207        let system_msg = ChatMessage::system("System message");
208        assert_eq!(system_msg.role, Role::System);
209        assert_eq!(system_msg.content, "System message");
210        assert!(system_msg.name.is_none());
211        assert!(system_msg.tool_calls.is_none());
212        assert!(system_msg.tool_call_id.is_none());
213
214        let user_msg = ChatMessage::user("User message");
215        assert_eq!(user_msg.role, Role::User);
216        assert_eq!(user_msg.content, "User message");
217
218        let assistant_msg = ChatMessage::assistant("Assistant message");
219        assert_eq!(assistant_msg.role, Role::Assistant);
220        assert_eq!(assistant_msg.content, "Assistant message");
221
222        let tool_msg = ChatMessage::tool("Tool result", "tool_call_123");
223        assert_eq!(tool_msg.role, Role::Tool);
224        assert_eq!(tool_msg.content, "Tool result");
225        assert_eq!(tool_msg.tool_call_id, Some("tool_call_123".to_string()));
226    }
227
228    #[test]
229    fn test_chat_session_new() {
230        let session = ChatSession::new();
231        assert!(session.messages.is_empty());
232        assert!(session.system_prompt.is_none());
233    }
234
235    #[test]
236    fn test_chat_session_with_system_prompt() {
237        let session = ChatSession::new().with_system_prompt("Test system prompt");
238        assert_eq!(
239            session.system_prompt,
240            Some("Test system prompt".to_string())
241        );
242    }
243
244    #[test]
245    fn test_chat_session_add_message() {
246        let mut session = ChatSession::new();
247        let msg = ChatMessage::user("Test message");
248        session.add_message(msg);
249        assert_eq!(session.messages.len(), 1);
250    }
251
252    #[test]
253    fn test_chat_session_add_user_message() {
254        let mut session = ChatSession::new();
255        session.add_user_message("Test user message");
256        assert_eq!(session.messages.len(), 1);
257        assert_eq!(session.messages[0].role, Role::User);
258        assert_eq!(session.messages[0].content, "Test user message");
259    }
260
261    #[test]
262    fn test_chat_session_add_assistant_message() {
263        let mut session = ChatSession::new();
264        session.add_assistant_message("Test assistant message");
265        assert_eq!(session.messages.len(), 1);
266        assert_eq!(session.messages[0].role, Role::Assistant);
267        assert_eq!(session.messages[0].content, "Test assistant message");
268    }
269
270    #[test]
271    fn test_chat_session_get_messages() {
272        let mut session = ChatSession::new().with_system_prompt("System prompt");
273        session.add_user_message("User message");
274        session.add_assistant_message("Assistant message");
275
276        let messages = session.get_messages();
277        assert_eq!(messages.len(), 3); // system + user + assistant
278        assert_eq!(messages[0].role, Role::System);
279        assert_eq!(messages[0].content, "System prompt");
280        assert_eq!(messages[1].role, Role::User);
281        assert_eq!(messages[1].content, "User message");
282        assert_eq!(messages[2].role, Role::Assistant);
283        assert_eq!(messages[2].content, "Assistant message");
284    }
285
286    #[test]
287    fn test_chat_session_clear() {
288        let mut session = ChatSession::new();
289        session.add_user_message("Test message");
290        assert!(!session.messages.is_empty());
291
292        session.clear();
293        assert!(session.messages.is_empty());
294    }
295}