use std::collections::BTreeMap;
use schemars::{JsonSchema, Schema};
use serde::{Deserialize, Serialize};
use serde_with::{DefaultOnError, VecSkipError, serde_as, skip_serializing_none};
use super::Meta;
use crate::{IntoOption, SkipListener};
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, JsonSchema)]
#[serde(tag = "type", rename_all = "snake_case")]
#[schemars(extend("discriminator" = {"propertyName": "type"}))]
#[non_exhaustive]
pub enum ContentBlock {
Text(TextContent),
Image(ImageContent),
Audio(AudioContent),
ResourceLink(ResourceLink),
Resource(EmbeddedResource),
#[serde(untagged)]
Other(OtherContentBlock),
}
#[derive(Debug, Clone, PartialEq, Serialize, JsonSchema)]
#[schemars(inline)]
#[schemars(transform = other_content_block_schema)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct OtherContentBlock {
#[serde(rename = "type")]
pub type_: String,
#[serde(flatten)]
pub fields: BTreeMap<String, serde_json::Value>,
}
impl OtherContentBlock {
#[must_use]
pub fn new(type_: impl Into<String>, mut fields: BTreeMap<String, serde_json::Value>) -> Self {
fields.remove("type");
Self {
type_: type_.into(),
fields,
}
}
}
impl<'de> Deserialize<'de> for OtherContentBlock {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let mut fields = BTreeMap::<String, serde_json::Value>::deserialize(deserializer)?;
let type_ = fields
.remove("type")
.ok_or_else(|| serde::de::Error::missing_field("type"))?;
let serde_json::Value::String(type_) = type_ else {
return Err(serde::de::Error::custom("`type` must be a string"));
};
if is_known_content_block_type(&type_) {
return Err(serde::de::Error::custom(format!(
"known content block `{type_}` did not match its schema"
)));
}
Ok(Self { type_, fields })
}
}
fn is_known_content_block_type(type_: &str) -> bool {
matches!(
type_,
"text" | "image" | "audio" | "resource_link" | "resource"
)
}
fn other_content_block_schema(schema: &mut Schema) {
super::schema_util::reject_known_string_discriminators(
schema,
"type",
&["text", "image", "audio", "resource_link", "resource"],
);
}
#[serde_as]
#[skip_serializing_none]
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, JsonSchema)]
#[non_exhaustive]
pub struct TextContent {
#[serde_as(deserialize_as = "DefaultOnError")]
#[schemars(extend("x-deserialize-default-on-error" = true))]
#[serde(default)]
pub annotations: Option<Annotations>,
pub text: String,
#[serde(rename = "_meta")]
pub meta: Option<Meta>,
}
impl TextContent {
#[must_use]
pub fn new(text: impl Into<String>) -> Self {
Self {
annotations: None,
text: text.into(),
meta: None,
}
}
#[must_use]
pub fn annotations(mut self, annotations: impl IntoOption<Annotations>) -> Self {
self.annotations = annotations.into_option();
self
}
#[must_use]
pub fn meta(mut self, meta: impl IntoOption<Meta>) -> Self {
self.meta = meta.into_option();
self
}
}
impl<T: Into<String>> From<T> for ContentBlock {
fn from(value: T) -> Self {
Self::Text(TextContent::new(value))
}
}
#[serde_as]
#[skip_serializing_none]
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct ImageContent {
#[serde_as(deserialize_as = "DefaultOnError")]
#[schemars(extend("x-deserialize-default-on-error" = true))]
#[serde(default)]
pub annotations: Option<Annotations>,
pub data: String,
pub mime_type: String,
pub uri: Option<String>,
#[serde(rename = "_meta")]
pub meta: Option<Meta>,
}
impl ImageContent {
#[must_use]
pub fn new(data: impl Into<String>, mime_type: impl Into<String>) -> Self {
Self {
annotations: None,
data: data.into(),
mime_type: mime_type.into(),
uri: None,
meta: None,
}
}
#[must_use]
pub fn annotations(mut self, annotations: impl IntoOption<Annotations>) -> Self {
self.annotations = annotations.into_option();
self
}
#[must_use]
pub fn uri(mut self, uri: impl IntoOption<String>) -> Self {
self.uri = uri.into_option();
self
}
#[must_use]
pub fn meta(mut self, meta: impl IntoOption<Meta>) -> Self {
self.meta = meta.into_option();
self
}
}
#[serde_as]
#[skip_serializing_none]
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct AudioContent {
#[serde_as(deserialize_as = "DefaultOnError")]
#[schemars(extend("x-deserialize-default-on-error" = true))]
#[serde(default)]
pub annotations: Option<Annotations>,
pub data: String,
pub mime_type: String,
#[serde(rename = "_meta")]
pub meta: Option<Meta>,
}
impl AudioContent {
#[must_use]
pub fn new(data: impl Into<String>, mime_type: impl Into<String>) -> Self {
Self {
annotations: None,
data: data.into(),
mime_type: mime_type.into(),
meta: None,
}
}
#[must_use]
pub fn annotations(mut self, annotations: impl IntoOption<Annotations>) -> Self {
self.annotations = annotations.into_option();
self
}
#[must_use]
pub fn meta(mut self, meta: impl IntoOption<Meta>) -> Self {
self.meta = meta.into_option();
self
}
}
#[serde_as]
#[skip_serializing_none]
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, JsonSchema)]
#[non_exhaustive]
pub struct EmbeddedResource {
#[serde_as(deserialize_as = "DefaultOnError")]
#[schemars(extend("x-deserialize-default-on-error" = true))]
#[serde(default)]
pub annotations: Option<Annotations>,
pub resource: EmbeddedResourceResource,
#[serde(rename = "_meta")]
pub meta: Option<Meta>,
}
impl EmbeddedResource {
#[must_use]
pub fn new(resource: EmbeddedResourceResource) -> Self {
Self {
annotations: None,
resource,
meta: None,
}
}
#[must_use]
pub fn annotations(mut self, annotations: impl IntoOption<Annotations>) -> Self {
self.annotations = annotations.into_option();
self
}
#[must_use]
pub fn meta(mut self, meta: impl IntoOption<Meta>) -> Self {
self.meta = meta.into_option();
self
}
}
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, JsonSchema)]
#[serde(untagged)]
#[non_exhaustive]
pub enum EmbeddedResourceResource {
TextResourceContents(TextResourceContents),
BlobResourceContents(BlobResourceContents),
}
#[skip_serializing_none]
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct TextResourceContents {
pub mime_type: Option<String>,
pub text: String,
pub uri: String,
#[serde(rename = "_meta")]
pub meta: Option<Meta>,
}
impl TextResourceContents {
#[must_use]
pub fn new(text: impl Into<String>, uri: impl Into<String>) -> Self {
Self {
mime_type: None,
text: text.into(),
uri: uri.into(),
meta: None,
}
}
#[must_use]
pub fn mime_type(mut self, mime_type: impl IntoOption<String>) -> Self {
self.mime_type = mime_type.into_option();
self
}
#[must_use]
pub fn meta(mut self, meta: impl IntoOption<Meta>) -> Self {
self.meta = meta.into_option();
self
}
}
#[skip_serializing_none]
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct BlobResourceContents {
pub blob: String,
pub mime_type: Option<String>,
pub uri: String,
#[serde(rename = "_meta")]
pub meta: Option<Meta>,
}
impl BlobResourceContents {
#[must_use]
pub fn new(blob: impl Into<String>, uri: impl Into<String>) -> Self {
Self {
blob: blob.into(),
mime_type: None,
uri: uri.into(),
meta: None,
}
}
#[must_use]
pub fn mime_type(mut self, mime_type: impl IntoOption<String>) -> Self {
self.mime_type = mime_type.into_option();
self
}
#[must_use]
pub fn meta(mut self, meta: impl IntoOption<Meta>) -> Self {
self.meta = meta.into_option();
self
}
}
#[serde_as]
#[skip_serializing_none]
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct ResourceLink {
#[serde_as(deserialize_as = "DefaultOnError")]
#[schemars(extend("x-deserialize-default-on-error" = true))]
#[serde(default)]
pub annotations: Option<Annotations>,
pub description: Option<String>,
pub mime_type: Option<String>,
pub name: String,
pub size: Option<i64>,
pub title: Option<String>,
pub uri: String,
#[serde(rename = "_meta")]
pub meta: Option<Meta>,
}
impl ResourceLink {
#[must_use]
pub fn new(name: impl Into<String>, uri: impl Into<String>) -> Self {
Self {
annotations: None,
description: None,
mime_type: None,
name: name.into(),
size: None,
title: None,
uri: uri.into(),
meta: None,
}
}
#[must_use]
pub fn annotations(mut self, annotations: impl IntoOption<Annotations>) -> Self {
self.annotations = annotations.into_option();
self
}
#[must_use]
pub fn description(mut self, description: impl IntoOption<String>) -> Self {
self.description = description.into_option();
self
}
#[must_use]
pub fn mime_type(mut self, mime_type: impl IntoOption<String>) -> Self {
self.mime_type = mime_type.into_option();
self
}
#[must_use]
pub fn size(mut self, size: impl IntoOption<i64>) -> Self {
self.size = size.into_option();
self
}
#[must_use]
pub fn title(mut self, title: impl IntoOption<String>) -> Self {
self.title = title.into_option();
self
}
#[must_use]
pub fn meta(mut self, meta: impl IntoOption<Meta>) -> Self {
self.meta = meta.into_option();
self
}
}
#[serde_as]
#[skip_serializing_none]
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, JsonSchema, Default)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct Annotations {
#[serde_as(deserialize_as = "DefaultOnError<Option<VecSkipError<_, SkipListener>>>")]
#[schemars(extend("x-deserialize-default-on-error" = true, "x-deserialize-skip-invalid-items" = true))]
#[serde(default)]
pub audience: Option<Vec<Role>>,
pub last_modified: Option<String>,
pub priority: Option<f64>,
#[serde(rename = "_meta")]
pub meta: Option<Meta>,
}
impl Annotations {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn audience(mut self, audience: impl IntoOption<Vec<Role>>) -> Self {
self.audience = audience.into_option();
self
}
#[must_use]
pub fn last_modified(mut self, last_modified: impl IntoOption<String>) -> Self {
self.last_modified = last_modified.into_option();
self
}
#[must_use]
pub fn priority(mut self, priority: impl IntoOption<f64>) -> Self {
self.priority = priority.into_option();
self
}
#[must_use]
pub fn meta(mut self, meta: impl IntoOption<Meta>) -> Self {
self.meta = meta.into_option();
self
}
}
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub enum Role {
Assistant,
User,
#[serde(untagged)]
Other(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_text_content_roundtrip() {
let content = TextContent::new("hello world");
let json = serde_json::to_value(&content).unwrap();
let parsed: TextContent = serde_json::from_value(json).unwrap();
assert_eq!(content, parsed);
}
#[test]
fn test_text_content_omits_optional_fields() {
let content = TextContent::new("hello");
let json = serde_json::to_value(&content).unwrap();
assert!(!json.as_object().unwrap().contains_key("annotations"));
assert!(!json.as_object().unwrap().contains_key("meta"));
}
#[test]
fn test_text_content_from_string() {
let block: ContentBlock = "hello".into();
match block {
ContentBlock::Text(c) => assert_eq!(c.text, "hello"),
_ => panic!("Expected Text variant"),
}
}
#[test]
fn role_preserves_unknown_variant() {
let role: Role = serde_json::from_str("\"critic\"").unwrap();
assert_eq!(role, Role::Other("critic".to_string()));
assert_eq!(serde_json::to_value(&role).unwrap(), "critic");
}
#[test]
fn content_block_preserves_unknown_variant() {
let block: ContentBlock = serde_json::from_value(serde_json::json!({
"type": "_widget",
"title": "Status",
"state": {"ok": true}
}))
.unwrap();
let ContentBlock::Other(unknown) = block else {
panic!("expected unknown content block");
};
assert_eq!(unknown.type_, "_widget");
assert_eq!(
unknown.fields.get("title"),
Some(&serde_json::json!("Status"))
);
assert_eq!(
serde_json::to_value(ContentBlock::Other(unknown)).unwrap(),
serde_json::json!({
"type": "_widget",
"title": "Status",
"state": {"ok": true}
})
);
}
#[test]
fn content_block_does_not_hide_malformed_known_variant() {
assert!(
serde_json::from_value::<ContentBlock>(serde_json::json!({
"type": "text"
}))
.is_err()
);
}
#[test]
fn test_image_content_roundtrip() {
let content = ImageContent::new("base64data", "image/png");
let json = serde_json::to_value(&content).unwrap();
let parsed: ImageContent = serde_json::from_value(json).unwrap();
assert_eq!(content, parsed);
}
#[test]
fn test_image_content_omits_optional_fields() {
let content = ImageContent::new("data", "image/png");
let json = serde_json::to_value(&content).unwrap();
assert!(!json.as_object().unwrap().contains_key("uri"));
assert!(!json.as_object().unwrap().contains_key("annotations"));
assert!(!json.as_object().unwrap().contains_key("meta"));
}
#[test]
fn test_image_content_with_uri() {
let content = ImageContent::new("data", "image/png").uri("https://example.com/image.png");
let json = serde_json::to_value(&content).unwrap();
assert_eq!(json["uri"], "https://example.com/image.png");
}
#[test]
fn test_audio_content_roundtrip() {
let content = AudioContent::new("base64audio", "audio/mp3");
let json = serde_json::to_value(&content).unwrap();
let parsed: AudioContent = serde_json::from_value(json).unwrap();
assert_eq!(content, parsed);
}
#[test]
fn test_audio_content_omits_optional_fields() {
let content = AudioContent::new("data", "audio/mp3");
let json = serde_json::to_value(&content).unwrap();
assert!(!json.as_object().unwrap().contains_key("annotations"));
assert!(!json.as_object().unwrap().contains_key("meta"));
}
}