use serde::{Deserialize, Serialize};
use crate::errors::{Error, Result};
pub const MAX_IMAGE_BASE64_BYTES: usize = 15 * 1024 * 1024;
pub const ALLOWED_IMAGE_MIME_TYPES: &[&str] =
&["image/jpeg", "image/png", "image/gif", "image/webp"];
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentBlock {
Text(TextBlock),
ToolUse(ToolUseBlock),
ToolResult(ToolResultBlock),
Thinking(ThinkingBlock),
Image(ImageBlock),
}
impl ContentBlock {
#[inline]
#[must_use]
pub fn as_text(&self) -> Option<&str> {
match self {
Self::Text(b) => Some(&b.text),
_ => None,
}
}
#[inline]
#[must_use]
pub fn is_tool_use(&self) -> bool {
matches!(self, Self::ToolUse(_))
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct TextBlock {
pub text: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolUseBlock {
pub id: String,
pub name: String,
#[serde(default)]
pub input: serde_json::Value,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolResultBlock {
pub tool_use_id: String,
#[serde(default)]
pub is_error: bool,
#[serde(default)]
pub content: ToolResultContent,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ToolResultContent {
Text(String),
Blocks(Vec<ContentBlock>),
}
impl Default for ToolResultContent {
fn default() -> Self {
Self::Text(String::new())
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ThinkingBlock {
pub thinking: String,
#[serde(default)]
pub signature: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ImageBlock {
pub source: ImageSource,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ImageSource {
Base64(Base64ImageSource),
Url(UrlImageSource),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Base64ImageSource {
pub media_type: String,
pub data: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct UrlImageSource {
pub url: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub media_type: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum UserContent {
Text(TextBlock),
Image(ImageBlock),
}
impl UserContent {
#[inline]
#[must_use]
pub fn text(s: impl Into<String>) -> Self {
Self::Text(TextBlock { text: s.into() })
}
pub fn image_base64(data: impl Into<String>, media_type: impl Into<String>) -> Result<Self> {
let data = data.into();
let media_type = media_type.into();
validate_mime_type(&media_type)?;
validate_base64_size(&data)?;
Ok(Self::Image(ImageBlock {
source: ImageSource::Base64(Base64ImageSource { media_type, data }),
}))
}
pub fn image_url(url: impl Into<String>, media_type: impl Into<String>) -> Result<Self> {
let media_type = media_type.into();
validate_mime_type(&media_type)?;
Ok(Self::Image(ImageBlock {
source: ImageSource::Url(UrlImageSource {
url: url.into(),
media_type: Some(media_type),
}),
}))
}
#[inline]
#[must_use]
pub fn image_url_untyped(url: impl Into<String>) -> Self {
Self::Image(ImageBlock {
source: ImageSource::Url(UrlImageSource {
url: url.into(),
media_type: None,
}),
})
}
#[inline]
#[must_use]
pub fn as_text(&self) -> Option<&str> {
match self {
Self::Text(b) => Some(&b.text),
_ => None,
}
}
}
impl From<&str> for UserContent {
#[inline]
fn from(s: &str) -> Self {
Self::text(s)
}
}
impl From<String> for UserContent {
#[inline]
fn from(s: String) -> Self {
Self::text(s)
}
}
fn validate_mime_type(media_type: &str) -> Result<()> {
if ALLOWED_IMAGE_MIME_TYPES.contains(&media_type) {
Ok(())
} else {
Err(Error::ImageValidation(format!(
"unsupported MIME type '{media_type}'; allowed: {}",
ALLOWED_IMAGE_MIME_TYPES.join(", ")
)))
}
}
fn validate_base64_size(data: &str) -> Result<()> {
if data.len() > MAX_IMAGE_BASE64_BYTES {
Err(Error::ImageValidation(format!(
"base64 image data exceeds the 15 MiB limit ({} bytes)",
data.len()
)))
} else {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn user_content_text_round_trip() {
let original = UserContent::text("Hello!");
let json = serde_json::to_string(&original).unwrap();
let decoded: UserContent = serde_json::from_str(&json).unwrap();
assert_eq!(original, decoded);
}
#[test]
fn user_content_text_serde_shape() {
let c = UserContent::text("hi");
let v: serde_json::Value = serde_json::to_value(&c).unwrap();
assert_eq!(v["type"], "text");
assert_eq!(v["text"], "hi");
}
#[test]
fn from_str_produces_text_variant() {
let c: UserContent = "hello".into();
assert_eq!(c.as_text(), Some("hello"));
}
#[test]
fn from_string_produces_text_variant() {
let c: UserContent = String::from("world").into();
assert_eq!(c.as_text(), Some("world"));
}
#[test]
fn image_base64_valid_mime_types() {
for mime in ALLOWED_IMAGE_MIME_TYPES {
let result = UserContent::image_base64("aGVsbG8=", *mime);
assert!(result.is_ok(), "should accept {mime}");
}
}
#[test]
fn image_base64_rejects_unsupported_mime() {
let err = UserContent::image_base64("aGVsbG8=", "image/bmp").unwrap_err();
assert!(
matches!(err, Error::ImageValidation(_)),
"expected ImageValidation, got {err:?}"
);
assert!(err.to_string().contains("image/bmp"));
}
#[test]
fn image_base64_rejects_oversized_payload() {
let oversized = "A".repeat(MAX_IMAGE_BASE64_BYTES + 1);
let err = UserContent::image_base64(oversized, "image/png").unwrap_err();
assert!(matches!(err, Error::ImageValidation(_)));
assert!(err.to_string().contains("15 MiB"));
}
#[test]
fn image_base64_accepts_exactly_at_limit() {
let at_limit = "A".repeat(MAX_IMAGE_BASE64_BYTES);
let result = UserContent::image_base64(at_limit, "image/png");
assert!(result.is_ok());
}
#[test]
fn image_base64_round_trip() {
let original = UserContent::image_base64("aGVsbG8=", "image/jpeg").unwrap();
let json = serde_json::to_string(&original).unwrap();
let decoded: UserContent = serde_json::from_str(&json).unwrap();
assert_eq!(original, decoded);
}
#[test]
fn image_base64_serde_shape() {
let c = UserContent::image_base64("abc123", "image/png").unwrap();
let v: serde_json::Value = serde_json::to_value(&c).unwrap();
assert_eq!(v["type"], "image");
assert_eq!(v["source"]["type"], "base64");
assert_eq!(v["source"]["media_type"], "image/png");
assert_eq!(v["source"]["data"], "abc123");
}
#[test]
fn image_url_valid() {
let result = UserContent::image_url("https://example.com/img.png", "image/png");
assert!(result.is_ok());
}
#[test]
fn image_url_rejects_bad_mime() {
let err =
UserContent::image_url("https://example.com/img.svg", "image/svg+xml").unwrap_err();
assert!(matches!(err, Error::ImageValidation(_)));
}
#[test]
fn image_url_round_trip() {
let original = UserContent::image_url("https://example.com/img.gif", "image/gif").unwrap();
let json = serde_json::to_string(&original).unwrap();
let decoded: UserContent = serde_json::from_str(&json).unwrap();
assert_eq!(original, decoded);
}
#[test]
fn image_url_serde_shape() {
let c = UserContent::image_url("https://example.com/a.webp", "image/webp").unwrap();
let v: serde_json::Value = serde_json::to_value(&c).unwrap();
assert_eq!(v["type"], "image");
assert_eq!(v["source"]["type"], "url");
assert_eq!(v["source"]["url"], "https://example.com/a.webp");
}
#[test]
fn image_url_untyped_no_media_type_field() {
let c = UserContent::image_url_untyped("https://example.com/a.png");
let v: serde_json::Value = serde_json::to_value(&c).unwrap();
assert!(
v["source"]["media_type"].is_null(),
"media_type should be omitted"
);
}
#[test]
fn content_block_text_round_trip() {
let block = ContentBlock::Text(TextBlock {
text: "response".into(),
});
let json = serde_json::to_string(&block).unwrap();
let decoded: ContentBlock = serde_json::from_str(&json).unwrap();
assert_eq!(block, decoded);
}
#[test]
fn content_block_text_serde_shape() {
let block = ContentBlock::Text(TextBlock {
text: "hello".into(),
});
let v: serde_json::Value = serde_json::to_value(&block).unwrap();
assert_eq!(v["type"], "text");
assert_eq!(v["text"], "hello");
}
#[test]
fn content_block_tool_use_round_trip() {
let block = ContentBlock::ToolUse(ToolUseBlock {
id: "call_123".into(),
name: "bash".into(),
input: serde_json::json!({ "command": "ls" }),
});
let json = serde_json::to_string(&block).unwrap();
let decoded: ContentBlock = serde_json::from_str(&json).unwrap();
assert_eq!(block, decoded);
}
#[test]
fn content_block_tool_use_serde_shape() {
let block = ContentBlock::ToolUse(ToolUseBlock {
id: "id1".into(),
name: "read_file".into(),
input: serde_json::json!({ "path": "/tmp/foo" }),
});
let v: serde_json::Value = serde_json::to_value(&block).unwrap();
assert_eq!(v["type"], "tool_use");
assert_eq!(v["name"], "read_file");
}
#[test]
fn content_block_tool_result_round_trip() {
let block = ContentBlock::ToolResult(ToolResultBlock {
tool_use_id: "call_123".into(),
is_error: false,
content: ToolResultContent::Text("file contents".into()),
});
let json = serde_json::to_string(&block).unwrap();
let decoded: ContentBlock = serde_json::from_str(&json).unwrap();
assert_eq!(block, decoded);
}
#[test]
fn content_block_thinking_round_trip() {
let block = ContentBlock::Thinking(ThinkingBlock {
thinking: "Let me think...".into(),
signature: Some("sig123".into()),
});
let json = serde_json::to_string(&block).unwrap();
let decoded: ContentBlock = serde_json::from_str(&json).unwrap();
assert_eq!(block, decoded);
}
#[test]
fn content_block_as_text_helper() {
let text = ContentBlock::Text(TextBlock {
text: "hello".into(),
});
assert_eq!(text.as_text(), Some("hello"));
let tool = ContentBlock::ToolUse(ToolUseBlock {
id: "x".into(),
name: "bash".into(),
input: serde_json::Value::Null,
});
assert_eq!(tool.as_text(), None);
}
#[test]
fn content_block_is_tool_use_helper() {
let tool = ContentBlock::ToolUse(ToolUseBlock {
id: "x".into(),
name: "bash".into(),
input: serde_json::Value::Null,
});
assert!(tool.is_tool_use());
assert!(!ContentBlock::Text(TextBlock { text: "hi".into() }).is_tool_use());
}
#[test]
fn tool_result_content_default_is_empty_text() {
let default = ToolResultContent::default();
assert_eq!(default, ToolResultContent::Text(String::new()));
}
#[test]
fn image_block_round_trip() {
let block = ContentBlock::Image(ImageBlock {
source: ImageSource::Base64(Base64ImageSource {
media_type: "image/png".into(),
data: "abc==".into(),
}),
});
let json = serde_json::to_string(&block).unwrap();
let decoded: ContentBlock = serde_json::from_str(&json).unwrap();
assert_eq!(block, decoded);
}
}