use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::time::Duration;
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
System,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentBlock {
Text {
text: String,
},
Image {
source: ImageSource,
},
ToolUse {
id: String,
name: String,
input: Value,
},
ToolResult {
tool_use_id: String,
content: ToolResultContent,
#[serde(skip_serializing_if = "Option::is_none")]
is_error: Option<bool>,
},
Thinking {
thinking: String,
#[serde(default)]
signature: String,
},
RedactedThinking {
data: String,
},
Document {
source: DocumentSource,
#[serde(skip_serializing_if = "Option::is_none")]
title: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
context: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
citations: Option<CitationsConfig>,
},
#[serde(other)]
Opaque,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ToolResultContent {
Text(String),
Blocks(Vec<ContentBlock>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageSource {
#[serde(rename = "type")]
pub source_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub media_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub url: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DocumentSource {
#[serde(rename = "type")]
pub source_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub media_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub url: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CitationsConfig {
pub enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: MessageContent,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<MessageMetadata>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum MessageContent {
Text(String),
Blocks(Vec<ContentBlock>),
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct MessageMetadata {
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_reason: Option<StopReason>,
#[serde(default, skip_serializing_if = "Value::is_null")]
pub provider_data: Value,
}
impl Message {
pub fn user(content: impl Into<String>) -> Self {
Self {
role: Role::User,
content: MessageContent::Text(content.into()),
id: None,
metadata: None,
}
}
pub fn user_blocks(blocks: Vec<ContentBlock>) -> Self {
Self {
role: Role::User,
content: MessageContent::Blocks(blocks),
id: None,
metadata: None,
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: Role::Assistant,
content: MessageContent::Text(content.into()),
id: None,
metadata: None,
}
}
pub fn assistant_blocks(blocks: Vec<ContentBlock>) -> Self {
Self {
role: Role::Assistant,
content: MessageContent::Blocks(blocks),
id: None,
metadata: None,
}
}
pub fn system(content: impl Into<String>) -> Self {
Self {
role: Role::System,
content: MessageContent::Text(content.into()),
id: None,
metadata: None,
}
}
pub fn get_text(&self) -> Option<&str> {
match &self.content {
MessageContent::Text(t) => Some(t.as_str()),
MessageContent::Blocks(blocks) => blocks.iter().find_map(|b| {
if let ContentBlock::Text { text } = b {
Some(text.as_str())
} else {
None
}
}),
}
}
pub fn get_all_text(&self) -> String {
match &self.content {
MessageContent::Text(t) => t.clone(),
MessageContent::Blocks(blocks) => blocks
.iter()
.filter_map(|b| {
if let ContentBlock::Text { text } = b {
Some(text.as_str())
} else {
None
}
})
.collect::<Vec<_>>()
.join(""),
}
}
pub fn get_tool_use_blocks(&self) -> Vec<&ContentBlock> {
match &self.content {
MessageContent::Blocks(blocks) => blocks
.iter()
.filter(|b| matches!(b, ContentBlock::ToolUse { .. }))
.collect(),
_ => vec![],
}
}
pub fn has_tool_use(&self) -> bool {
!self.get_tool_use_blocks().is_empty()
}
pub fn content_blocks(&self) -> Vec<ContentBlock> {
match &self.content {
MessageContent::Text(t) => vec![ContentBlock::Text { text: t.clone() }],
MessageContent::Blocks(b) => b.clone(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Usage {
pub input_tokens: u64,
pub output_tokens: u64,
#[serde(default)]
pub total_tokens: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub cost_usd: Option<f64>,
#[serde(default, skip_serializing_if = "Value::is_null")]
pub provider_usage: Value,
}
impl Usage {
pub fn total(&self) -> u64 {
if self.total_tokens > 0 {
self.total_tokens
} else {
self.input_tokens + self.output_tokens
}
}
pub fn merge(&mut self, other: &Usage) {
self.input_tokens += other.input_tokens;
self.output_tokens += other.output_tokens;
self.total_tokens = self.input_tokens + self.output_tokens;
if let (Some(a), Some(b)) = (self.cost_usd, other.cost_usd) {
self.cost_usd = Some(a + b);
} else if other.cost_usd.is_some() {
self.cost_usd = other.cost_usd;
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum StopReason {
EndTurn,
MaxTokens,
ToolUse,
StopSequence,
ContentFilter,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub input_schema: Value,
}
#[derive(Debug, Clone)]
pub enum StreamEvent {
MessageStart {
id: String,
model: String,
},
ContentBlockStart {
index: usize,
block_type: String,
#[allow(unused)]
id: Option<String>,
#[allow(unused)]
name: Option<String>,
},
TextDelta {
index: usize,
text: String,
},
InputJsonDelta {
index: usize,
partial_json: String,
},
ThinkingDelta {
index: usize,
thinking: String,
},
ContentBlockStop {
index: usize,
},
MessageDelta {
stop_reason: Option<StopReason>,
usage: Option<Usage>,
},
MessageStop,
Error {
message: String,
},
Ping,
}
#[derive(thiserror::Error, Debug)]
pub enum CerseiError {
#[error("Provider error: {0}")]
Provider(String),
#[error("Provider error {status}: {message}")]
ProviderStatus { status: u16, message: String },
#[error("Authentication error: {0}")]
Auth(String),
#[error("Tool error: {0}")]
Tool(String),
#[error("Permission denied: {0}")]
Permission(String),
#[error("Rate limit exceeded")]
RateLimit { retry_after: Option<Duration> },
#[error("Context overflow: {used}/{limit} tokens")]
ContextOverflow { used: u64, limit: u64 },
#[error("Cancelled")]
Cancelled,
#[error("Configuration error: {0}")]
Config(String),
#[error("MCP error: {0}")]
Mcp(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
#[error("HTTP error: {0}")]
Http(#[from] reqwest::Error),
#[error("{0}")]
Other(#[from] anyhow::Error),
}
impl CerseiError {
pub fn is_retryable(&self) -> bool {
matches!(
self,
CerseiError::RateLimit { .. }
| CerseiError::ProviderStatus { status: 429, .. }
| CerseiError::ProviderStatus { status: 529, .. }
)
}
pub fn is_context_limit(&self) -> bool {
matches!(self, CerseiError::ContextOverflow { .. })
}
}
pub type Result<T> = std::result::Result<T, CerseiError>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionInfo {
pub id: String,
pub created_at: chrono::DateTime<chrono::Utc>,
pub message_count: usize,
pub model: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryEntry {
pub content: String,
pub relevance: f32,
pub source: String,
}