1use crate::{
29 CerseiError, ContentBlock, DocumentSource, ImageSource, Message, MessageContent, Result, Role,
30};
31use base64::Engine;
32use std::path::Path;
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum MediaKind {
38 Image,
39 Video,
40 Audio,
41 Document,
43}
44
45impl MediaKind {
46 pub fn from_mime(mime: &str) -> Self {
48 if mime.starts_with("image/") {
49 MediaKind::Image
50 } else if mime.starts_with("video/") {
51 MediaKind::Video
52 } else if mime.starts_with("audio/") {
53 MediaKind::Audio
54 } else {
55 MediaKind::Document
56 }
57 }
58}
59
60pub fn detect_mime(bytes: &[u8], path: Option<&Path>) -> Option<String> {
64 if let Some(m) = sniff_magic(bytes) {
65 return Some(m.to_string());
66 }
67 path.and_then(mime_from_extension).map(|s| s.to_string())
68}
69
70fn sniff_magic(b: &[u8]) -> Option<&'static str> {
73 if b.len() >= 12 {
74 if &b[0..4] == b"RIFF" {
75 match &b[8..12] {
76 b"WEBP" => return Some("image/webp"),
77 b"WAVE" => return Some("audio/wav"),
78 b"AVI " => return Some("video/x-msvideo"),
79 _ => {}
80 }
81 }
82 if &b[4..8] == b"ftyp" {
84 return Some(match &b[8..12] {
85 b"qt " => "video/quicktime",
86 b"M4A " | b"M4B " => "audio/mp4",
87 _ => "video/mp4",
88 });
89 }
90 }
91 if b.starts_with(&[0x89, b'P', b'N', b'G', 0x0D, 0x0A, 0x1A, 0x0A]) {
92 return Some("image/png");
93 }
94 if b.starts_with(&[0xFF, 0xD8, 0xFF]) {
95 return Some("image/jpeg");
96 }
97 if b.starts_with(b"GIF87a") || b.starts_with(b"GIF89a") {
98 return Some("image/gif");
99 }
100 if b.starts_with(b"%PDF-") {
101 return Some("application/pdf");
102 }
103 if b.starts_with(&[0x1A, 0x45, 0xDF, 0xA3]) {
104 return Some("video/webm");
105 }
106 if b.starts_with(b"OggS") {
107 return Some("audio/ogg");
108 }
109 if b.starts_with(b"ID3") || b.starts_with(&[0xFF, 0xFB]) || b.starts_with(&[0xFF, 0xF3]) {
110 return Some("audio/mpeg");
111 }
112 None
113}
114
115fn mime_from_extension(path: &Path) -> Option<&'static str> {
116 let ext = path.extension()?.to_str()?.to_ascii_lowercase();
117 Some(match ext.as_str() {
118 "png" => "image/png",
119 "jpg" | "jpeg" | "jpe" => "image/jpeg",
120 "gif" => "image/gif",
121 "webp" => "image/webp",
122 "bmp" => "image/bmp",
123 "svg" => "image/svg+xml",
124 "heic" => "image/heic",
125 "heif" => "image/heif",
126 "mp4" | "m4v" => "video/mp4",
127 "mov" => "video/quicktime",
128 "webm" => "video/webm",
129 "avi" => "video/x-msvideo",
130 "mpeg" | "mpg" => "video/mpeg",
131 "flv" => "video/x-flv",
132 "wmv" => "video/x-ms-wmv",
133 "3gp" => "video/3gpp",
134 "mp3" => "audio/mpeg",
135 "wav" => "audio/wav",
136 "ogg" | "oga" => "audio/ogg",
137 "flac" => "audio/flac",
138 "aac" => "audio/aac",
139 "m4a" => "audio/mp4",
140 "pdf" => "application/pdf",
141 "txt" | "text" | "log" => "text/plain",
142 "md" | "markdown" => "text/markdown",
143 "csv" => "text/csv",
144 "html" | "htm" => "text/html",
145 _ => return None,
146 })
147}
148
149fn b64(bytes: &[u8]) -> String {
150 base64::engine::general_purpose::STANDARD.encode(bytes)
151}
152
153impl ContentBlock {
154 pub fn image_base64(media_type: impl Into<String>, data: impl Into<String>) -> Self {
156 ContentBlock::Image {
157 source: ImageSource {
158 source_type: "base64".into(),
159 media_type: Some(media_type.into()),
160 data: Some(data.into()),
161 url: None,
162 },
163 }
164 }
165
166 pub fn image_bytes(media_type: impl Into<String>, bytes: &[u8]) -> Self {
168 Self::image_base64(media_type, b64(bytes))
169 }
170
171 pub fn image_url(url: impl Into<String>) -> Self {
173 ContentBlock::Image {
174 source: ImageSource {
175 source_type: "url".into(),
176 media_type: None,
177 data: None,
178 url: Some(url.into()),
179 },
180 }
181 }
182
183 pub fn document_base64(media_type: impl Into<String>, data: impl Into<String>) -> Self {
185 ContentBlock::Document {
186 source: DocumentSource {
187 source_type: "base64".into(),
188 media_type: Some(media_type.into()),
189 data: Some(data.into()),
190 url: None,
191 },
192 title: None,
193 context: None,
194 citations: None,
195 }
196 }
197
198 pub fn document_bytes(media_type: impl Into<String>, bytes: &[u8]) -> Self {
200 Self::document_base64(media_type, b64(bytes))
201 }
202
203 pub fn document_url(url: impl Into<String>) -> Self {
205 ContentBlock::Document {
206 source: DocumentSource {
207 source_type: "url".into(),
208 media_type: None,
209 data: None,
210 url: Some(url.into()),
211 },
212 title: None,
213 context: None,
214 citations: None,
215 }
216 }
217
218 pub fn media_bytes(media_type: impl Into<String>, bytes: &[u8]) -> Self {
222 let mt = media_type.into();
223 match MediaKind::from_mime(&mt) {
224 MediaKind::Document => Self::document_bytes(mt, bytes),
225 _ => Self::image_bytes(mt, bytes),
226 }
227 }
228
229 pub fn from_path(path: impl AsRef<Path>) -> Result<Self> {
236 let path = path.as_ref();
237 let bytes = std::fs::read(path)?;
238 let mime = detect_mime(&bytes, Some(path)).ok_or_else(|| {
239 CerseiError::Config(format!(
240 "could not determine a media type for `{}`; pass one explicitly via ContentBlock::media_bytes",
241 path.display()
242 ))
243 })?;
244 Ok(Self::media_bytes(mime, &bytes))
245 }
246}
247
248impl Message {
249 pub fn user_with_media(text: impl Into<String>, media: Vec<ContentBlock>) -> Self {
251 let mut blocks = Vec::with_capacity(media.len() + 1);
252 let text = text.into();
253 if !text.is_empty() {
254 blocks.push(ContentBlock::Text { text });
255 }
256 blocks.extend(media);
257 Message {
258 role: Role::User,
259 content: MessageContent::Blocks(blocks),
260 id: None,
261 metadata: None,
262 }
263 }
264
265 pub fn user_with_files<P: AsRef<Path>>(text: impl Into<String>, paths: &[P]) -> Result<Self> {
268 let mut media = Vec::with_capacity(paths.len());
269 for p in paths {
270 media.push(ContentBlock::from_path(p)?);
271 }
272 Ok(Self::user_with_media(text, media))
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[test]
281 fn sniffs_common_formats() {
282 assert_eq!(
283 sniff_magic(&[0x89, b'P', b'N', b'G', 0x0D, 0x0A, 0x1A, 0x0A]),
284 Some("image/png")
285 );
286 assert_eq!(sniff_magic(b"%PDF-1.7"), Some("application/pdf"));
287 assert_eq!(sniff_magic(&[0xFF, 0xD8, 0xFF, 0xE0]), Some("image/jpeg"));
288 let mut mp4 = vec![0u8; 12];
289 mp4[4..8].copy_from_slice(b"ftyp");
290 mp4[8..12].copy_from_slice(b"isom");
291 assert_eq!(sniff_magic(&mp4), Some("video/mp4"));
292 }
293
294 #[test]
295 fn extension_fallback() {
296 assert_eq!(
297 detect_mime(b"not-a-known-signature", Some(Path::new("a.webp"))),
298 Some("image/webp".to_string())
299 );
300 assert_eq!(detect_mime(b"???", Some(Path::new("a.unknownext"))), None);
301 }
302
303 #[test]
304 fn media_bytes_routes_by_kind() {
305 assert!(matches!(
306 ContentBlock::media_bytes("image/png", b"x"),
307 ContentBlock::Image { .. }
308 ));
309 assert!(matches!(
310 ContentBlock::media_bytes("video/mp4", b"x"),
311 ContentBlock::Image { .. }
312 ));
313 assert!(matches!(
314 ContentBlock::media_bytes("application/pdf", b"x"),
315 ContentBlock::Document { .. }
316 ));
317 }
318
319 #[test]
320 fn image_bytes_base64_roundtrip() {
321 let block = ContentBlock::image_bytes("image/png", b"hello");
322 if let ContentBlock::Image { source } = block {
323 assert_eq!(source.source_type, "base64");
324 assert_eq!(source.media_type.as_deref(), Some("image/png"));
325 assert_eq!(source.data.as_deref(), Some("aGVsbG8="));
326 } else {
327 panic!("expected image block");
328 }
329 }
330
331 #[test]
332 fn image_block_serializes_to_anthropic_shape() {
333 let block = ContentBlock::image_base64("image/png", "QUJD");
336 let v = serde_json::to_value(&block).unwrap();
337 assert_eq!(v["type"], "image");
338 assert_eq!(v["source"]["type"], "base64");
339 assert_eq!(v["source"]["media_type"], "image/png");
340 assert_eq!(v["source"]["data"], "QUJD");
341
342 let doc = ContentBlock::document_base64("application/pdf", "UERG");
343 let v = serde_json::to_value(&doc).unwrap();
344 assert_eq!(v["type"], "document");
345 assert_eq!(v["source"]["type"], "base64");
346 assert_eq!(v["source"]["media_type"], "application/pdf");
347 }
348
349 #[test]
350 fn user_with_media_prepends_text() {
351 let m = Message::user_with_media("hi", vec![ContentBlock::image_url("http://x/y.png")]);
352 if let MessageContent::Blocks(b) = m.content {
353 assert_eq!(b.len(), 2);
354 assert!(matches!(b[0], ContentBlock::Text { .. }));
355 assert!(matches!(b[1], ContentBlock::Image { .. }));
356 } else {
357 panic!("expected blocks");
358 }
359 }
360}