#![deny(missing_docs)]
use std::fmt;
use base64::{engine::general_purpose::STANDARD as BASE64_STANDARD, Engine as _};
use serde::Serialize;
use serde_json::Value;
pub const VALID_IMAGE_MEDIA_TYPES: &[&str] =
&["image/jpeg", "image/png", "image/gif", "image/webp"];
pub const VALID_DOCUMENT_MEDIA_TYPES: &[&str] = &["application/pdf", "text/plain"];
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum CacheControl {
Ephemeral,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ImageSource {
Base64 {
media_type: String,
data: String,
},
Url {
url: String,
},
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum DocumentSource {
Base64 {
media_type: String,
data: String,
},
}
#[derive(Debug, Clone, PartialEq, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentBlock {
Text {
text: String,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
Image {
source: ImageSource,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
ToolUse {
id: String,
name: String,
input: Value,
},
ToolResult {
tool_use_id: String,
content: Value,
#[serde(skip_serializing_if = "is_false")]
is_error: bool,
},
Document {
source: DocumentSource,
#[serde(skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
}
fn is_false(b: &bool) -> bool {
!*b
}
#[derive(Debug, Clone, PartialEq, Serialize)]
pub struct Message {
pub role: String,
pub content: Vec<ContentBlock>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BlockError {
UnsupportedImageMediaType(String),
UnsupportedDocumentMediaType(String),
}
impl fmt::Display for BlockError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
BlockError::UnsupportedImageMediaType(t) => write!(
f,
"unsupported image media_type {:?}; expected one of {:?}",
t, VALID_IMAGE_MEDIA_TYPES
),
BlockError::UnsupportedDocumentMediaType(t) => write!(
f,
"unsupported document media_type {:?}; expected one of {:?}",
t, VALID_DOCUMENT_MEDIA_TYPES
),
}
}
}
impl std::error::Error for BlockError {}
#[derive(Debug, Default, Clone)]
pub struct Blocks {
inner: Vec<ContentBlock>,
}
impl Blocks {
pub fn new() -> Self {
Self { inner: Vec::new() }
}
pub fn text(&mut self, s: impl Into<String>) -> &mut Self {
self.inner.push(ContentBlock::Text {
text: s.into(),
cache_control: None,
});
self
}
pub fn text_with_cache(
&mut self,
s: impl Into<String>,
cache_control: CacheControl,
) -> &mut Self {
self.inner.push(ContentBlock::Text {
text: s.into(),
cache_control: Some(cache_control),
});
self
}
pub fn image_b64(
&mut self,
data: &[u8],
media_type: impl Into<String>,
) -> Result<&mut Self, BlockError> {
let media_type = media_type.into();
if !VALID_IMAGE_MEDIA_TYPES.contains(&media_type.as_str()) {
return Err(BlockError::UnsupportedImageMediaType(media_type));
}
let encoded = BASE64_STANDARD.encode(data);
self.inner.push(ContentBlock::Image {
source: ImageSource::Base64 {
media_type,
data: encoded,
},
cache_control: None,
});
Ok(self)
}
pub fn image_b64_with_cache(
&mut self,
data: &[u8],
media_type: impl Into<String>,
cache_control: CacheControl,
) -> Result<&mut Self, BlockError> {
let media_type = media_type.into();
if !VALID_IMAGE_MEDIA_TYPES.contains(&media_type.as_str()) {
return Err(BlockError::UnsupportedImageMediaType(media_type));
}
let encoded = BASE64_STANDARD.encode(data);
self.inner.push(ContentBlock::Image {
source: ImageSource::Base64 {
media_type,
data: encoded,
},
cache_control: Some(cache_control),
});
Ok(self)
}
pub fn image_url(&mut self, url: impl Into<String>) -> &mut Self {
self.inner.push(ContentBlock::Image {
source: ImageSource::Url { url: url.into() },
cache_control: None,
});
self
}
pub fn tool_use(
&mut self,
id: impl Into<String>,
name: impl Into<String>,
input: Value,
) -> &mut Self {
self.inner.push(ContentBlock::ToolUse {
id: id.into(),
name: name.into(),
input,
});
self
}
pub fn tool_result(
&mut self,
tool_use_id: impl Into<String>,
content: Value,
is_error: bool,
) -> &mut Self {
self.inner.push(ContentBlock::ToolResult {
tool_use_id: tool_use_id.into(),
content,
is_error,
});
self
}
pub fn document_b64(
&mut self,
data: &[u8],
media_type: impl Into<String>,
) -> Result<&mut Self, BlockError> {
let media_type = media_type.into();
if !VALID_DOCUMENT_MEDIA_TYPES.contains(&media_type.as_str()) {
return Err(BlockError::UnsupportedDocumentMediaType(media_type));
}
let encoded = BASE64_STANDARD.encode(data);
self.inner.push(ContentBlock::Document {
source: DocumentSource::Base64 {
media_type,
data: encoded,
},
cache_control: None,
});
Ok(self)
}
pub fn document_pdf_b64(&mut self, data: &[u8]) -> &mut Self {
let _ = self.document_b64(data, "application/pdf");
self
}
pub fn extend<I>(&mut self, blocks: I) -> &mut Self
where
I: IntoIterator<Item = ContentBlock>,
{
self.inner.extend(blocks);
self
}
pub fn build(&mut self) -> Vec<ContentBlock> {
std::mem::take(&mut self.inner)
}
pub fn build_message(&mut self, role: impl Into<String>) -> Message {
Message {
role: role.into(),
content: std::mem::take(&mut self.inner),
}
}
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
pub fn tool_result_block(
tool_use_id: impl Into<String>,
content: Value,
is_error: bool,
) -> ContentBlock {
ContentBlock::ToolResult {
tool_use_id: tool_use_id.into(),
content,
is_error,
}
}
}