1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize)]
4pub struct Message {
5 pub role: Role,
6 #[serde(skip_serializing_if = "Option::is_none")]
7 pub content: Option<String>,
8 #[serde(skip_serializing_if = "Option::is_none")]
9 pub tool_calls: Option<Vec<ToolCall>>,
10 #[serde(skip_serializing_if = "Option::is_none")]
11 pub tool_call_id: Option<String>,
12}
13
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
15#[serde(rename_all = "lowercase")]
16pub enum Role {
17 User,
18 Assistant,
19 System,
20 Tool,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ToolCall {
25 pub id: String,
26 #[serde(rename = "type")]
27 pub tool_type: String,
28 pub function: FunctionCall,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct FunctionCall {
33 pub name: String,
34 pub arguments: String,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct Tool {
40 #[serde(rename = "type")]
41 pub tool_type: String,
42 pub function: ToolFunction,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct ToolFunction {
47 pub name: String,
48 pub description: String,
49 pub parameters: serde_json::Value,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct Response {
54 pub content: String,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 pub tool_calls: Option<Vec<ToolCall>>,
57 pub usage: Usage,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct Usage {
62 pub input_tokens: u64,
63 pub output_tokens: u64,
64}
65
66#[cfg(test)]
67mod tests {
68 use super::*;
69
70 #[test]
71 fn test_message_serialization() {
72 let msg = Message {
73 role: Role::User,
74 content: Some("Hello".to_string()),
75 tool_calls: None,
76 tool_call_id: None,
77 };
78 let json = serde_json::to_string(&msg).unwrap();
79 let deserialized: Message = serde_json::from_str(&json).unwrap();
80 assert_eq!(msg.content, deserialized.content);
81 }
82
83 #[test]
84 fn test_message_with_tool_calls() {
85 let msg = Message {
86 role: Role::Assistant,
87 content: Some("".to_string()),
88 tool_calls: Some(vec![ToolCall {
89 id: "call_123".to_string(),
90 tool_type: "function".to_string(),
91 function: FunctionCall {
92 name: "test_tool".to_string(),
93 arguments: serde_json::json!({"arg": "value"}).to_string(),
94 },
95 }]),
96 tool_call_id: None,
97 };
98 let json = serde_json::to_string(&msg).unwrap();
99 let deserialized: Message = serde_json::from_str(&json).unwrap();
100 assert!(deserialized.tool_calls.is_some());
101 }
102
103 #[test]
104 fn test_tool_result_message() {
105 let msg = Message {
106 role: Role::Tool,
107 content: Some("result output".to_string()),
108 tool_calls: None,
109 tool_call_id: Some("call_123".to_string()),
110 };
111 let json = serde_json::to_string(&msg).unwrap();
112 println!("Tool result message JSON: {}", json);
113 assert!(json.contains("tool_call_id"));
114 let deserialized: Message = serde_json::from_str(&json).unwrap();
115 assert_eq!(deserialized.tool_call_id, Some("call_123".to_string()));
116 }
117
118 #[test]
119 fn test_assistant_with_tool_calls_serialization() {
120 let msg = Message {
121 role: Role::Assistant,
122 content: None, tool_calls: Some(vec![ToolCall {
124 id: "call_123".to_string(),
125 tool_type: "function".to_string(),
126 function: FunctionCall {
127 name: "test_tool".to_string(),
128 arguments: serde_json::json!({}).to_string(),
129 },
130 }]),
131 tool_call_id: None,
132 };
133 let json = serde_json::to_string(&msg).unwrap();
134 println!("Assistant with tool_calls JSON: {}", json);
135 assert!(!json.contains("\"content\":null"));
137 assert!(json.contains("tool_calls"));
138 }
139
140 #[test]
141 fn test_role_serialization() {
142 let role = Role::User;
143 let json = serde_json::to_string(&role).unwrap();
144 assert_eq!(json, "\"user\"");
145 }
146
147 #[test]
148 fn test_tool_serialization() {
149 let tool = Tool {
150 tool_type: "function".to_string(),
151 function: ToolFunction {
152 name: "test_tool".to_string(),
153 description: "A test tool".to_string(),
154 parameters: serde_json::json!({"type": "object"}),
155 },
156 };
157 let json = serde_json::to_string(&tool).unwrap();
158 let deserialized: Tool = serde_json::from_str(&json).unwrap();
159 assert_eq!(tool.function.name, deserialized.function.name);
160 }
161
162 #[test]
163 fn test_response_serialization() {
164 let response = Response {
165 content: "Hello, world!".to_string(),
166 tool_calls: None,
167 usage: Usage {
168 input_tokens: 10,
169 output_tokens: 5,
170 },
171 };
172 let json = serde_json::to_string(&response).unwrap();
173 let deserialized: Response = serde_json::from_str(&json).unwrap();
174 assert_eq!(response.content, deserialized.content);
175 assert_eq!(response.usage.input_tokens, deserialized.usage.input_tokens);
176 }
177
178 #[test]
179 fn test_usage_serialization() {
180 let usage = Usage {
181 input_tokens: 100,
182 output_tokens: 50,
183 };
184 let json = serde_json::to_string(&usage).unwrap();
185 let deserialized: Usage = serde_json::from_str(&json).unwrap();
186 assert_eq!(usage.input_tokens, deserialized.input_tokens);
187 assert_eq!(usage.output_tokens, deserialized.output_tokens);
188 }
189}