1use serde::{Deserialize, Serialize};
15
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18#[serde(tag = "type", rename_all = "lowercase")]
19pub enum ContentPart {
20 Text { text: String },
22 #[serde(rename = "image_url")]
24 ImageUrl { image_url: ImageUrl },
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
29pub struct ImageUrl {
30 pub url: String,
32 #[serde(skip_serializing_if = "Option::is_none")]
34 pub detail: Option<String>,
35}
36
37impl ContentPart {
38 pub fn text(text: impl Into<String>) -> Self {
40 ContentPart::Text { text: text.into() }
41 }
42
43 pub fn image_url(url: impl Into<String>) -> Self {
45 ContentPart::ImageUrl {
46 image_url: ImageUrl {
47 url: url.into(),
48 detail: None,
49 },
50 }
51 }
52
53 pub fn image_base64(media_type: &str, base64_data: &str) -> Self {
55 ContentPart::ImageUrl {
56 image_url: ImageUrl {
57 url: format!("data:{};base64,{}", media_type, base64_data),
58 detail: None,
59 },
60 }
61 }
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct Message {
158 pub role: Role,
160
161 #[serde(skip_serializing_if = "Option::is_none")]
166 pub content: Option<MessageContent>,
167
168 #[serde(skip_serializing_if = "Option::is_none")]
172 pub tool_calls: Option<Vec<ToolCall>>,
173
174 #[serde(skip_serializing_if = "Option::is_none")]
178 pub tool_call_id: Option<String>,
179
180 #[serde(skip_serializing_if = "Option::is_none")]
182 pub cache_control: Option<CacheControl>,
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
187#[serde(untagged)]
188pub enum MessageContent {
189 Text(String),
191 Parts(Vec<ContentPart>),
193}
194
195impl MessageContent {
196 pub fn text(text: impl Into<String>) -> Self {
198 MessageContent::Text(text.into())
199 }
200
201 pub fn parts(parts: Vec<ContentPart>) -> Self {
203 MessageContent::Parts(parts)
204 }
205
206 pub fn as_text(&self) -> Option<&str> {
208 match self {
209 MessageContent::Text(text) => Some(text),
210 MessageContent::Parts(_) => None,
211 }
212 }
213
214 pub fn to_text(&self) -> String {
216 match self {
217 MessageContent::Text(text) => text.clone(),
218 MessageContent::Parts(parts) => parts
219 .iter()
220 .filter_map(|part| match part {
221 ContentPart::Text { text } => Some(text.clone()),
222 _ => None,
223 })
224 .collect::<Vec<_>>()
225 .join(""),
226 }
227 }
228}
229
230impl std::fmt::Display for MessageContent {
231 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
232 match self {
233 MessageContent::Text(text) => write!(f, "{}", text),
234 MessageContent::Parts(parts) => {
235 let text = parts
236 .iter()
237 .filter_map(|p| match p {
238 ContentPart::Text { text } => Some(text.as_str()),
239 _ => None,
240 })
241 .collect::<Vec<_>>()
242 .join("");
243 write!(f, "{}", text)
244 }
245 }
246 }
247}
248
249impl From<String> for MessageContent {
250 fn from(text: String) -> Self {
251 MessageContent::Text(text)
252 }
253}
254
255impl From<&str> for MessageContent {
256 fn from(text: &str) -> Self {
257 MessageContent::Text(text.to_string())
258 }
259}
260
261#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
271#[serde(rename_all = "lowercase")]
272pub enum Role {
273 User,
275
276 Assistant,
278
279 System,
281
282 Tool,
284}
285
286#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
291pub struct CacheControl {
292 #[serde(rename = "type")]
294 pub cache_type: String,
295
296 #[serde(skip_serializing_if = "Option::is_none")]
299 pub ttl: Option<String>,
300}
301
302impl CacheControl {
303 pub fn ephemeral() -> Self {
305 Self {
306 cache_type: "ephemeral".to_string(),
307 ttl: None,
308 }
309 }
310
311 pub fn ephemeral_long() -> Self {
313 Self {
314 cache_type: "ephemeral".to_string(),
315 ttl: Some("1h".to_string()),
316 }
317 }
318}
319
320#[derive(Debug, Clone, Serialize, Deserialize)]
325pub struct ToolCall {
326 pub id: String,
328
329 #[serde(rename = "type")]
331 pub tool_type: String,
332
333 pub function: FunctionCall,
335}
336
337#[derive(Debug, Clone, Serialize, Deserialize)]
342pub struct FunctionCall {
343 pub name: String,
345
346 pub arguments: String,
351}
352
353#[derive(Debug, Clone, Serialize, Deserialize)]
383pub struct Tool {
384 #[serde(rename = "type")]
386 pub tool_type: String,
387
388 pub function: ToolFunction,
390}
391
392#[derive(Debug, Clone, Serialize, Deserialize)]
394pub struct ToolFunction {
395 pub name: String,
397
398 pub description: String,
401
402 pub parameters: serde_json::Value,
406}
407
408#[derive(Debug, Clone, Serialize, Deserialize)]
410pub struct Response {
411 pub content: String,
413
414 #[serde(skip_serializing_if = "Option::is_none")]
416 pub tool_calls: Option<Vec<ToolCall>>,
417
418 pub usage: Usage,
420}
421
422#[derive(Debug, Clone, Serialize, Deserialize)]
428pub struct Usage {
429 pub input_tokens: u64,
431
432 pub output_tokens: u64,
434
435 #[serde(default, alias = "cache_read_input_tokens")]
437 pub cache_read_tokens: u64,
438
439 #[serde(default, alias = "cache_creation_input_tokens")]
441 pub cache_write_tokens: u64,
442}
443
444impl Usage {
445 pub fn total_tokens(&self) -> u64 {
447 self.input_tokens + self.output_tokens + self.cache_read_tokens + self.cache_write_tokens
448 }
449}
450
451#[cfg(test)]
452mod tests {
453 use super::*;
454
455 #[test]
456 fn test_message_serialization() {
457 let msg = Message {
458 role: Role::User,
459 content: Some(MessageContent::text("Hello")),
460 tool_calls: None,
461 tool_call_id: None,
462 cache_control: None,
463 };
464 let json = serde_json::to_string(&msg).unwrap();
465 let deserialized: Message = serde_json::from_str(&json).unwrap();
466 assert_eq!(msg.content, deserialized.content);
467 }
468
469 #[test]
470 fn test_message_with_tool_calls() {
471 let msg = Message {
472 role: Role::Assistant,
473 content: Some(MessageContent::text("")),
474 tool_calls: Some(vec![ToolCall {
475 id: "call_123".to_string(),
476 tool_type: "function".to_string(),
477 function: FunctionCall {
478 name: "test_tool".to_string(),
479 arguments: serde_json::json!({"arg": "value"}).to_string(),
480 },
481 }]),
482 tool_call_id: None,
483 cache_control: None,
484 };
485 let json = serde_json::to_string(&msg).unwrap();
486 let deserialized: Message = serde_json::from_str(&json).unwrap();
487 assert!(deserialized.tool_calls.is_some());
488 }
489
490 #[test]
491 fn test_tool_result_message() {
492 let msg = Message {
493 role: Role::Tool,
494 content: Some(MessageContent::text("result output")),
495 tool_calls: None,
496 tool_call_id: Some("call_123".to_string()),
497 cache_control: None,
498 };
499 let json = serde_json::to_string(&msg).unwrap();
500 println!("Tool result message JSON: {}", json);
501 assert!(json.contains("tool_call_id"));
502 let deserialized: Message = serde_json::from_str(&json).unwrap();
503 assert_eq!(deserialized.tool_call_id, Some("call_123".to_string()));
504 }
505
506 #[test]
507 fn test_assistant_with_tool_calls_serialization() {
508 let msg = Message {
509 role: Role::Assistant,
510 content: None,
511 tool_calls: Some(vec![ToolCall {
512 id: "call_123".to_string(),
513 tool_type: "function".to_string(),
514 function: FunctionCall {
515 name: "test_tool".to_string(),
516 arguments: serde_json::json!({}).to_string(),
517 },
518 }]),
519 tool_call_id: None,
520 cache_control: None,
521 };
522 let json = serde_json::to_string(&msg).unwrap();
523 println!("Assistant with tool_calls JSON: {}", json);
524 assert!(!json.contains("\"content\":null"));
525 assert!(json.contains("tool_calls"));
526 }
527
528 #[test]
529 fn test_role_serialization() {
530 let role = Role::User;
531 let json = serde_json::to_string(&role).unwrap();
532 assert_eq!(json, "\"user\"");
533 }
534
535 #[test]
536 fn test_tool_serialization() {
537 let tool = Tool {
538 tool_type: "function".to_string(),
539 function: ToolFunction {
540 name: "test_tool".to_string(),
541 description: "A test tool".to_string(),
542 parameters: serde_json::json!({"type": "object"}),
543 },
544 };
545 let json = serde_json::to_string(&tool).unwrap();
546 let deserialized: Tool = serde_json::from_str(&json).unwrap();
547 assert_eq!(tool.function.name, deserialized.function.name);
548 }
549
550 #[test]
551 fn test_response_serialization() {
552 let response = Response {
553 content: "Hello, world!".to_string(),
554 tool_calls: None,
555 usage: Usage {
556 input_tokens: 10,
557 output_tokens: 5,
558 cache_read_tokens: 0,
559 cache_write_tokens: 0,
560 },
561 };
562 let json = serde_json::to_string(&response).unwrap();
563 let deserialized: Response = serde_json::from_str(&json).unwrap();
564 assert_eq!(response.content, deserialized.content);
565 assert_eq!(response.usage.input_tokens, deserialized.usage.input_tokens);
566 }
567
568 #[test]
569 fn test_usage_serialization() {
570 let usage = Usage {
571 input_tokens: 100,
572 output_tokens: 50,
573 cache_read_tokens: 0,
574 cache_write_tokens: 0,
575 };
576 let json = serde_json::to_string(&usage).unwrap();
577 let deserialized: Usage = serde_json::from_str(&json).unwrap();
578 assert_eq!(usage.input_tokens, deserialized.input_tokens);
579 assert_eq!(usage.output_tokens, deserialized.output_tokens);
580 }
581
582 #[test]
583 fn test_cache_control_serialization() {
584 let cache = CacheControl::ephemeral();
585 let json = serde_json::to_string(&cache).unwrap();
586 assert_eq!(json, r#"{"type":"ephemeral"}"#);
587
588 let cache_long = CacheControl::ephemeral_long();
589 let json_long = serde_json::to_string(&cache_long).unwrap();
590 assert!(json_long.contains(r#""ttl":"1h""#));
591 }
592
593 #[test]
594 fn test_message_with_cache_control() {
595 let msg = Message {
596 role: Role::User,
597 content: Some(MessageContent::text("Hello")),
598 tool_calls: None,
599 tool_call_id: None,
600 cache_control: Some(CacheControl::ephemeral()),
601 };
602 let json = serde_json::to_string(&msg).unwrap();
603 assert!(json.contains("cache_control"));
604 let deserialized: Message = serde_json::from_str(&json).unwrap();
605 assert!(deserialized.cache_control.is_some());
606 }
607
608 #[test]
609 fn test_usage_with_cache_fields() {
610 let usage = Usage {
611 input_tokens: 100,
612 output_tokens: 50,
613 cache_read_tokens: 80,
614 cache_write_tokens: 20,
615 };
616 assert_eq!(usage.total_tokens(), 250);
617
618 let json = serde_json::to_string(&usage).unwrap();
619 assert!(json.contains("cache_read_tokens"));
620 }
621
622 #[test]
623 fn test_usage_anthropic_aliases() {
624 let json = r#"{
625 "input_tokens": 100,
626 "output_tokens": 50,
627 "cache_read_input_tokens": 80,
628 "cache_creation_input_tokens": 20
629 }"#;
630 let usage: Usage = serde_json::from_str(json).unwrap();
631 assert_eq!(usage.input_tokens, 100);
632 assert_eq!(usage.output_tokens, 50);
633 assert_eq!(usage.cache_read_tokens, 80);
634 assert_eq!(usage.cache_write_tokens, 20);
635 assert_eq!(usage.total_tokens(), 250);
636 }
637}