use std::time::Duration;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum CacheControl {
Ephemeral,
EphemeralWithTtl(Duration),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum Message {
User {
blocks: Vec<UserBlock>,
},
Assistant {
blocks: Vec<AssistantBlock>,
},
}
impl Message {
pub fn user(text: impl Into<String>) -> Message {
Message::User {
blocks: vec![UserBlock::text(text)],
}
}
pub fn assistant_text(text: impl Into<String>) -> Message {
Message::Assistant {
blocks: vec![AssistantBlock::text(text)],
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum UserBlock {
Text {
text: String,
#[serde(skip, default)]
cache_control: Option<CacheControl>,
},
ToolResult {
call_id: String,
content: ToolResultContent,
#[serde(skip, default)]
cache_control: Option<CacheControl>,
},
Image {
source: Source,
#[serde(skip, default)]
cache_control: Option<CacheControl>,
},
Document {
source: Source,
#[serde(skip, default)]
cache_control: Option<CacheControl>,
},
}
impl UserBlock {
pub fn text(text: impl Into<String>) -> Self {
Self::Text {
text: text.into(),
cache_control: None,
}
}
pub fn tool_result(call_id: impl Into<String>, content: impl Into<ToolResultContent>) -> Self {
Self::ToolResult {
call_id: call_id.into(),
content: content.into(),
cache_control: None,
}
}
pub fn image(source: Source) -> Self {
Self::Image {
source,
cache_control: None,
}
}
pub fn document(source: Source) -> Self {
Self::Document {
source,
cache_control: None,
}
}
pub fn with_cache_control(mut self, cache_control: Option<CacheControl>) -> Self {
match &mut self {
Self::Text {
cache_control: cc, ..
}
| Self::ToolResult {
cache_control: cc, ..
}
| Self::Image {
cache_control: cc, ..
}
| Self::Document {
cache_control: cc, ..
} => *cc = cache_control,
}
self
}
pub fn cache_control(&self) -> Option<&CacheControl> {
match self {
Self::Text { cache_control, .. }
| Self::ToolResult { cache_control, .. }
| Self::Image { cache_control, .. }
| Self::Document { cache_control, .. } => cache_control.as_ref(),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum AssistantBlock {
Text {
text: String,
#[serde(skip, default)]
cache_control: Option<CacheControl>,
},
ToolCall {
id: String,
name: String,
args: serde_json::Value,
#[serde(skip, default)]
cache_control: Option<CacheControl>,
},
Reasoning {
text: String,
signature: Option<String>,
},
RedactedReasoning {
data: String,
},
}
impl AssistantBlock {
pub fn text(text: impl Into<String>) -> Self {
Self::Text {
text: text.into(),
cache_control: None,
}
}
pub fn tool_call(
id: impl Into<String>,
name: impl Into<String>,
args: serde_json::Value,
) -> Self {
Self::ToolCall {
id: id.into(),
name: name.into(),
args,
cache_control: None,
}
}
pub fn with_cache_control(mut self, cache_control: Option<CacheControl>) -> Self {
match &mut self {
Self::Text {
cache_control: cc, ..
} => *cc = cache_control,
Self::ToolCall {
cache_control: cc, ..
} => *cc = cache_control,
Self::Reasoning { .. } | Self::RedactedReasoning { .. } => {}
}
self
}
pub fn cache_control(&self) -> Option<&CacheControl> {
match self {
Self::Text { cache_control, .. } | Self::ToolCall { cache_control, .. } => {
cache_control.as_ref()
}
Self::Reasoning { .. } | Self::RedactedReasoning { .. } => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum Source {
Base64 {
media_type: String,
data: String,
},
Url {
url: String,
},
FileId {
id: String,
},
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ToolResultBlock {
Text {
text: String,
},
Image {
source: Source,
},
}
impl ToolResultBlock {
pub fn text(text: impl Into<String>) -> Self {
Self::Text { text: text.into() }
}
pub fn image(source: Source) -> Self {
Self::Image { source }
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[non_exhaustive]
pub struct ToolResultContent {
pub blocks: Vec<ToolResultBlock>,
#[serde(default, skip_serializing_if = "is_false")]
pub is_error: bool,
}
fn is_false(b: &bool) -> bool {
!*b
}
impl ToolResultContent {
pub fn text(text: impl Into<String>) -> Self {
Self {
blocks: vec![ToolResultBlock::text(text)],
is_error: false,
}
}
pub fn error(text: impl Into<String>) -> Self {
Self {
blocks: vec![ToolResultBlock::text(text)],
is_error: true,
}
}
pub fn image(source: Source) -> Self {
Self {
blocks: vec![ToolResultBlock::image(source)],
is_error: false,
}
}
pub fn from_blocks(blocks: Vec<ToolResultBlock>) -> Self {
Self {
blocks,
is_error: false,
}
}
pub fn with_is_error(mut self, is_error: bool) -> Self {
self.is_error = is_error;
self
}
pub fn as_text(&self) -> Option<&str> {
self.blocks.iter().find_map(|b| match b {
ToolResultBlock::Text { text } => Some(text.as_str()),
_ => None,
})
}
pub fn collect_text(&self) -> String {
self.blocks
.iter()
.filter_map(|b| match b {
ToolResultBlock::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("\n")
}
}
impl From<String> for ToolResultContent {
fn from(value: String) -> Self {
Self::text(value)
}
}
impl From<&str> for ToolResultContent {
fn from(value: &str) -> Self {
Self::text(value)
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum SystemPrompt {
Plain(String),
Blocks(Vec<SystemBlock>),
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct SystemBlock {
pub text: String,
pub cache_control: Option<CacheControl>,
}
impl SystemBlock {
pub fn new(text: impl Into<String>) -> Self {
Self {
text: text.into(),
cache_control: None,
}
}
pub fn with_cache_control(mut self, cache_control: CacheControl) -> Self {
self.cache_control = Some(cache_control);
self
}
}
impl From<String> for SystemPrompt {
fn from(value: String) -> Self {
Self::Plain(value)
}
}
impl From<&str> for SystemPrompt {
fn from(value: &str) -> Self {
Self::Plain(value.to_string())
}
}
impl SystemPrompt {
pub fn as_text(&self) -> String {
match self {
Self::Plain(s) => s.clone(),
Self::Blocks(bs) => bs
.iter()
.map(|b| b.text.as_str())
.collect::<Vec<_>>()
.join("\n\n"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn round_trip(msg: &Message) -> Message {
let json = serde_json::to_string(msg).expect("serialize");
serde_json::from_str(&json).expect("deserialize")
}
fn assert_user_text_eq(msg: &Message, expected: &str) {
match msg {
Message::User { blocks } => match &blocks[0] {
UserBlock::Text {
text,
cache_control,
} => {
assert_eq!(text, expected);
assert!(cache_control.is_none(), "cache_control must not round-trip");
}
other => panic!("expected UserBlock::Text, got {other:?}"),
},
other => panic!("expected Message::User, got {other:?}"),
}
}
#[test]
fn round_trip_user_text_drops_cache_control() {
let msg = Message::User {
blocks: vec![
UserBlock::text("hello").with_cache_control(Some(CacheControl::Ephemeral)),
],
};
let restored = round_trip(&msg);
assert_user_text_eq(&restored, "hello");
}
#[test]
fn round_trip_user_tool_result() {
let msg = Message::User {
blocks: vec![UserBlock::tool_result(
"call-1",
ToolResultContent::text("ok"),
)],
};
let restored = round_trip(&msg);
match &restored {
Message::User { blocks } => match &blocks[0] {
UserBlock::ToolResult {
call_id,
content,
cache_control,
} => {
assert_eq!(call_id, "call-1");
assert_eq!(content.as_text(), Some("ok"));
assert!(!content.is_error);
assert!(cache_control.is_none());
}
other => panic!("expected ToolResult, got {other:?}"),
},
other => panic!("expected User, got {other:?}"),
}
}
#[test]
fn round_trip_assistant_text_and_tool_call() {
let msg = Message::Assistant {
blocks: vec![
AssistantBlock::text("thinking out loud"),
AssistantBlock::tool_call("c1", "fetch", json!({"q": "x"})),
],
};
let restored = round_trip(&msg);
match &restored {
Message::Assistant { blocks } => {
assert_eq!(blocks.len(), 2);
match &blocks[0] {
AssistantBlock::Text { text, .. } => assert_eq!(text, "thinking out loud"),
other => panic!("expected Text, got {other:?}"),
}
match &blocks[1] {
AssistantBlock::ToolCall { id, name, args, .. } => {
assert_eq!(id, "c1");
assert_eq!(name, "fetch");
assert_eq!(args, &json!({"q": "x"}));
}
other => panic!("expected ToolCall, got {other:?}"),
}
}
other => panic!("expected Assistant, got {other:?}"),
}
}
#[test]
fn round_trip_assistant_reasoning_variants() {
let msg = Message::Assistant {
blocks: vec![
AssistantBlock::Reasoning {
text: "consider X".into(),
signature: Some("sig-1".into()),
},
AssistantBlock::RedactedReasoning {
data: "opaque".into(),
},
],
};
let restored = round_trip(&msg);
match &restored {
Message::Assistant { blocks } => match (&blocks[0], &blocks[1]) {
(
AssistantBlock::Reasoning { text, signature },
AssistantBlock::RedactedReasoning { data },
) => {
assert_eq!(text, "consider X");
assert_eq!(signature.as_deref(), Some("sig-1"));
assert_eq!(data, "opaque");
}
other => panic!("unexpected blocks: {other:?}"),
},
other => panic!("expected Assistant, got {other:?}"),
}
}
#[test]
fn round_trip_tool_result_error_variant() {
let content = ToolResultContent::error("boom");
let json = serde_json::to_string(&content).unwrap();
let back: ToolResultContent = serde_json::from_str(&json).unwrap();
assert_eq!(back.as_text(), Some("boom"));
assert!(back.is_error);
}
#[test]
fn tool_result_content_is_error_omitted_when_false() {
let content = ToolResultContent::text("ok");
let json = serde_json::to_value(&content).unwrap();
assert!(
json.get("is_error").is_none(),
"is_error must be skipped when false, got {json}"
);
}
#[test]
fn tool_result_content_multi_block_round_trip() {
let content = ToolResultContent::from_blocks(vec![
ToolResultBlock::text("see chart"),
ToolResultBlock::image(Source::Url {
url: "https://example.com/chart.png".into(),
}),
]);
let json = serde_json::to_string(&content).unwrap();
let back: ToolResultContent = serde_json::from_str(&json).unwrap();
assert_eq!(back.blocks.len(), 2);
assert!(!back.is_error);
}
#[test]
fn round_trip_user_image_block() {
let msg = Message::User {
blocks: vec![UserBlock::image(Source::Base64 {
media_type: "image/png".into(),
data: "AAAA".into(),
})],
};
let restored = round_trip(&msg);
match &restored {
Message::User { blocks } => match &blocks[0] {
UserBlock::Image {
source,
cache_control,
} => {
assert!(matches!(
source,
Source::Base64 { media_type, data }
if media_type == "image/png" && data == "AAAA"
));
assert!(cache_control.is_none());
}
other => panic!("expected Image, got {other:?}"),
},
other => panic!("expected User, got {other:?}"),
}
}
}