1use base64::Engine as _;
4use serde::{Deserialize, Serialize};
5use std::path::Path;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct Message {
10 pub role: MessageRole,
11 pub content: MessageContent,
12 #[serde(default, skip_serializing_if = "Option::is_none")]
14 pub tool_call_id: Option<String>,
15}
16
17impl Message {
18 pub fn system(text: impl Into<String>) -> Self {
19 Self {
20 role: MessageRole::System,
21 content: MessageContent::Text(text.into()),
22 tool_call_id: None,
23 }
24 }
25
26 pub fn user(text: impl Into<String>) -> Self {
27 Self {
28 role: MessageRole::User,
29 content: MessageContent::Text(text.into()),
30 tool_call_id: None,
31 }
32 }
33
34 pub fn assistant(text: impl Into<String>) -> Self {
35 Self {
36 role: MessageRole::Assistant,
37 content: MessageContent::Text(text.into()),
38 tool_call_id: None,
39 }
40 }
41
42 pub fn tool(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
46 Self {
47 role: MessageRole::Tool,
48 content: MessageContent::Text(content.into()),
49 tool_call_id: Some(tool_call_id.into()),
50 }
51 }
52
53 pub fn with_content(role: MessageRole, content: MessageContent) -> Self {
54 Self {
55 role,
56 content,
57 tool_call_id: None,
58 }
59 }
60
61 pub fn contains_image(&self) -> bool {
62 match &self.content {
63 MessageContent::Text(_) => false,
64 MessageContent::Blocks(bs) => {
65 bs.iter().any(|b| matches!(b, ContentBlock::Image { .. }))
66 }
67 }
68 }
69
70 pub fn contains_audio(&self) -> bool {
71 match &self.content {
72 MessageContent::Text(_) => false,
73 MessageContent::Blocks(bs) => {
74 bs.iter().any(|b| matches!(b, ContentBlock::Audio { .. }))
75 }
76 }
77 }
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82#[serde(rename_all = "lowercase")]
83pub enum MessageRole {
84 System,
85 User,
86 Assistant,
87 Tool,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
93#[serde(untagged)]
94pub enum MessageContent {
95 Text(String),
96 Blocks(Vec<ContentBlock>),
97}
98
99impl MessageContent {
100 pub fn text(text: impl Into<String>) -> Self {
101 MessageContent::Text(text.into())
102 }
103
104 pub fn blocks(blocks: Vec<ContentBlock>) -> Self {
105 MessageContent::Blocks(blocks)
106 }
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111#[serde(tag = "type")]
112pub enum ContentBlock {
113 #[serde(rename = "text")]
114 Text { text: String },
115 #[serde(rename = "image")]
116 Image { source: ImageSource },
117 #[serde(rename = "audio")]
118 Audio { source: AudioSource },
119 #[serde(rename = "tool_use")]
120 ToolUse {
121 id: String,
122 name: String,
123 input: serde_json::Value,
124 },
125 #[serde(rename = "tool_result")]
126 ToolResult {
127 tool_use_id: String,
128 content: serde_json::Value,
129 },
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct ImageSource {
134 #[serde(rename = "type")]
135 pub source_type: String,
136 #[serde(default, skip_serializing_if = "Option::is_none")]
137 pub media_type: Option<String>,
138 pub data: String, }
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct AudioSource {
143 #[serde(rename = "type")]
144 pub source_type: String,
145 #[serde(default, skip_serializing_if = "Option::is_none")]
146 pub media_type: Option<String>,
147 pub data: String, }
149
150impl ContentBlock {
151 pub fn text(text: impl Into<String>) -> Self {
152 ContentBlock::Text { text: text.into() }
153 }
154
155 pub fn image_base64(data: String, media_type: Option<String>) -> Self {
156 ContentBlock::Image {
157 source: ImageSource {
158 source_type: "base64".to_string(),
159 media_type,
160 data,
161 },
162 }
163 }
164
165 pub fn audio_base64(data: String, media_type: Option<String>) -> Self {
166 ContentBlock::Audio {
167 source: AudioSource {
168 source_type: "base64".to_string(),
169 media_type,
170 data,
171 },
172 }
173 }
174
175 pub fn image_from_file(path: impl AsRef<Path>) -> crate::Result<Self> {
176 let path = path.as_ref();
177 let bytes = std::fs::read(path)?;
178 let media_type = guess_media_type(path);
179 let data = base64::engine::general_purpose::STANDARD.encode(bytes);
180 Ok(Self::image_base64(data, media_type))
181 }
182
183 pub fn audio_from_file(path: impl AsRef<Path>) -> crate::Result<Self> {
184 let path = path.as_ref();
185 let bytes = std::fs::read(path)?;
186 let media_type = guess_media_type(path);
187 let data = base64::engine::general_purpose::STANDARD.encode(bytes);
188 Ok(Self::audio_base64(data, media_type))
189 }
190}
191
192fn guess_media_type(path: &Path) -> Option<String> {
193 let ext = path
194 .extension()
195 .and_then(|s| s.to_str())
196 .unwrap_or("")
197 .to_lowercase();
198 let mt = match ext.as_str() {
199 "png" => "image/png",
200 "jpg" | "jpeg" => "image/jpeg",
201 "webp" => "image/webp",
202 "gif" => "image/gif",
203 "mp3" => "audio/mpeg",
204 "wav" => "audio/wav",
205 "ogg" => "audio/ogg",
206 "m4a" => "audio/mp4",
207 _ => return None,
208 };
209 Some(mt.to_string())
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215
216 #[test]
217 fn test_message_tool() {
218 let msg = Message::tool("call_abc123", "42");
219 assert!(matches!(msg.role, MessageRole::Tool));
220 assert_eq!(msg.tool_call_id.as_deref(), Some("call_abc123"));
221 if let MessageContent::Text(s) = msg.content {
222 assert_eq!(s, "42");
223 } else {
224 panic!("expected Text content");
225 }
226 }
227
228 #[test]
229 fn test_message_role_serialization() {
230 let msg = Message::tool("call_xyz", "result");
231 let json = serde_json::to_value(&msg).unwrap();
232 assert_eq!(json["role"], "tool");
233 assert_eq!(json["content"], "result");
234 assert_eq!(json["tool_call_id"], "call_xyz");
235 }
236}