Skip to main content

ai_lib_rust/types/
message.rs

1//! Unified message format based on AI-Protocol standard_schema
2
3use base64::Engine as _;
4use serde::{Deserialize, Serialize};
5use std::path::Path;
6
7/// Unified message structure
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct Message {
10    pub role: MessageRole,
11    pub content: MessageContent,
12    /// Required when role is Tool (OpenAI API: tool_call_id).
13    #[serde(default, skip_serializing_if = "Option::is_none")]
14    pub tool_call_id: Option<String>,
15}
16
17impl Message {
18    pub fn system(text: impl Into<String>) -> Self {
19        Self {
20            role: MessageRole::System,
21            content: MessageContent::Text(text.into()),
22            tool_call_id: None,
23        }
24    }
25
26    pub fn user(text: impl Into<String>) -> Self {
27        Self {
28            role: MessageRole::User,
29            content: MessageContent::Text(text.into()),
30            tool_call_id: None,
31        }
32    }
33
34    pub fn assistant(text: impl Into<String>) -> Self {
35        Self {
36            role: MessageRole::Assistant,
37            content: MessageContent::Text(text.into()),
38            tool_call_id: None,
39        }
40    }
41
42    /// Create a tool result message for multi-turn tool calling.
43    ///
44    /// OpenAI and similar APIs expect `role: "tool"` with `tool_call_id` and `content`.
45    pub fn tool(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
46        Self {
47            role: MessageRole::Tool,
48            content: MessageContent::Text(content.into()),
49            tool_call_id: Some(tool_call_id.into()),
50        }
51    }
52
53    pub fn with_content(role: MessageRole, content: MessageContent) -> Self {
54        Self {
55            role,
56            content,
57            tool_call_id: None,
58        }
59    }
60
61    pub fn contains_image(&self) -> bool {
62        match &self.content {
63            MessageContent::Text(_) => false,
64            MessageContent::Blocks(bs) => {
65                bs.iter().any(|b| matches!(b, ContentBlock::Image { .. }))
66            }
67        }
68    }
69
70    pub fn contains_audio(&self) -> bool {
71        match &self.content {
72            MessageContent::Text(_) => false,
73            MessageContent::Blocks(bs) => {
74                bs.iter().any(|b| matches!(b, ContentBlock::Audio { .. }))
75            }
76        }
77    }
78}
79
80/// Message role
81#[derive(Debug, Clone, Serialize, Deserialize)]
82#[serde(rename_all = "lowercase")]
83pub enum MessageRole {
84    System,
85    User,
86    Assistant,
87    /// Tool result message (OpenAI API: role "tool").
88    Tool,
89}
90
91/// Message content (can be string or array of content blocks)
92#[derive(Debug, Clone, Serialize, Deserialize)]
93#[serde(untagged)]
94pub enum MessageContent {
95    Text(String),
96    Blocks(Vec<ContentBlock>),
97}
98
99impl MessageContent {
100    pub fn text(text: impl Into<String>) -> Self {
101        MessageContent::Text(text.into())
102    }
103
104    pub fn blocks(blocks: Vec<ContentBlock>) -> Self {
105        MessageContent::Blocks(blocks)
106    }
107}
108
109/// Content block (for multimodal or tool results)
110#[derive(Debug, Clone, Serialize, Deserialize)]
111#[serde(tag = "type")]
112pub enum ContentBlock {
113    #[serde(rename = "text")]
114    Text { text: String },
115    #[serde(rename = "image")]
116    Image { source: ImageSource },
117    #[serde(rename = "audio")]
118    Audio { source: AudioSource },
119    #[serde(rename = "tool_use")]
120    ToolUse {
121        id: String,
122        name: String,
123        input: serde_json::Value,
124    },
125    #[serde(rename = "tool_result")]
126    ToolResult {
127        tool_use_id: String,
128        content: serde_json::Value,
129    },
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct ImageSource {
134    #[serde(rename = "type")]
135    pub source_type: String,
136    #[serde(default, skip_serializing_if = "Option::is_none")]
137    pub media_type: Option<String>,
138    pub data: String, // base64 encoded or URL
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct AudioSource {
143    #[serde(rename = "type")]
144    pub source_type: String,
145    #[serde(default, skip_serializing_if = "Option::is_none")]
146    pub media_type: Option<String>,
147    pub data: String, // base64 encoded or URL
148}
149
150impl ContentBlock {
151    pub fn text(text: impl Into<String>) -> Self {
152        ContentBlock::Text { text: text.into() }
153    }
154
155    pub fn image_base64(data: String, media_type: Option<String>) -> Self {
156        ContentBlock::Image {
157            source: ImageSource {
158                source_type: "base64".to_string(),
159                media_type,
160                data,
161            },
162        }
163    }
164
165    pub fn audio_base64(data: String, media_type: Option<String>) -> Self {
166        ContentBlock::Audio {
167            source: AudioSource {
168                source_type: "base64".to_string(),
169                media_type,
170                data,
171            },
172        }
173    }
174
175    pub fn image_from_file(path: impl AsRef<Path>) -> crate::Result<Self> {
176        let path = path.as_ref();
177        let bytes = std::fs::read(path)?;
178        let media_type = guess_media_type(path);
179        let data = base64::engine::general_purpose::STANDARD.encode(bytes);
180        Ok(Self::image_base64(data, media_type))
181    }
182
183    pub fn audio_from_file(path: impl AsRef<Path>) -> crate::Result<Self> {
184        let path = path.as_ref();
185        let bytes = std::fs::read(path)?;
186        let media_type = guess_media_type(path);
187        let data = base64::engine::general_purpose::STANDARD.encode(bytes);
188        Ok(Self::audio_base64(data, media_type))
189    }
190}
191
192fn guess_media_type(path: &Path) -> Option<String> {
193    let ext = path
194        .extension()
195        .and_then(|s| s.to_str())
196        .unwrap_or("")
197        .to_lowercase();
198    let mt = match ext.as_str() {
199        "png" => "image/png",
200        "jpg" | "jpeg" => "image/jpeg",
201        "webp" => "image/webp",
202        "gif" => "image/gif",
203        "mp3" => "audio/mpeg",
204        "wav" => "audio/wav",
205        "ogg" => "audio/ogg",
206        "m4a" => "audio/mp4",
207        _ => return None,
208    };
209    Some(mt.to_string())
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215
216    #[test]
217    fn test_message_tool() {
218        let msg = Message::tool("call_abc123", "42");
219        assert!(matches!(msg.role, MessageRole::Tool));
220        assert_eq!(msg.tool_call_id.as_deref(), Some("call_abc123"));
221        if let MessageContent::Text(s) = msg.content {
222            assert_eq!(s, "42");
223        } else {
224            panic!("expected Text content");
225        }
226    }
227
228    #[test]
229    fn test_message_role_serialization() {
230        let msg = Message::tool("call_xyz", "result");
231        let json = serde_json::to_value(&msg).unwrap();
232        assert_eq!(json["role"], "tool");
233        assert_eq!(json["content"], "result");
234        assert_eq!(json["tool_call_id"], "call_xyz");
235    }
236}