1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
4#[serde(rename_all = "lowercase")]
5pub enum Role {
6 System,
7 User,
8 Assistant,
9 Tool,
10}
11
12impl From<&str> for Role {
13 fn from(s: &str) -> Self {
14 match s.to_lowercase().as_str() {
15 "system" => Role::System,
16 "user" => Role::User,
17 "assistant" => Role::Assistant,
18 "tool" => Role::Tool,
19 _ => Role::Assistant, }
21 }
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ChatMessage {
26 pub role: Role,
27 #[serde(default)]
28 pub content: String,
29 #[serde(skip_serializing_if = "Option::is_none")]
30 pub name: Option<String>,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub tool_calls: Option<Vec<ToolCall>>,
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub tool_call_id: Option<String>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct ToolCall {
39 pub id: String,
40 #[serde(rename = "type")]
41 pub call_type: String,
42 pub function: FunctionCall,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct FunctionCall {
47 pub name: String,
48 pub arguments: String,
49}
50
51impl ChatMessage {
52 pub fn system(content: impl Into<String>) -> Self {
53 Self {
54 role: Role::System,
55 content: content.into(),
56 name: None,
57 tool_calls: None,
58 tool_call_id: None,
59 }
60 }
61
62 pub fn user(content: impl Into<String>) -> Self {
63 Self {
64 role: Role::User,
65 content: content.into(),
66 name: None,
67 tool_calls: None,
68 tool_call_id: None,
69 }
70 }
71
72 pub fn assistant(content: impl Into<String>) -> Self {
73 Self {
74 role: Role::Assistant,
75 content: content.into(),
76 name: None,
77 tool_calls: None,
78 tool_call_id: None,
79 }
80 }
81
82 pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
83 Self {
84 role: Role::Tool,
85 content: content.into(),
86 name: None,
87 tool_calls: None,
88 tool_call_id: Some(tool_call_id.into()),
89 }
90 }
91}
92
93#[derive(Debug, Clone)]
94pub struct ChatSession {
95 pub messages: Vec<ChatMessage>,
96 pub system_prompt: Option<String>,
97}
98
99impl ChatSession {
100 pub fn new() -> Self {
101 Self {
102 messages: Vec::new(),
103 system_prompt: None,
104 }
105 }
106
107 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
108 self.system_prompt = Some(prompt.into());
109 self
110 }
111
112 pub fn add_message(&mut self, message: ChatMessage) {
113 self.messages.push(message);
114 }
115
116 pub fn add_user_message(&mut self, content: impl Into<String>) {
117 self.messages.push(ChatMessage::user(content));
118 }
119
120 pub fn add_assistant_message(&mut self, content: impl Into<String>) {
121 self.messages.push(ChatMessage::assistant(content));
122 }
123
124 pub fn get_messages(&self) -> Vec<ChatMessage> {
125 let mut messages = Vec::new();
126
127 if let Some(ref system_prompt) = self.system_prompt {
128 messages.push(ChatMessage::system(system_prompt.clone()));
129 }
130
131 messages.extend(self.messages.clone());
132 messages
133 }
134
135 pub fn clear(&mut self) {
136 self.messages.clear();
137 }
138}
139
140impl Default for ChatSession {
141 fn default() -> Self {
142 Self::new()
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149
150 #[test]
151 fn test_role_from_str() {
152 assert_eq!(Role::from("system"), Role::System);
153 assert_eq!(Role::from("user"), Role::User);
154 assert_eq!(Role::from("assistant"), Role::Assistant);
155 assert_eq!(Role::from("tool"), Role::Tool);
156 assert_eq!(Role::from("unknown"), Role::Assistant); assert_eq!(Role::from("SYSTEM"), Role::System); }
159
160 #[test]
161 fn test_chat_message_constructors() {
162 let system_msg = ChatMessage::system("System message");
163 assert_eq!(system_msg.role, Role::System);
164 assert_eq!(system_msg.content, "System message");
165 assert!(system_msg.name.is_none());
166 assert!(system_msg.tool_calls.is_none());
167 assert!(system_msg.tool_call_id.is_none());
168
169 let user_msg = ChatMessage::user("User message");
170 assert_eq!(user_msg.role, Role::User);
171 assert_eq!(user_msg.content, "User message");
172
173 let assistant_msg = ChatMessage::assistant("Assistant message");
174 assert_eq!(assistant_msg.role, Role::Assistant);
175 assert_eq!(assistant_msg.content, "Assistant message");
176
177 let tool_msg = ChatMessage::tool("Tool result", "tool_call_123");
178 assert_eq!(tool_msg.role, Role::Tool);
179 assert_eq!(tool_msg.content, "Tool result");
180 assert_eq!(tool_msg.tool_call_id, Some("tool_call_123".to_string()));
181 }
182
183 #[test]
184 fn test_chat_session_new() {
185 let session = ChatSession::new();
186 assert!(session.messages.is_empty());
187 assert!(session.system_prompt.is_none());
188 }
189
190 #[test]
191 fn test_chat_session_with_system_prompt() {
192 let session = ChatSession::new().with_system_prompt("Test system prompt");
193 assert_eq!(
194 session.system_prompt,
195 Some("Test system prompt".to_string())
196 );
197 }
198
199 #[test]
200 fn test_chat_session_add_message() {
201 let mut session = ChatSession::new();
202 let msg = ChatMessage::user("Test message");
203 session.add_message(msg);
204 assert_eq!(session.messages.len(), 1);
205 }
206
207 #[test]
208 fn test_chat_session_add_user_message() {
209 let mut session = ChatSession::new();
210 session.add_user_message("Test user message");
211 assert_eq!(session.messages.len(), 1);
212 assert_eq!(session.messages[0].role, Role::User);
213 assert_eq!(session.messages[0].content, "Test user message");
214 }
215
216 #[test]
217 fn test_chat_session_add_assistant_message() {
218 let mut session = ChatSession::new();
219 session.add_assistant_message("Test assistant message");
220 assert_eq!(session.messages.len(), 1);
221 assert_eq!(session.messages[0].role, Role::Assistant);
222 assert_eq!(session.messages[0].content, "Test assistant message");
223 }
224
225 #[test]
226 fn test_chat_session_get_messages() {
227 let mut session = ChatSession::new().with_system_prompt("System prompt");
228 session.add_user_message("User message");
229 session.add_assistant_message("Assistant message");
230
231 let messages = session.get_messages();
232 assert_eq!(messages.len(), 3); assert_eq!(messages[0].role, Role::System);
234 assert_eq!(messages[0].content, "System prompt");
235 assert_eq!(messages[1].role, Role::User);
236 assert_eq!(messages[1].content, "User message");
237 assert_eq!(messages[2].role, Role::Assistant);
238 assert_eq!(messages[2].content, "Assistant message");
239 }
240
241 #[test]
242 fn test_chat_session_clear() {
243 let mut session = ChatSession::new();
244 session.add_user_message("Test message");
245 assert!(!session.messages.is_empty());
246
247 session.clear();
248 assert!(session.messages.is_empty());
249 }
250}