1use serde::{Deserialize, Serialize};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct Message {
71 pub role: Role,
73
74 #[serde(skip_serializing_if = "Option::is_none")]
78 pub content: Option<String>,
79
80 #[serde(skip_serializing_if = "Option::is_none")]
84 pub tool_calls: Option<Vec<ToolCall>>,
85
86 #[serde(skip_serializing_if = "Option::is_none")]
90 pub tool_call_id: Option<String>,
91
92 #[serde(skip_serializing_if = "Option::is_none")]
94 pub cache_control: Option<CacheControl>,
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
107#[serde(rename_all = "lowercase")]
108pub enum Role {
109 User,
111
112 Assistant,
114
115 System,
117
118 Tool,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
127pub struct CacheControl {
128 #[serde(rename = "type")]
130 pub cache_type: String,
131
132 #[serde(skip_serializing_if = "Option::is_none")]
135 pub ttl: Option<String>,
136}
137
138impl CacheControl {
139 pub fn ephemeral() -> Self {
141 Self {
142 cache_type: "ephemeral".to_string(),
143 ttl: None,
144 }
145 }
146
147 pub fn ephemeral_long() -> Self {
149 Self {
150 cache_type: "ephemeral".to_string(),
151 ttl: Some("1h".to_string()),
152 }
153 }
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct ToolCall {
162 pub id: String,
164
165 #[serde(rename = "type")]
167 pub tool_type: String,
168
169 pub function: FunctionCall,
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct FunctionCall {
179 pub name: String,
181
182 pub arguments: String,
187}
188
189#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct Tool {
220 #[serde(rename = "type")]
222 pub tool_type: String,
223
224 pub function: ToolFunction,
226}
227
228#[derive(Debug, Clone, Serialize, Deserialize)]
230pub struct ToolFunction {
231 pub name: String,
233
234 pub description: String,
237
238 pub parameters: serde_json::Value,
242}
243
244#[derive(Debug, Clone, Serialize, Deserialize)]
246pub struct Response {
247 pub content: String,
249
250 #[serde(skip_serializing_if = "Option::is_none")]
252 pub tool_calls: Option<Vec<ToolCall>>,
253
254 pub usage: Usage,
256}
257
258#[derive(Debug, Clone, Serialize, Deserialize)]
264pub struct Usage {
265 pub input_tokens: u64,
267
268 pub output_tokens: u64,
270
271 #[serde(default, alias = "cache_read_input_tokens")]
273 pub cache_read_tokens: u64,
274
275 #[serde(default, alias = "cache_creation_input_tokens")]
277 pub cache_write_tokens: u64,
278}
279
280impl Usage {
281 pub fn total_tokens(&self) -> u64 {
283 self.input_tokens + self.output_tokens + self.cache_read_tokens + self.cache_write_tokens
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 #[test]
292 fn test_message_serialization() {
293 let msg = Message {
294 role: Role::User,
295 content: Some("Hello".to_string()),
296 tool_calls: None,
297 tool_call_id: None,
298 cache_control: None,
299 };
300 let json = serde_json::to_string(&msg).unwrap();
301 let deserialized: Message = serde_json::from_str(&json).unwrap();
302 assert_eq!(msg.content, deserialized.content);
303 }
304
305 #[test]
306 fn test_message_with_tool_calls() {
307 let msg = Message {
308 role: Role::Assistant,
309 content: Some("".to_string()),
310 tool_calls: Some(vec![ToolCall {
311 id: "call_123".to_string(),
312 tool_type: "function".to_string(),
313 function: FunctionCall {
314 name: "test_tool".to_string(),
315 arguments: serde_json::json!({"arg": "value"}).to_string(),
316 },
317 }]),
318 tool_call_id: None,
319 cache_control: None,
320 };
321 let json = serde_json::to_string(&msg).unwrap();
322 let deserialized: Message = serde_json::from_str(&json).unwrap();
323 assert!(deserialized.tool_calls.is_some());
324 }
325
326 #[test]
327 fn test_tool_result_message() {
328 let msg = Message {
329 role: Role::Tool,
330 content: Some("result output".to_string()),
331 tool_calls: None,
332 tool_call_id: Some("call_123".to_string()),
333 cache_control: None,
334 };
335 let json = serde_json::to_string(&msg).unwrap();
336 println!("Tool result message JSON: {}", json);
337 assert!(json.contains("tool_call_id"));
338 let deserialized: Message = serde_json::from_str(&json).unwrap();
339 assert_eq!(deserialized.tool_call_id, Some("call_123".to_string()));
340 }
341
342 #[test]
343 fn test_assistant_with_tool_calls_serialization() {
344 let msg = Message {
345 role: Role::Assistant,
346 content: None,
347 tool_calls: Some(vec![ToolCall {
348 id: "call_123".to_string(),
349 tool_type: "function".to_string(),
350 function: FunctionCall {
351 name: "test_tool".to_string(),
352 arguments: serde_json::json!({}).to_string(),
353 },
354 }]),
355 tool_call_id: None,
356 cache_control: None,
357 };
358 let json = serde_json::to_string(&msg).unwrap();
359 println!("Assistant with tool_calls JSON: {}", json);
360 assert!(!json.contains("\"content\":null"));
361 assert!(json.contains("tool_calls"));
362 }
363
364 #[test]
365 fn test_role_serialization() {
366 let role = Role::User;
367 let json = serde_json::to_string(&role).unwrap();
368 assert_eq!(json, "\"user\"");
369 }
370
371 #[test]
372 fn test_tool_serialization() {
373 let tool = Tool {
374 tool_type: "function".to_string(),
375 function: ToolFunction {
376 name: "test_tool".to_string(),
377 description: "A test tool".to_string(),
378 parameters: serde_json::json!({"type": "object"}),
379 },
380 };
381 let json = serde_json::to_string(&tool).unwrap();
382 let deserialized: Tool = serde_json::from_str(&json).unwrap();
383 assert_eq!(tool.function.name, deserialized.function.name);
384 }
385
386 #[test]
387 fn test_response_serialization() {
388 let response = Response {
389 content: "Hello, world!".to_string(),
390 tool_calls: None,
391 usage: Usage {
392 input_tokens: 10,
393 output_tokens: 5,
394 cache_read_tokens: 0,
395 cache_write_tokens: 0,
396 },
397 };
398 let json = serde_json::to_string(&response).unwrap();
399 let deserialized: Response = serde_json::from_str(&json).unwrap();
400 assert_eq!(response.content, deserialized.content);
401 assert_eq!(response.usage.input_tokens, deserialized.usage.input_tokens);
402 }
403
404 #[test]
405 fn test_usage_serialization() {
406 let usage = Usage {
407 input_tokens: 100,
408 output_tokens: 50,
409 cache_read_tokens: 0,
410 cache_write_tokens: 0,
411 };
412 let json = serde_json::to_string(&usage).unwrap();
413 let deserialized: Usage = serde_json::from_str(&json).unwrap();
414 assert_eq!(usage.input_tokens, deserialized.input_tokens);
415 assert_eq!(usage.output_tokens, deserialized.output_tokens);
416 }
417
418 #[test]
419 fn test_cache_control_serialization() {
420 let cache = CacheControl::ephemeral();
421 let json = serde_json::to_string(&cache).unwrap();
422 assert_eq!(json, r#"{"type":"ephemeral"}"#);
423
424 let cache_long = CacheControl::ephemeral_long();
425 let json_long = serde_json::to_string(&cache_long).unwrap();
426 assert!(json_long.contains(r#""ttl":"1h""#));
427 }
428
429 #[test]
430 fn test_message_with_cache_control() {
431 let msg = Message {
432 role: Role::User,
433 content: Some("Hello".to_string()),
434 tool_calls: None,
435 tool_call_id: None,
436 cache_control: Some(CacheControl::ephemeral()),
437 };
438 let json = serde_json::to_string(&msg).unwrap();
439 assert!(json.contains("cache_control"));
440 let deserialized: Message = serde_json::from_str(&json).unwrap();
441 assert!(deserialized.cache_control.is_some());
442 }
443
444 #[test]
445 fn test_usage_with_cache_fields() {
446 let usage = Usage {
447 input_tokens: 100,
448 output_tokens: 50,
449 cache_read_tokens: 80,
450 cache_write_tokens: 20,
451 };
452 assert_eq!(usage.total_tokens(), 250);
453
454 let json = serde_json::to_string(&usage).unwrap();
455 assert!(json.contains("cache_read_tokens"));
456 }
457
458 #[test]
459 fn test_usage_anthropic_aliases() {
460 let json = r#"{
461 "input_tokens": 100,
462 "output_tokens": 50,
463 "cache_read_input_tokens": 80,
464 "cache_creation_input_tokens": 20
465 }"#;
466 let usage: Usage = serde_json::from_str(json).unwrap();
467 assert_eq!(usage.input_tokens, 100);
468 assert_eq!(usage.output_tokens, 50);
469 assert_eq!(usage.cache_read_tokens, 80);
470 assert_eq!(usage.cache_write_tokens, 20);
471 assert_eq!(usage.total_tokens(), 250);
472 }
473}