helios_engine/
chat.rs

1//! # Chat Module
2//!
3//! This module provides the data structures for managing chat conversations.
4//! It defines the roles in a conversation, the structure of a chat message,
5//! and the chat session that holds the conversation history.
6
7use serde::{Deserialize, Serialize};
8
9/// Represents the role of a participant in a chat conversation.
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
11#[serde(rename_all = "lowercase")]
12pub enum Role {
13    /// The system, providing instructions to the assistant.
14    System,
15    /// The user, asking questions or giving commands.
16    User,
17    /// The assistant, responding to the user.
18    Assistant,
19    /// A tool, providing the result of a function call.
20    Tool,
21}
22
23impl From<&str> for Role {
24    /// Converts a string slice to a `Role`.
25    fn from(s: &str) -> Self {
26        match s.to_lowercase().as_str() {
27            "system" => Role::System,
28            "user" => Role::User,
29            "assistant" => Role::Assistant,
30            "tool" => Role::Tool,
31            _ => Role::Assistant, // Default to assistant
32        }
33    }
34}
35
36/// Represents a single message in a chat conversation.
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct ChatMessage {
39    /// The role of the message sender.
40    pub role: Role,
41    /// The content of the message.
42    #[serde(default, deserialize_with = "deserialize_null_as_empty_string")]
43    pub content: String,
44    /// The name of the message sender.
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub name: Option<String>,
47    /// Any tool calls requested by the assistant.
48    #[serde(skip_serializing_if = "Option::is_none")]
49    pub tool_calls: Option<Vec<ToolCall>>,
50    /// The ID of the tool call this message is a response to.
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub tool_call_id: Option<String>,
53}
54
55/// Deserializes a null value as an empty string.
56fn deserialize_null_as_empty_string<'de, D>(deserializer: D) -> Result<String, D::Error>
57where
58    D: serde::Deserializer<'de>,
59{
60    use serde::Deserialize;
61    Option::<String>::deserialize(deserializer).map(|opt| opt.unwrap_or_default())
62}
63
64/// Represents a tool call requested by the assistant.
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct ToolCall {
67    /// The ID of the tool call.
68    pub id: String,
69    /// The type of the tool call (e.g., "function").
70    #[serde(rename = "type")]
71    pub call_type: String,
72    /// The function call to be executed.
73    pub function: FunctionCall,
74}
75
76/// Represents a function call to be executed by a tool.
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct FunctionCall {
79    /// The name of the function to call.
80    pub name: String,
81    /// The arguments to the function, as a JSON string.
82    pub arguments: String,
83}
84
85impl ChatMessage {
86    /// Creates a new system message.
87    pub fn system(content: impl Into<String>) -> Self {
88        Self {
89            role: Role::System,
90            content: content.into(),
91            name: None,
92            tool_calls: None,
93            tool_call_id: None,
94        }
95    }
96
97    /// Creates a new system message. Alias for `system()`.
98    pub fn sys(content: impl Into<String>) -> Self {
99        Self::system(content)
100    }
101
102    /// Creates a new user message.
103    pub fn user(content: impl Into<String>) -> Self {
104        Self {
105            role: Role::User,
106            content: content.into(),
107            name: None,
108            tool_calls: None,
109            tool_call_id: None,
110        }
111    }
112
113    /// Creates a new user message. Alias for `user()`.
114    pub fn msg(content: impl Into<String>) -> Self {
115        Self::user(content)
116    }
117
118    /// Creates a new assistant message.
119    pub fn assistant(content: impl Into<String>) -> Self {
120        Self {
121            role: Role::Assistant,
122            content: content.into(),
123            name: None,
124            tool_calls: None,
125            tool_call_id: None,
126        }
127    }
128
129    /// Creates a new assistant message. Alias for `assistant()`.
130    pub fn reply(content: impl Into<String>) -> Self {
131        Self::assistant(content)
132    }
133
134    /// Creates a new tool message.
135    pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
136        Self {
137            role: Role::Tool,
138            content: content.into(),
139            name: None,
140            tool_calls: None,
141            tool_call_id: Some(tool_call_id.into()),
142        }
143    }
144}
145
146/// Represents a chat session, including the conversation history and metadata.
147#[derive(Debug, Clone)]
148pub struct ChatSession {
149    /// The messages in the chat session.
150    pub messages: Vec<ChatMessage>,
151    /// The system prompt for the chat session.
152    pub system_prompt: Option<String>,
153    /// Metadata associated with the chat session.
154    pub metadata: std::collections::HashMap<String, String>,
155}
156
157impl ChatSession {
158    /// Creates a new, empty chat session.
159    pub fn new() -> Self {
160        Self {
161            messages: Vec::new(),
162            system_prompt: None,
163            metadata: std::collections::HashMap::new(),
164        }
165    }
166
167    /// Sets the system prompt for the chat session.
168    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
169        self.system_prompt = Some(prompt.into());
170        self
171    }
172
173    /// Adds a message to the chat session.
174    pub fn add_message(&mut self, message: ChatMessage) {
175        self.messages.push(message);
176    }
177
178    /// Adds a user message to the chat session.
179    pub fn add_user_message(&mut self, content: impl Into<String>) {
180        self.messages.push(ChatMessage::user(content));
181    }
182
183    /// Adds an assistant message to the chat session.
184    pub fn add_assistant_message(&mut self, content: impl Into<String>) {
185        self.messages.push(ChatMessage::assistant(content));
186    }
187
188    /// Shorthand for adding a system message
189    pub fn add_sys(&mut self, content: impl Into<String>) {
190        self.messages.push(ChatMessage::system(content));
191    }
192
193    /// Shorthand for adding a user message (alias for add_user_message)
194    pub fn add_msg(&mut self, content: impl Into<String>) {
195        self.messages.push(ChatMessage::user(content));
196    }
197
198    /// Shorthand for adding an assistant message (alias for add_assistant_message)
199    pub fn add_reply(&mut self, content: impl Into<String>) {
200        self.messages.push(ChatMessage::assistant(content));
201    }
202
203    /// Returns all messages in the chat session, including the system prompt.
204    pub fn get_messages(&self) -> Vec<ChatMessage> {
205        let mut messages = Vec::new();
206
207        if let Some(ref system_prompt) = self.system_prompt {
208            messages.push(ChatMessage::system(system_prompt.clone()));
209        }
210
211        messages.extend(self.messages.clone());
212        messages
213    }
214
215    /// Clears all messages from the chat session.
216    pub fn clear(&mut self) {
217        self.messages.clear();
218    }
219
220    /// Sets a metadata key-value pair for the session.
221    pub fn set_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
222        self.metadata.insert(key.into(), value.into());
223    }
224
225    /// Gets a metadata value by key.
226    pub fn get_metadata(&self, key: &str) -> Option<&String> {
227        self.metadata.get(key)
228    }
229
230    /// Removes a metadata key-value pair.
231    pub fn remove_metadata(&mut self, key: &str) -> Option<String> {
232        self.metadata.remove(key)
233    }
234
235    /// Returns a summary of the chat session.
236    pub fn get_summary(&self) -> String {
237        let mut summary = String::new();
238        summary.push_str(&format!("Total messages: {}\n", self.messages.len()));
239
240        let user_msgs = self
241            .messages
242            .iter()
243            .filter(|m| matches!(m.role, Role::User))
244            .count();
245        let assistant_msgs = self
246            .messages
247            .iter()
248            .filter(|m| matches!(m.role, Role::Assistant))
249            .count();
250        let tool_msgs = self
251            .messages
252            .iter()
253            .filter(|m| matches!(m.role, Role::Tool))
254            .count();
255
256        summary.push_str(&format!("User messages: {}\n", user_msgs));
257        summary.push_str(&format!("Assistant messages: {}\n", assistant_msgs));
258        summary.push_str(&format!("Tool messages: {}\n", tool_msgs));
259
260        if !self.metadata.is_empty() {
261            summary.push_str("\nSession metadata:\n");
262            for (key, value) in &self.metadata {
263                summary.push_str(&format!("  {}: {}\n", key, value));
264            }
265        }
266
267        summary
268    }
269}
270
271impl Default for ChatSession {
272    /// Creates a new, empty chat session.
273    fn default() -> Self {
274        Self::new()
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281
282    /// Tests the conversion from a string to a `Role`.
283    #[test]
284    fn test_role_from_str() {
285        assert_eq!(Role::from("system"), Role::System);
286        assert_eq!(Role::from("user"), Role::User);
287        assert_eq!(Role::from("assistant"), Role::Assistant);
288        assert_eq!(Role::from("tool"), Role::Tool);
289        assert_eq!(Role::from("unknown"), Role::Assistant); // default case
290        assert_eq!(Role::from("SYSTEM"), Role::System); // case insensitive
291    }
292
293    /// Tests the constructors for `ChatMessage`.
294    #[test]
295    fn test_chat_message_constructors() {
296        let system_msg = ChatMessage::system("System message");
297        assert_eq!(system_msg.role, Role::System);
298        assert_eq!(system_msg.content, "System message");
299        assert!(system_msg.name.is_none());
300        assert!(system_msg.tool_calls.is_none());
301        assert!(system_msg.tool_call_id.is_none());
302
303        let user_msg = ChatMessage::user("User message");
304        assert_eq!(user_msg.role, Role::User);
305        assert_eq!(user_msg.content, "User message");
306
307        let assistant_msg = ChatMessage::assistant("Assistant message");
308        assert_eq!(assistant_msg.role, Role::Assistant);
309        assert_eq!(assistant_msg.content, "Assistant message");
310
311        let tool_msg = ChatMessage::tool("Tool result", "tool_call_123");
312        assert_eq!(tool_msg.role, Role::Tool);
313        assert_eq!(tool_msg.content, "Tool result");
314        assert_eq!(tool_msg.tool_call_id, Some("tool_call_123".to_string()));
315    }
316
317    /// Tests the creation of a new `ChatSession`.
318    #[test]
319    fn test_chat_session_new() {
320        let session = ChatSession::new();
321        assert!(session.messages.is_empty());
322        assert!(session.system_prompt.is_none());
323    }
324
325    /// Tests setting the system prompt for a `ChatSession`.
326    #[test]
327    fn test_chat_session_with_system_prompt() {
328        let session = ChatSession::new().with_system_prompt("Test system prompt");
329        assert_eq!(
330            session.system_prompt,
331            Some("Test system prompt".to_string())
332        );
333    }
334
335    /// Tests adding a message to a `ChatSession`.
336    #[test]
337    fn test_chat_session_add_message() {
338        let mut session = ChatSession::new();
339        let msg = ChatMessage::user("Test message");
340        session.add_message(msg);
341        assert_eq!(session.messages.len(), 1);
342    }
343
344    /// Tests adding a user message to a `ChatSession`.
345    #[test]
346    fn test_chat_session_add_user_message() {
347        let mut session = ChatSession::new();
348        session.add_user_message("Test user message");
349        assert_eq!(session.messages.len(), 1);
350        assert_eq!(session.messages[0].role, Role::User);
351        assert_eq!(session.messages[0].content, "Test user message");
352    }
353
354    /// Tests adding an assistant message to a `ChatSession`.
355    #[test]
356    fn test_chat_session_add_assistant_message() {
357        let mut session = ChatSession::new();
358        session.add_assistant_message("Test assistant message");
359        assert_eq!(session.messages.len(), 1);
360        assert_eq!(session.messages[0].role, Role::Assistant);
361        assert_eq!(session.messages[0].content, "Test assistant message");
362    }
363
364    /// Tests getting all messages from a `ChatSession`.
365    #[test]
366    fn test_chat_session_get_messages() {
367        let mut session = ChatSession::new().with_system_prompt("System prompt");
368        session.add_user_message("User message");
369        session.add_assistant_message("Assistant message");
370
371        let messages = session.get_messages();
372        assert_eq!(messages.len(), 3); // system + user + assistant
373        assert_eq!(messages[0].role, Role::System);
374        assert_eq!(messages[0].content, "System prompt");
375        assert_eq!(messages[1].role, Role::User);
376        assert_eq!(messages[1].content, "User message");
377        assert_eq!(messages[2].role, Role::Assistant);
378        assert_eq!(messages[2].content, "Assistant message");
379    }
380
381    /// Tests clearing all messages from a `ChatSession`.
382    #[test]
383    fn test_chat_session_clear() {
384        let mut session = ChatSession::new();
385        session.add_user_message("Test message");
386        assert!(!session.messages.is_empty());
387
388        session.clear();
389        assert!(session.messages.is_empty());
390    }
391}