Skip to main content

ai_lib_rust/types/
message.rs

1//! Unified message format based on AI-Protocol standard_schema
2
3use base64::Engine as _;
4use serde::{Deserialize, Serialize};
5use std::path::Path;
6
7/// Unified message structure
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct Message {
10    pub role: MessageRole,
11    pub content: MessageContent,
12}
13
14impl Message {
15    pub fn system(text: impl Into<String>) -> Self {
16        Self {
17            role: MessageRole::System,
18            content: MessageContent::Text(text.into()),
19        }
20    }
21
22    pub fn user(text: impl Into<String>) -> Self {
23        Self {
24            role: MessageRole::User,
25            content: MessageContent::Text(text.into()),
26        }
27    }
28
29    pub fn assistant(text: impl Into<String>) -> Self {
30        Self {
31            role: MessageRole::Assistant,
32            content: MessageContent::Text(text.into()),
33        }
34    }
35
36    pub fn with_content(role: MessageRole, content: MessageContent) -> Self {
37        Self { role, content }
38    }
39
40    pub fn contains_image(&self) -> bool {
41        match &self.content {
42            MessageContent::Text(_) => false,
43            MessageContent::Blocks(bs) => {
44                bs.iter().any(|b| matches!(b, ContentBlock::Image { .. }))
45            }
46        }
47    }
48
49    pub fn contains_audio(&self) -> bool {
50        match &self.content {
51            MessageContent::Text(_) => false,
52            MessageContent::Blocks(bs) => {
53                bs.iter().any(|b| matches!(b, ContentBlock::Audio { .. }))
54            }
55        }
56    }
57}
58
59/// Message role
60#[derive(Debug, Clone, Serialize, Deserialize)]
61#[serde(rename_all = "lowercase")]
62pub enum MessageRole {
63    System,
64    User,
65    Assistant,
66}
67
68/// Message content (can be string or array of content blocks)
69#[derive(Debug, Clone, Serialize, Deserialize)]
70#[serde(untagged)]
71pub enum MessageContent {
72    Text(String),
73    Blocks(Vec<ContentBlock>),
74}
75
76impl MessageContent {
77    pub fn text(text: impl Into<String>) -> Self {
78        MessageContent::Text(text.into())
79    }
80
81    pub fn blocks(blocks: Vec<ContentBlock>) -> Self {
82        MessageContent::Blocks(blocks)
83    }
84}
85
86/// Content block (for multimodal or tool results)
87#[derive(Debug, Clone, Serialize, Deserialize)]
88#[serde(tag = "type")]
89pub enum ContentBlock {
90    #[serde(rename = "text")]
91    Text { text: String },
92    #[serde(rename = "image")]
93    Image { source: ImageSource },
94    #[serde(rename = "audio")]
95    Audio { source: AudioSource },
96    #[serde(rename = "tool_use")]
97    ToolUse {
98        id: String,
99        name: String,
100        input: serde_json::Value,
101    },
102    #[serde(rename = "tool_result")]
103    ToolResult {
104        tool_use_id: String,
105        content: serde_json::Value,
106    },
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct ImageSource {
111    #[serde(rename = "type")]
112    pub source_type: String,
113    #[serde(default, skip_serializing_if = "Option::is_none")]
114    pub media_type: Option<String>,
115    pub data: String, // base64 encoded or URL
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct AudioSource {
120    #[serde(rename = "type")]
121    pub source_type: String,
122    #[serde(default, skip_serializing_if = "Option::is_none")]
123    pub media_type: Option<String>,
124    pub data: String, // base64 encoded or URL
125}
126
127impl ContentBlock {
128    pub fn text(text: impl Into<String>) -> Self {
129        ContentBlock::Text { text: text.into() }
130    }
131
132    pub fn image_base64(data: String, media_type: Option<String>) -> Self {
133        ContentBlock::Image {
134            source: ImageSource {
135                source_type: "base64".to_string(),
136                media_type,
137                data,
138            },
139        }
140    }
141
142    pub fn audio_base64(data: String, media_type: Option<String>) -> Self {
143        ContentBlock::Audio {
144            source: AudioSource {
145                source_type: "base64".to_string(),
146                media_type,
147                data,
148            },
149        }
150    }
151
152    pub fn image_from_file(path: impl AsRef<Path>) -> crate::Result<Self> {
153        let path = path.as_ref();
154        let bytes = std::fs::read(path)?;
155        let media_type = guess_media_type(path);
156        let data = base64::engine::general_purpose::STANDARD.encode(bytes);
157        Ok(Self::image_base64(data, media_type))
158    }
159
160    pub fn audio_from_file(path: impl AsRef<Path>) -> crate::Result<Self> {
161        let path = path.as_ref();
162        let bytes = std::fs::read(path)?;
163        let media_type = guess_media_type(path);
164        let data = base64::engine::general_purpose::STANDARD.encode(bytes);
165        Ok(Self::audio_base64(data, media_type))
166    }
167}
168
169fn guess_media_type(path: &Path) -> Option<String> {
170    let ext = path
171        .extension()
172        .and_then(|s| s.to_str())
173        .unwrap_or("")
174        .to_lowercase();
175    let mt = match ext.as_str() {
176        "png" => "image/png",
177        "jpg" | "jpeg" => "image/jpeg",
178        "webp" => "image/webp",
179        "gif" => "image/gif",
180        "mp3" => "audio/mpeg",
181        "wav" => "audio/wav",
182        "ogg" => "audio/ogg",
183        "m4a" => "audio/mp4",
184        _ => return None,
185    };
186    Some(mt.to_string())
187}