1use serde::{Deserialize, Serialize};
25
26#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
28#[serde(tag = "type", rename_all = "snake_case")]
29pub enum ContentPart {
30 Text {
34 text: String,
36 },
37 Image {
39 source: ImageSource,
41 mime: String,
43 },
44 Audio {
46 source: AudioSource,
48 mime: String,
50 },
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
55#[serde(tag = "kind", rename_all = "snake_case")]
56pub enum ImageSource {
57 Url {
59 url: String,
61 },
62 Base64 {
64 data: String,
66 },
67}
68
69impl ImageSource {
70 pub fn url(u: impl Into<String>) -> Self {
72 Self::Url { url: u.into() }
73 }
74 pub fn base64(d: impl Into<String>) -> Self {
76 Self::Base64 { data: d.into() }
77 }
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
82#[serde(tag = "kind", rename_all = "snake_case")]
83pub enum AudioSource {
84 Url {
86 url: String,
88 },
89 Base64 {
91 data: String,
93 },
94}
95
96impl AudioSource {
97 pub fn url(u: impl Into<String>) -> Self {
99 Self::Url { url: u.into() }
100 }
101 pub fn base64(d: impl Into<String>) -> Self {
103 Self::Base64 { data: d.into() }
104 }
105}
106
107pub fn mime_from_path(path: &std::path::Path) -> Option<&'static str> {
115 let ext = path.extension()?.to_str()?.to_ascii_lowercase();
116 Some(match ext.as_str() {
117 "png" => "image/png",
118 "jpg" | "jpeg" => "image/jpeg",
119 "gif" => "image/gif",
120 "webp" => "image/webp",
121 "bmp" => "image/bmp",
122 "tiff" | "tif" => "image/tiff",
123 "svg" => "image/svg+xml",
124 "wav" => "audio/wav",
125 "mp3" => "audio/mpeg",
126 "m4a" => "audio/mp4",
127 "flac" => "audio/flac",
128 "ogg" => "audio/ogg",
129 "pdf" => "application/pdf",
130 _ => return None,
131 })
132}
133
134pub fn base64_encode(bytes: &[u8]) -> String {
137 const CHARS: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
138 let mut out = String::with_capacity(bytes.len().div_ceil(3) * 4);
139 let mut chunks = bytes.chunks_exact(3);
140 for chunk in &mut chunks {
141 let n = ((chunk[0] as u32) << 16) | ((chunk[1] as u32) << 8) | (chunk[2] as u32);
142 out.push(CHARS[((n >> 18) & 0x3f) as usize] as char);
143 out.push(CHARS[((n >> 12) & 0x3f) as usize] as char);
144 out.push(CHARS[((n >> 6) & 0x3f) as usize] as char);
145 out.push(CHARS[(n & 0x3f) as usize] as char);
146 }
147 let rem = chunks.remainder();
148 match rem.len() {
149 0 => {}
150 1 => {
151 let n = (rem[0] as u32) << 16;
152 out.push(CHARS[((n >> 18) & 0x3f) as usize] as char);
153 out.push(CHARS[((n >> 12) & 0x3f) as usize] as char);
154 out.push('=');
155 out.push('=');
156 }
157 2 => {
158 let n = ((rem[0] as u32) << 16) | ((rem[1] as u32) << 8);
159 out.push(CHARS[((n >> 18) & 0x3f) as usize] as char);
160 out.push(CHARS[((n >> 12) & 0x3f) as usize] as char);
161 out.push(CHARS[((n >> 6) & 0x3f) as usize] as char);
162 out.push('=');
163 }
164 _ => unreachable!(),
165 }
166 out
167}
168
169pub fn base64_decode(s: &str) -> crate::Result<Vec<u8>> {
172 fn val(c: u8) -> Option<u8> {
173 Some(match c {
174 b'A'..=b'Z' => c - b'A',
175 b'a'..=b'z' => c - b'a' + 26,
176 b'0'..=b'9' => c - b'0' + 52,
177 b'+' => 62,
178 b'/' => 63,
179 _ => return None,
180 })
181 }
182 let bytes: Vec<u8> = s
183 .bytes()
184 .filter(|b| !b.is_ascii_whitespace() && *b != b'=')
185 .collect();
186 let mut out = Vec::with_capacity(bytes.len() * 3 / 4);
187 let mut chunks = bytes.chunks_exact(4);
188 for chunk in &mut chunks {
189 let a = val(chunk[0]).ok_or_else(bad)?;
190 let b = val(chunk[1]).ok_or_else(bad)?;
191 let c = val(chunk[2]).ok_or_else(bad)?;
192 let d = val(chunk[3]).ok_or_else(bad)?;
193 let n = ((a as u32) << 18) | ((b as u32) << 12) | ((c as u32) << 6) | (d as u32);
194 out.push((n >> 16) as u8);
195 out.push((n >> 8) as u8);
196 out.push(n as u8);
197 }
198 let rem = chunks.remainder();
199 match rem.len() {
200 0 => {}
201 2 => {
202 let a = val(rem[0]).ok_or_else(bad)?;
203 let b = val(rem[1]).ok_or_else(bad)?;
204 let n = ((a as u32) << 18) | ((b as u32) << 12);
205 out.push((n >> 16) as u8);
206 }
207 3 => {
208 let a = val(rem[0]).ok_or_else(bad)?;
209 let b = val(rem[1]).ok_or_else(bad)?;
210 let c = val(rem[2]).ok_or_else(bad)?;
211 let n = ((a as u32) << 18) | ((b as u32) << 12) | ((c as u32) << 6);
212 out.push((n >> 16) as u8);
213 out.push((n >> 8) as u8);
214 }
215 _ => {
216 return Err(crate::CognisError::Serialization(
217 "base64: malformed input length".into(),
218 ))
219 }
220 }
221 Ok(out)
222}
223
224fn bad() -> crate::CognisError {
225 crate::CognisError::Serialization("base64: invalid character".into())
226}
227
228pub async fn image_source_from_path(
231 path: impl AsRef<std::path::Path>,
232) -> crate::Result<(ImageSource, String)> {
233 let path = path.as_ref();
234 let bytes = tokio::fs::read(path).await.map_err(|e| {
235 crate::CognisError::Configuration(format!(
236 "image_source_from_path: read `{}`: {e}",
237 path.display()
238 ))
239 })?;
240 let mime = mime_from_path(path)
241 .unwrap_or("application/octet-stream")
242 .to_string();
243 Ok((ImageSource::base64(base64_encode(&bytes)), mime))
244}
245
246impl ContentPart {
251 pub fn to_openai(&self) -> serde_json::Value {
258 match self {
259 ContentPart::Text { text } => {
260 serde_json::json!({"type": "text", "text": text})
261 }
262 ContentPart::Image { source, mime } => {
263 let url = match source {
264 ImageSource::Url { url } => url.clone(),
265 ImageSource::Base64 { data } => format!("data:{mime};base64,{data}"),
266 };
267 serde_json::json!({"type": "image_url", "image_url": {"url": url}})
268 }
269 ContentPart::Audio { source, mime } => {
270 let data = match source {
271 AudioSource::Base64 { data } => data.clone(),
272 AudioSource::Url { url } => url.clone(),
273 };
274 let format = mime.split('/').nth(1).unwrap_or("wav").to_string();
275 serde_json::json!({
276 "type": "input_audio",
277 "input_audio": {"data": data, "format": format},
278 })
279 }
280 }
281 }
282
283 pub fn from_openai(v: &serde_json::Value) -> Option<Self> {
285 let kind = v["type"].as_str()?;
286 match kind {
287 "text" => Some(ContentPart::Text {
288 text: v["text"].as_str()?.to_string(),
289 }),
290 "image_url" => {
291 let url = v["image_url"]["url"].as_str()?;
292 if let Some(rest) = url.strip_prefix("data:") {
293 if let Some((mime_part, b64)) = rest.split_once(";base64,") {
294 return Some(ContentPart::Image {
295 source: ImageSource::base64(b64),
296 mime: mime_part.to_string(),
297 });
298 }
299 }
300 Some(ContentPart::Image {
301 source: ImageSource::url(url),
302 mime: String::new(),
303 })
304 }
305 _ => None,
306 }
307 }
308
309 pub fn to_anthropic(&self) -> serde_json::Value {
317 match self {
318 ContentPart::Text { text } => {
319 serde_json::json!({"type": "text", "text": text})
320 }
321 ContentPart::Image { source, mime } => match source {
322 ImageSource::Url { url } => serde_json::json!({
323 "type": "image",
324 "source": {"type": "url", "url": url},
325 }),
326 ImageSource::Base64 { data } => serde_json::json!({
327 "type": "image",
328 "source": {
329 "type": "base64",
330 "media_type": mime,
331 "data": data,
332 },
333 }),
334 },
335 ContentPart::Audio { source, mime } => {
336 let stub = match source {
337 AudioSource::Url { url } => format!("[audio: {url} ({mime})]"),
338 AudioSource::Base64 { .. } => format!("[audio: base64 ({mime})]"),
339 };
340 serde_json::json!({"type": "text", "text": stub})
341 }
342 }
343 }
344
345 pub fn from_anthropic(v: &serde_json::Value) -> Option<Self> {
347 let kind = v["type"].as_str()?;
348 match kind {
349 "text" => Some(ContentPart::Text {
350 text: v["text"].as_str()?.to_string(),
351 }),
352 "image" => {
353 let source_kind = v["source"]["type"].as_str()?;
354 match source_kind {
355 "url" => Some(ContentPart::Image {
356 source: ImageSource::url(v["source"]["url"].as_str()?),
357 mime: String::new(),
358 }),
359 "base64" => Some(ContentPart::Image {
360 source: ImageSource::base64(v["source"]["data"].as_str()?),
361 mime: v["source"]["media_type"]
362 .as_str()
363 .unwrap_or_default()
364 .to_string(),
365 }),
366 _ => None,
367 }
368 }
369 _ => None,
370 }
371 }
372
373 pub fn to_gemini(&self) -> serde_json::Value {
380 match self {
381 ContentPart::Text { text } => serde_json::json!({"text": text}),
382 ContentPart::Image { source, mime } => match source {
383 ImageSource::Url { url } => serde_json::json!({
384 "file_data": {"mime_type": mime, "file_uri": url},
385 }),
386 ImageSource::Base64 { data } => serde_json::json!({
387 "inline_data": {"mime_type": mime, "data": data},
388 }),
389 },
390 ContentPart::Audio { source, mime } => match source {
391 AudioSource::Url { url } => serde_json::json!({
392 "file_data": {"mime_type": mime, "file_uri": url},
393 }),
394 AudioSource::Base64 { data } => serde_json::json!({
395 "inline_data": {"mime_type": mime, "data": data},
396 }),
397 },
398 }
399 }
400
401 pub fn from_gemini(v: &serde_json::Value) -> Option<Self> {
403 if let Some(t) = v["text"].as_str() {
404 return Some(ContentPart::Text {
405 text: t.to_string(),
406 });
407 }
408 if let Some(inline) = v["inline_data"].as_object() {
409 let mime = inline["mime_type"].as_str()?.to_string();
410 let data = inline["data"].as_str()?.to_string();
411 return Some(if mime.starts_with("audio/") {
412 ContentPart::Audio {
413 source: AudioSource::base64(data),
414 mime,
415 }
416 } else {
417 ContentPart::Image {
418 source: ImageSource::base64(data),
419 mime,
420 }
421 });
422 }
423 if let Some(file) = v["file_data"].as_object() {
424 let mime = file["mime_type"].as_str()?.to_string();
425 let uri = file["file_uri"].as_str()?.to_string();
426 return Some(if mime.starts_with("audio/") {
427 ContentPart::Audio {
428 source: AudioSource::url(uri),
429 mime,
430 }
431 } else {
432 ContentPart::Image {
433 source: ImageSource::url(uri),
434 mime,
435 }
436 });
437 }
438 None
439 }
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445
446 #[test]
447 fn image_part_roundtrip() {
448 let p = ContentPart::Image {
449 source: ImageSource::url("https://x"),
450 mime: "image/png".into(),
451 };
452 let s = serde_json::to_string(&p).unwrap();
453 let back: ContentPart = serde_json::from_str(&s).unwrap();
454 assert_eq!(p, back);
455 assert!(s.contains("\"type\":\"image\""));
456 assert!(s.contains("\"kind\":\"url\""));
457 }
458
459 #[test]
460 fn openai_url_image_roundtrip() {
461 let p = ContentPart::Image {
462 source: ImageSource::url("https://x/cat.png"),
463 mime: "image/png".into(),
464 };
465 let v = p.to_openai();
466 assert_eq!(v["type"], "image_url");
467 let back = ContentPart::from_openai(&v).unwrap();
468 assert!(matches!(back, ContentPart::Image { .. }));
469 }
470
471 #[test]
472 fn openai_base64_image_via_data_uri() {
473 let p = ContentPart::Image {
474 source: ImageSource::base64("AAAA"),
475 mime: "image/png".into(),
476 };
477 let v = p.to_openai();
478 let url = v["image_url"]["url"].as_str().unwrap();
479 assert!(url.starts_with("data:image/png;base64,AAAA"));
480 let back = ContentPart::from_openai(&v).unwrap();
481 assert_eq!(
482 back,
483 ContentPart::Image {
484 source: ImageSource::base64("AAAA"),
485 mime: "image/png".into(),
486 }
487 );
488 }
489
490 #[test]
491 fn anthropic_base64_image_roundtrip() {
492 let p = ContentPart::Image {
493 source: ImageSource::base64("BBBB"),
494 mime: "image/jpeg".into(),
495 };
496 let v = p.to_anthropic();
497 assert_eq!(v["type"], "image");
498 assert_eq!(v["source"]["type"], "base64");
499 let back = ContentPart::from_anthropic(&v).unwrap();
500 assert_eq!(back, p);
501 }
502
503 #[test]
504 fn gemini_inline_data_roundtrip() {
505 let p = ContentPart::Image {
506 source: ImageSource::base64("CCCC"),
507 mime: "image/jpeg".into(),
508 };
509 let v = p.to_gemini();
510 assert!(v["inline_data"]["data"].is_string());
511 let back = ContentPart::from_gemini(&v).unwrap();
512 assert_eq!(back, p);
513 }
514
515 #[test]
516 fn gemini_file_data_roundtrip() {
517 let p = ContentPart::Image {
518 source: ImageSource::url("gs://bucket/x.png"),
519 mime: "image/png".into(),
520 };
521 let v = p.to_gemini();
522 assert!(v["file_data"]["file_uri"].is_string());
523 let back = ContentPart::from_gemini(&v).unwrap();
524 assert_eq!(back, p);
525 }
526
527 #[test]
528 fn mime_from_path_recognises_common_extensions() {
529 use std::path::Path;
530 assert_eq!(mime_from_path(Path::new("a.png")), Some("image/png"));
531 assert_eq!(mime_from_path(Path::new("a.jpg")), Some("image/jpeg"));
532 assert_eq!(mime_from_path(Path::new("A.JPEG")), Some("image/jpeg"));
533 assert_eq!(mime_from_path(Path::new("a.unknown")), None);
534 assert_eq!(mime_from_path(Path::new("noext")), None);
535 }
536
537 #[test]
538 fn base64_roundtrip() {
539 for v in [
540 &[][..],
541 b"a",
542 b"ab",
543 b"abc",
544 b"hello, world!",
545 &[0u8, 1, 2, 3, 254, 255][..],
546 ] {
547 let enc = base64_encode(v);
548 let dec = base64_decode(&enc).unwrap();
549 assert_eq!(dec, v.to_vec(), "roundtrip failed for {v:?}");
550 }
551 }
552
553 #[test]
554 fn base64_known_vector() {
555 assert_eq!(base64_encode(b""), "");
557 assert_eq!(base64_encode(b"f"), "Zg==");
558 assert_eq!(base64_encode(b"fo"), "Zm8=");
559 assert_eq!(base64_encode(b"foo"), "Zm9v");
560 assert_eq!(base64_encode(b"foob"), "Zm9vYg==");
561 assert_eq!(base64_encode(b"fooba"), "Zm9vYmE=");
562 assert_eq!(base64_encode(b"foobar"), "Zm9vYmFy");
563 }
564
565 #[test]
566 fn base64_decode_rejects_garbage() {
567 assert!(base64_decode("****").is_err());
568 }
569}