use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: Vec<Content>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum Content {
#[serde(rename = "text")]
Text {
text: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: Value,
},
#[serde(rename = "tool_result")]
ToolResult {
tool_use_id: String,
content: String,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
is_error: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
cache_control: Option<CacheControl>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum CacheControl {
Ephemeral {
#[serde(default, skip_serializing_if = "Option::is_none")]
ttl: Option<CacheTtl>,
},
}
impl CacheControl {
pub fn ephemeral() -> Self {
CacheControl::Ephemeral { ttl: None }
}
pub fn ephemeral_1h() -> Self {
CacheControl::Ephemeral {
ttl: Some(CacheTtl::OneHour),
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum CacheTtl {
#[serde(rename = "5m")]
FiveMinutes,
#[serde(rename = "1h")]
OneHour,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum StopReason {
EndTurn,
ToolUse,
MaxTokens,
StopSequence,
PauseTurn,
Cancelled,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Usage {
pub input_tokens: u32,
pub output_tokens: u32,
#[serde(default)]
pub cache_creation_input_tokens: u32,
#[serde(default)]
pub cache_read_input_tokens: u32,
}
impl Usage {
pub fn merge_max(&mut self, other: &Usage) {
self.input_tokens = self.input_tokens.max(other.input_tokens);
self.output_tokens = self.output_tokens.max(other.output_tokens);
self.cache_creation_input_tokens = self
.cache_creation_input_tokens
.max(other.cache_creation_input_tokens);
self.cache_read_input_tokens = self
.cache_read_input_tokens
.max(other.cache_read_input_tokens);
}
pub fn add(&mut self, other: &Usage) {
self.input_tokens = self.input_tokens.saturating_add(other.input_tokens);
self.output_tokens = self.output_tokens.saturating_add(other.output_tokens);
self.cache_creation_input_tokens = self
.cache_creation_input_tokens
.saturating_add(other.cache_creation_input_tokens);
self.cache_read_input_tokens = self
.cache_read_input_tokens
.saturating_add(other.cache_read_input_tokens);
}
}
impl Message {
pub fn user(content: Vec<Content>) -> Self {
Self {
role: Role::User,
content,
}
}
pub fn user_text(text: impl Into<String>) -> Self {
Self {
role: Role::User,
content: vec![Content::text(text)],
}
}
pub fn assistant(content: Vec<Content>) -> Self {
Self {
role: Role::Assistant,
content,
}
}
pub fn text(&self) -> String {
self.content
.iter()
.filter_map(|c| match c {
Content::Text { text, .. } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("")
}
pub fn tool_uses(&self) -> Vec<(&str, &str, &Value)> {
self.content
.iter()
.filter_map(|c| match c {
Content::ToolUse { id, name, input } => Some((id.as_str(), name.as_str(), input)),
_ => None,
})
.collect()
}
}
impl Content {
pub fn text(text: impl Into<String>) -> Self {
Content::Text {
text: text.into(),
cache_control: None,
}
}
pub fn text_cached(text: impl Into<String>) -> Self {
Content::Text {
text: text.into(),
cache_control: Some(CacheControl::ephemeral()),
}
}
pub fn tool_result(
tool_use_id: impl Into<String>,
content: impl Into<String>,
is_error: bool,
) -> Self {
Content::ToolResult {
tool_use_id: tool_use_id.into(),
content: content.into(),
is_error,
cache_control: None,
}
}
}