1use serde::{Deserialize, Serialize};
23
24use crate::errors::{Error, Result};
25
26pub const MAX_IMAGE_BASE64_BYTES: usize = 15 * 1024 * 1024;
30
31pub const ALLOWED_IMAGE_MIME_TYPES: &[&str] =
33 &["image/jpeg", "image/png", "image/gif", "image/webp"];
34
35#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
45#[serde(tag = "type", rename_all = "snake_case")]
46pub enum ContentBlock {
47 Text(TextBlock),
49
50 ToolUse(ToolUseBlock),
52
53 ToolResult(ToolResultBlock),
55
56 Thinking(ThinkingBlock),
58
59 Image(ImageBlock),
61}
62
63impl ContentBlock {
64 #[inline]
66 #[must_use]
67 pub fn as_text(&self) -> Option<&str> {
68 match self {
69 Self::Text(b) => Some(&b.text),
70 _ => None,
71 }
72 }
73
74 #[inline]
76 #[must_use]
77 pub fn is_tool_use(&self) -> bool {
78 matches!(self, Self::ToolUse(_))
79 }
80}
81
82#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
86pub struct TextBlock {
87 pub text: String,
89}
90
91#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
95pub struct ToolUseBlock {
96 pub id: String,
99
100 pub name: String,
102
103 #[serde(default)]
105 pub input: serde_json::Value,
106}
107
108#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
112pub struct ToolResultBlock {
113 pub tool_use_id: String,
115
116 #[serde(default)]
118 pub is_error: bool,
119
120 #[serde(default)]
122 pub content: ToolResultContent,
123}
124
125#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
128#[serde(untagged)]
129pub enum ToolResultContent {
130 Text(String),
132 Blocks(Vec<ContentBlock>),
134}
135
136impl Default for ToolResultContent {
137 fn default() -> Self {
138 Self::Text(String::new())
139 }
140}
141
142#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
146pub struct ThinkingBlock {
147 pub thinking: String,
149
150 #[serde(default)]
152 pub signature: Option<String>,
153}
154
155#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
159pub struct ImageBlock {
160 pub source: ImageSource,
162}
163
164#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
166#[serde(tag = "type", rename_all = "snake_case")]
167pub enum ImageSource {
168 Base64(Base64ImageSource),
170 Url(UrlImageSource),
172}
173
174#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
176pub struct Base64ImageSource {
177 pub media_type: String,
179 pub data: String,
181}
182
183#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
185pub struct UrlImageSource {
186 pub url: String,
188 #[serde(default, skip_serializing_if = "Option::is_none")]
190 pub media_type: Option<String>,
191}
192
193#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
199#[serde(tag = "type", rename_all = "snake_case")]
200pub enum UserContent {
201 Text(TextBlock),
203 Image(ImageBlock),
205}
206
207impl UserContent {
208 #[inline]
217 #[must_use]
218 pub fn text(s: impl Into<String>) -> Self {
219 Self::Text(TextBlock { text: s.into() })
220 }
221
222 pub fn image_base64(data: impl Into<String>, media_type: impl Into<String>) -> Result<Self> {
240 let data = data.into();
241 let media_type = media_type.into();
242
243 validate_mime_type(&media_type)?;
244 validate_base64_size(&data)?;
245
246 Ok(Self::Image(ImageBlock {
247 source: ImageSource::Base64(Base64ImageSource { media_type, data }),
248 }))
249 }
250
251 pub fn image_url(url: impl Into<String>, media_type: impl Into<String>) -> Result<Self> {
267 let media_type = media_type.into();
268 validate_mime_type(&media_type)?;
269
270 Ok(Self::Image(ImageBlock {
271 source: ImageSource::Url(UrlImageSource {
272 url: url.into(),
273 media_type: Some(media_type),
274 }),
275 }))
276 }
277
278 #[inline]
280 #[must_use]
281 pub fn image_url_untyped(url: impl Into<String>) -> Self {
282 Self::Image(ImageBlock {
283 source: ImageSource::Url(UrlImageSource {
284 url: url.into(),
285 media_type: None,
286 }),
287 })
288 }
289
290 #[inline]
292 #[must_use]
293 pub fn as_text(&self) -> Option<&str> {
294 match self {
295 Self::Text(b) => Some(&b.text),
296 _ => None,
297 }
298 }
299}
300
301impl From<&str> for UserContent {
304 #[inline]
305 fn from(s: &str) -> Self {
306 Self::text(s)
307 }
308}
309
310impl From<String> for UserContent {
311 #[inline]
312 fn from(s: String) -> Self {
313 Self::text(s)
314 }
315}
316
317fn validate_mime_type(media_type: &str) -> Result<()> {
321 if ALLOWED_IMAGE_MIME_TYPES.contains(&media_type) {
322 Ok(())
323 } else {
324 Err(Error::ImageValidation(format!(
325 "unsupported MIME type '{media_type}'; allowed: {}",
326 ALLOWED_IMAGE_MIME_TYPES.join(", ")
327 )))
328 }
329}
330
331fn validate_base64_size(data: &str) -> Result<()> {
333 if data.len() > MAX_IMAGE_BASE64_BYTES {
334 Err(Error::ImageValidation(format!(
335 "base64 image data exceeds the 15 MiB limit ({} bytes)",
336 data.len()
337 )))
338 } else {
339 Ok(())
340 }
341}
342
343#[cfg(test)]
346mod tests {
347 use super::*;
348
349 #[test]
352 fn user_content_text_round_trip() {
353 let original = UserContent::text("Hello!");
354 let json = serde_json::to_string(&original).unwrap();
355 let decoded: UserContent = serde_json::from_str(&json).unwrap();
356 assert_eq!(original, decoded);
357 }
358
359 #[test]
360 fn user_content_text_serde_shape() {
361 let c = UserContent::text("hi");
362 let v: serde_json::Value = serde_json::to_value(&c).unwrap();
363 assert_eq!(v["type"], "text");
364 assert_eq!(v["text"], "hi");
365 }
366
367 #[test]
370 fn from_str_produces_text_variant() {
371 let c: UserContent = "hello".into();
372 assert_eq!(c.as_text(), Some("hello"));
373 }
374
375 #[test]
376 fn from_string_produces_text_variant() {
377 let c: UserContent = String::from("world").into();
378 assert_eq!(c.as_text(), Some("world"));
379 }
380
381 #[test]
384 fn image_base64_valid_mime_types() {
385 for mime in ALLOWED_IMAGE_MIME_TYPES {
386 let result = UserContent::image_base64("aGVsbG8=", *mime);
387 assert!(result.is_ok(), "should accept {mime}");
388 }
389 }
390
391 #[test]
392 fn image_base64_rejects_unsupported_mime() {
393 let err = UserContent::image_base64("aGVsbG8=", "image/bmp").unwrap_err();
394 assert!(
395 matches!(err, Error::ImageValidation(_)),
396 "expected ImageValidation, got {err:?}"
397 );
398 assert!(err.to_string().contains("image/bmp"));
399 }
400
401 #[test]
402 fn image_base64_rejects_oversized_payload() {
403 let oversized = "A".repeat(MAX_IMAGE_BASE64_BYTES + 1);
405 let err = UserContent::image_base64(oversized, "image/png").unwrap_err();
406 assert!(matches!(err, Error::ImageValidation(_)));
407 assert!(err.to_string().contains("15 MiB"));
408 }
409
410 #[test]
411 fn image_base64_accepts_exactly_at_limit() {
412 let at_limit = "A".repeat(MAX_IMAGE_BASE64_BYTES);
413 let result = UserContent::image_base64(at_limit, "image/png");
414 assert!(result.is_ok());
415 }
416
417 #[test]
418 fn image_base64_round_trip() {
419 let original = UserContent::image_base64("aGVsbG8=", "image/jpeg").unwrap();
420 let json = serde_json::to_string(&original).unwrap();
421 let decoded: UserContent = serde_json::from_str(&json).unwrap();
422 assert_eq!(original, decoded);
423 }
424
425 #[test]
426 fn image_base64_serde_shape() {
427 let c = UserContent::image_base64("abc123", "image/png").unwrap();
428 let v: serde_json::Value = serde_json::to_value(&c).unwrap();
429 assert_eq!(v["type"], "image");
430 assert_eq!(v["source"]["type"], "base64");
431 assert_eq!(v["source"]["media_type"], "image/png");
432 assert_eq!(v["source"]["data"], "abc123");
433 }
434
435 #[test]
438 fn image_url_valid() {
439 let result = UserContent::image_url("https://example.com/img.png", "image/png");
440 assert!(result.is_ok());
441 }
442
443 #[test]
444 fn image_url_rejects_bad_mime() {
445 let err =
446 UserContent::image_url("https://example.com/img.svg", "image/svg+xml").unwrap_err();
447 assert!(matches!(err, Error::ImageValidation(_)));
448 }
449
450 #[test]
451 fn image_url_round_trip() {
452 let original = UserContent::image_url("https://example.com/img.gif", "image/gif").unwrap();
453 let json = serde_json::to_string(&original).unwrap();
454 let decoded: UserContent = serde_json::from_str(&json).unwrap();
455 assert_eq!(original, decoded);
456 }
457
458 #[test]
459 fn image_url_serde_shape() {
460 let c = UserContent::image_url("https://example.com/a.webp", "image/webp").unwrap();
461 let v: serde_json::Value = serde_json::to_value(&c).unwrap();
462 assert_eq!(v["type"], "image");
463 assert_eq!(v["source"]["type"], "url");
464 assert_eq!(v["source"]["url"], "https://example.com/a.webp");
465 }
466
467 #[test]
468 fn image_url_untyped_no_media_type_field() {
469 let c = UserContent::image_url_untyped("https://example.com/a.png");
470 let v: serde_json::Value = serde_json::to_value(&c).unwrap();
471 assert!(
472 v["source"]["media_type"].is_null(),
473 "media_type should be omitted"
474 );
475 }
476
477 #[test]
480 fn content_block_text_round_trip() {
481 let block = ContentBlock::Text(TextBlock {
482 text: "response".into(),
483 });
484 let json = serde_json::to_string(&block).unwrap();
485 let decoded: ContentBlock = serde_json::from_str(&json).unwrap();
486 assert_eq!(block, decoded);
487 }
488
489 #[test]
490 fn content_block_text_serde_shape() {
491 let block = ContentBlock::Text(TextBlock {
492 text: "hello".into(),
493 });
494 let v: serde_json::Value = serde_json::to_value(&block).unwrap();
495 assert_eq!(v["type"], "text");
496 assert_eq!(v["text"], "hello");
497 }
498
499 #[test]
500 fn content_block_tool_use_round_trip() {
501 let block = ContentBlock::ToolUse(ToolUseBlock {
502 id: "call_123".into(),
503 name: "bash".into(),
504 input: serde_json::json!({ "command": "ls" }),
505 });
506 let json = serde_json::to_string(&block).unwrap();
507 let decoded: ContentBlock = serde_json::from_str(&json).unwrap();
508 assert_eq!(block, decoded);
509 }
510
511 #[test]
512 fn content_block_tool_use_serde_shape() {
513 let block = ContentBlock::ToolUse(ToolUseBlock {
514 id: "id1".into(),
515 name: "read_file".into(),
516 input: serde_json::json!({ "path": "/tmp/foo" }),
517 });
518 let v: serde_json::Value = serde_json::to_value(&block).unwrap();
519 assert_eq!(v["type"], "tool_use");
520 assert_eq!(v["name"], "read_file");
521 }
522
523 #[test]
524 fn content_block_tool_result_round_trip() {
525 let block = ContentBlock::ToolResult(ToolResultBlock {
526 tool_use_id: "call_123".into(),
527 is_error: false,
528 content: ToolResultContent::Text("file contents".into()),
529 });
530 let json = serde_json::to_string(&block).unwrap();
531 let decoded: ContentBlock = serde_json::from_str(&json).unwrap();
532 assert_eq!(block, decoded);
533 }
534
535 #[test]
536 fn content_block_thinking_round_trip() {
537 let block = ContentBlock::Thinking(ThinkingBlock {
538 thinking: "Let me think...".into(),
539 signature: Some("sig123".into()),
540 });
541 let json = serde_json::to_string(&block).unwrap();
542 let decoded: ContentBlock = serde_json::from_str(&json).unwrap();
543 assert_eq!(block, decoded);
544 }
545
546 #[test]
547 fn content_block_as_text_helper() {
548 let text = ContentBlock::Text(TextBlock {
549 text: "hello".into(),
550 });
551 assert_eq!(text.as_text(), Some("hello"));
552
553 let tool = ContentBlock::ToolUse(ToolUseBlock {
554 id: "x".into(),
555 name: "bash".into(),
556 input: serde_json::Value::Null,
557 });
558 assert_eq!(tool.as_text(), None);
559 }
560
561 #[test]
562 fn content_block_is_tool_use_helper() {
563 let tool = ContentBlock::ToolUse(ToolUseBlock {
564 id: "x".into(),
565 name: "bash".into(),
566 input: serde_json::Value::Null,
567 });
568 assert!(tool.is_tool_use());
569 assert!(!ContentBlock::Text(TextBlock { text: "hi".into() }).is_tool_use());
570 }
571
572 #[test]
573 fn tool_result_content_default_is_empty_text() {
574 let default = ToolResultContent::default();
575 assert_eq!(default, ToolResultContent::Text(String::new()));
576 }
577
578 #[test]
579 fn image_block_round_trip() {
580 let block = ContentBlock::Image(ImageBlock {
581 source: ImageSource::Base64(Base64ImageSource {
582 media_type: "image/png".into(),
583 data: "abc==".into(),
584 }),
585 });
586 let json = serde_json::to_string(&block).unwrap();
587 let decoded: ContentBlock = serde_json::from_str(&json).unwrap();
588 assert_eq!(block, decoded);
589 }
590}