ai_lib_rust/types/
message.rs

1//! Unified message format based on AI-Protocol standard_schema
2
3use serde::{Deserialize, Serialize};
4use base64::Engine as _;
5use std::path::Path;
6
7/// Unified message structure
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct Message {
10    pub role: MessageRole,
11    pub content: MessageContent,
12}
13
14impl Message {
15    pub fn system(text: impl Into<String>) -> Self {
16        Self {
17            role: MessageRole::System,
18            content: MessageContent::Text(text.into()),
19        }
20    }
21
22    pub fn user(text: impl Into<String>) -> Self {
23        Self {
24            role: MessageRole::User,
25            content: MessageContent::Text(text.into()),
26        }
27    }
28
29    pub fn assistant(text: impl Into<String>) -> Self {
30        Self {
31            role: MessageRole::Assistant,
32            content: MessageContent::Text(text.into()),
33        }
34    }
35
36    pub fn with_content(role: MessageRole, content: MessageContent) -> Self {
37        Self { role, content }
38    }
39
40    pub fn contains_image(&self) -> bool {
41        match &self.content {
42            MessageContent::Text(_) => false,
43            MessageContent::Blocks(bs) => bs.iter().any(|b| matches!(b, ContentBlock::Image { .. })),
44        }
45    }
46
47    pub fn contains_audio(&self) -> bool {
48        match &self.content {
49            MessageContent::Text(_) => false,
50            MessageContent::Blocks(bs) => bs.iter().any(|b| matches!(b, ContentBlock::Audio { .. })),
51        }
52    }
53}
54
55/// Message role
56#[derive(Debug, Clone, Serialize, Deserialize)]
57#[serde(rename_all = "lowercase")]
58pub enum MessageRole {
59    System,
60    User,
61    Assistant,
62}
63
64/// Message content (can be string or array of content blocks)
65#[derive(Debug, Clone, Serialize, Deserialize)]
66#[serde(untagged)]
67pub enum MessageContent {
68    Text(String),
69    Blocks(Vec<ContentBlock>),
70}
71
72impl MessageContent {
73    pub fn text(text: impl Into<String>) -> Self {
74        MessageContent::Text(text.into())
75    }
76
77    pub fn blocks(blocks: Vec<ContentBlock>) -> Self {
78        MessageContent::Blocks(blocks)
79    }
80}
81
82/// Content block (for multimodal or tool results)
83#[derive(Debug, Clone, Serialize, Deserialize)]
84#[serde(tag = "type")]
85pub enum ContentBlock {
86    #[serde(rename = "text")]
87    Text { text: String },
88    #[serde(rename = "image")]
89    Image { source: ImageSource },
90    #[serde(rename = "audio")]
91    Audio { source: AudioSource },
92    #[serde(rename = "tool_use")]
93    ToolUse {
94        id: String,
95        name: String,
96        input: serde_json::Value,
97    },
98    #[serde(rename = "tool_result")]
99    ToolResult {
100        tool_use_id: String,
101        content: serde_json::Value,
102    },
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct ImageSource {
107    #[serde(rename = "type")]
108    pub source_type: String,
109    #[serde(default, skip_serializing_if = "Option::is_none")]
110    pub media_type: Option<String>,
111    pub data: String, // base64 encoded or URL
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct AudioSource {
116    #[serde(rename = "type")]
117    pub source_type: String,
118    #[serde(default, skip_serializing_if = "Option::is_none")]
119    pub media_type: Option<String>,
120    pub data: String, // base64 encoded or URL
121}
122
123impl ContentBlock {
124    pub fn text(text: impl Into<String>) -> Self {
125        ContentBlock::Text { text: text.into() }
126    }
127
128    pub fn image_base64(data: String, media_type: Option<String>) -> Self {
129        ContentBlock::Image {
130            source: ImageSource {
131                source_type: "base64".to_string(),
132                media_type,
133                data,
134            },
135        }
136    }
137
138    pub fn audio_base64(data: String, media_type: Option<String>) -> Self {
139        ContentBlock::Audio {
140            source: AudioSource {
141                source_type: "base64".to_string(),
142                media_type,
143                data,
144            },
145        }
146    }
147
148    pub fn image_from_file(path: impl AsRef<Path>) -> crate::Result<Self> {
149        let path = path.as_ref();
150        let bytes = std::fs::read(path)?;
151        let media_type = guess_media_type(path);
152        let data = base64::engine::general_purpose::STANDARD.encode(bytes);
153        Ok(Self::image_base64(data, media_type))
154    }
155
156    pub fn audio_from_file(path: impl AsRef<Path>) -> crate::Result<Self> {
157        let path = path.as_ref();
158        let bytes = std::fs::read(path)?;
159        let media_type = guess_media_type(path);
160        let data = base64::engine::general_purpose::STANDARD.encode(bytes);
161        Ok(Self::audio_base64(data, media_type))
162    }
163}
164
165fn guess_media_type(path: &Path) -> Option<String> {
166    let ext = path
167        .extension()
168        .and_then(|s| s.to_str())
169        .unwrap_or("")
170        .to_lowercase();
171    let mt = match ext.as_str() {
172        "png" => "image/png",
173        "jpg" | "jpeg" => "image/jpeg",
174        "webp" => "image/webp",
175        "gif" => "image/gif",
176        "mp3" => "audio/mpeg",
177        "wav" => "audio/wav",
178        "ogg" => "audio/ogg",
179        "m4a" => "audio/mp4",
180        _ => return None,
181    };
182    Some(mt.to_string())
183}