Skip to main content

cersei_types/
lib.rs

1//! cersei-types: Provider-agnostic message types, errors, and content blocks
2//! for the Cersei coding agent SDK.
3
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::time::Duration;
7
8mod media;
9pub use media::{detect_mime, MediaKind};
10
11// ─── Roles ───────────────────────────────────────────────────────────────────
12
13#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
14#[serde(rename_all = "lowercase")]
15pub enum Role {
16    User,
17    Assistant,
18    System,
19}
20
21// ─── Content blocks ──────────────────────────────────────────────────────────
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24#[serde(tag = "type", rename_all = "snake_case")]
25pub enum ContentBlock {
26    Text {
27        text: String,
28    },
29    Image {
30        source: ImageSource,
31    },
32    ToolUse {
33        id: String,
34        name: String,
35        input: Value,
36    },
37    ToolResult {
38        tool_use_id: String,
39        content: ToolResultContent,
40        #[serde(skip_serializing_if = "Option::is_none")]
41        is_error: Option<bool>,
42    },
43    Thinking {
44        thinking: String,
45        #[serde(default)]
46        signature: String,
47    },
48    RedactedThinking {
49        data: String,
50    },
51    Document {
52        source: DocumentSource,
53        #[serde(skip_serializing_if = "Option::is_none")]
54        title: Option<String>,
55        #[serde(skip_serializing_if = "Option::is_none")]
56        context: Option<String>,
57        #[serde(skip_serializing_if = "Option::is_none")]
58        citations: Option<CitationsConfig>,
59    },
60    /// Escape hatch for provider-specific block types not covered above.
61    #[serde(other)]
62    Opaque,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66#[serde(untagged)]
67pub enum ToolResultContent {
68    Text(String),
69    Blocks(Vec<ContentBlock>),
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct ImageSource {
74    #[serde(rename = "type")]
75    pub source_type: String,
76    #[serde(skip_serializing_if = "Option::is_none")]
77    pub media_type: Option<String>,
78    #[serde(skip_serializing_if = "Option::is_none")]
79    pub data: Option<String>,
80    #[serde(skip_serializing_if = "Option::is_none")]
81    pub url: Option<String>,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct DocumentSource {
86    #[serde(rename = "type")]
87    pub source_type: String,
88    #[serde(skip_serializing_if = "Option::is_none")]
89    pub media_type: Option<String>,
90    #[serde(skip_serializing_if = "Option::is_none")]
91    pub data: Option<String>,
92    #[serde(skip_serializing_if = "Option::is_none")]
93    pub url: Option<String>,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct CitationsConfig {
98    pub enabled: bool,
99}
100
101// ─── Messages ────────────────────────────────────────────────────────────────
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct Message {
105    pub role: Role,
106    pub content: MessageContent,
107    #[serde(skip_serializing_if = "Option::is_none")]
108    pub id: Option<String>,
109    #[serde(skip_serializing_if = "Option::is_none")]
110    pub metadata: Option<MessageMetadata>,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
114#[serde(untagged)]
115pub enum MessageContent {
116    Text(String),
117    Blocks(Vec<ContentBlock>),
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize, Default)]
121pub struct MessageMetadata {
122    #[serde(skip_serializing_if = "Option::is_none")]
123    pub model: Option<String>,
124    #[serde(skip_serializing_if = "Option::is_none")]
125    pub usage: Option<Usage>,
126    #[serde(skip_serializing_if = "Option::is_none")]
127    pub stop_reason: Option<StopReason>,
128    /// Provider-specific metadata (cache tokens, etc.)
129    #[serde(default, skip_serializing_if = "Value::is_null")]
130    pub provider_data: Value,
131}
132
133impl Message {
134    pub fn user(content: impl Into<String>) -> Self {
135        Self {
136            role: Role::User,
137            content: MessageContent::Text(content.into()),
138            id: None,
139            metadata: None,
140        }
141    }
142
143    pub fn user_blocks(blocks: Vec<ContentBlock>) -> Self {
144        Self {
145            role: Role::User,
146            content: MessageContent::Blocks(blocks),
147            id: None,
148            metadata: None,
149        }
150    }
151
152    pub fn assistant(content: impl Into<String>) -> Self {
153        Self {
154            role: Role::Assistant,
155            content: MessageContent::Text(content.into()),
156            id: None,
157            metadata: None,
158        }
159    }
160
161    pub fn assistant_blocks(blocks: Vec<ContentBlock>) -> Self {
162        Self {
163            role: Role::Assistant,
164            content: MessageContent::Blocks(blocks),
165            id: None,
166            metadata: None,
167        }
168    }
169
170    pub fn system(content: impl Into<String>) -> Self {
171        Self {
172            role: Role::System,
173            content: MessageContent::Text(content.into()),
174            id: None,
175            metadata: None,
176        }
177    }
178
179    /// Extract the first text content from this message.
180    pub fn get_text(&self) -> Option<&str> {
181        match &self.content {
182            MessageContent::Text(t) => Some(t.as_str()),
183            MessageContent::Blocks(blocks) => blocks.iter().find_map(|b| {
184                if let ContentBlock::Text { text } = b {
185                    Some(text.as_str())
186                } else {
187                    None
188                }
189            }),
190        }
191    }
192
193    /// Collect all text content blocks into one concatenated string.
194    pub fn get_all_text(&self) -> String {
195        match &self.content {
196            MessageContent::Text(t) => t.clone(),
197            MessageContent::Blocks(blocks) => blocks
198                .iter()
199                .filter_map(|b| {
200                    if let ContentBlock::Text { text } = b {
201                        Some(text.as_str())
202                    } else {
203                        None
204                    }
205                })
206                .collect::<Vec<_>>()
207                .join(""),
208        }
209    }
210
211    pub fn get_tool_use_blocks(&self) -> Vec<&ContentBlock> {
212        match &self.content {
213            MessageContent::Blocks(blocks) => blocks
214                .iter()
215                .filter(|b| matches!(b, ContentBlock::ToolUse { .. }))
216                .collect(),
217            _ => vec![],
218        }
219    }
220
221    pub fn has_tool_use(&self) -> bool {
222        !self.get_tool_use_blocks().is_empty()
223    }
224
225    pub fn content_blocks(&self) -> Vec<ContentBlock> {
226        match &self.content {
227            MessageContent::Text(t) => vec![ContentBlock::Text { text: t.clone() }],
228            MessageContent::Blocks(b) => b.clone(),
229        }
230    }
231}
232
233// ─── Usage / Cost ────────────────────────────────────────────────────────────
234
235#[derive(Debug, Clone, Serialize, Deserialize, Default)]
236pub struct Usage {
237    pub input_tokens: u64,
238    pub output_tokens: u64,
239    #[serde(default)]
240    pub total_tokens: u64,
241    #[serde(skip_serializing_if = "Option::is_none")]
242    pub cost_usd: Option<f64>,
243    /// Provider-specific usage data (e.g. cache_creation_input_tokens for Anthropic)
244    #[serde(default, skip_serializing_if = "Value::is_null")]
245    pub provider_usage: Value,
246}
247
248impl Usage {
249    pub fn total(&self) -> u64 {
250        if self.total_tokens > 0 {
251            self.total_tokens
252        } else {
253            self.input_tokens + self.output_tokens
254        }
255    }
256
257    pub fn merge(&mut self, other: &Usage) {
258        self.input_tokens += other.input_tokens;
259        self.output_tokens += other.output_tokens;
260        self.total_tokens = self.input_tokens + self.output_tokens;
261        if let (Some(a), Some(b)) = (self.cost_usd, other.cost_usd) {
262            self.cost_usd = Some(a + b);
263        } else if other.cost_usd.is_some() {
264            self.cost_usd = other.cost_usd;
265        }
266    }
267}
268
269// ─── Stop reasons ────────────────────────────────────────────────────────────
270
271#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
272#[serde(rename_all = "snake_case")]
273pub enum StopReason {
274    EndTurn,
275    MaxTokens,
276    ToolUse,
277    StopSequence,
278    ContentFilter,
279}
280
281// ─── Tool definition (sent to providers) ─────────────────────────────────────
282
283#[derive(Debug, Clone, Serialize, Deserialize)]
284pub struct ToolDefinition {
285    pub name: String,
286    pub description: String,
287    pub input_schema: Value,
288}
289
290// ─── Stream events ───────────────────────────────────────────────────────────
291
292#[derive(Debug, Clone)]
293pub enum StreamEvent {
294    MessageStart {
295        id: String,
296        model: String,
297    },
298    ContentBlockStart {
299        index: usize,
300        block_type: String,
301        /// For tool_use blocks: the tool use ID. Default: None.
302        #[allow(unused)]
303        id: Option<String>,
304        /// For tool_use blocks: the tool name. Default: None.
305        #[allow(unused)]
306        name: Option<String>,
307    },
308    TextDelta {
309        index: usize,
310        text: String,
311    },
312    InputJsonDelta {
313        index: usize,
314        partial_json: String,
315    },
316    ThinkingDelta {
317        index: usize,
318        thinking: String,
319    },
320    ContentBlockStop {
321        index: usize,
322    },
323    MessageDelta {
324        stop_reason: Option<StopReason>,
325        usage: Option<Usage>,
326    },
327    MessageStop,
328    Error {
329        message: String,
330    },
331    Ping,
332}
333
334// ─── Errors ──────────────────────────────────────────────────────────────────
335
336#[derive(thiserror::Error, Debug)]
337pub enum CerseiError {
338    #[error("Provider error: {0}")]
339    Provider(String),
340
341    #[error("Provider error {status}: {message}")]
342    ProviderStatus { status: u16, message: String },
343
344    #[error("Authentication error: {0}")]
345    Auth(String),
346
347    #[error("Tool error: {0}")]
348    Tool(String),
349
350    #[error("Permission denied: {0}")]
351    Permission(String),
352
353    #[error("Rate limit exceeded")]
354    RateLimit { retry_after: Option<Duration> },
355
356    #[error("Context overflow: {used}/{limit} tokens")]
357    ContextOverflow { used: u64, limit: u64 },
358
359    #[error("Cancelled")]
360    Cancelled,
361
362    #[error("Configuration error: {0}")]
363    Config(String),
364
365    #[error("MCP error: {0}")]
366    Mcp(String),
367
368    #[error("IO error: {0}")]
369    Io(#[from] std::io::Error),
370
371    #[error("JSON error: {0}")]
372    Json(#[from] serde_json::Error),
373
374    #[error("HTTP error: {0}")]
375    Http(#[from] reqwest::Error),
376
377    #[error("{0}")]
378    Other(#[from] anyhow::Error),
379}
380
381impl CerseiError {
382    pub fn is_retryable(&self) -> bool {
383        matches!(
384            self,
385            CerseiError::RateLimit { .. }
386                | CerseiError::ProviderStatus { status: 429, .. }
387                | CerseiError::ProviderStatus { status: 529, .. }
388        )
389    }
390
391    pub fn is_context_limit(&self) -> bool {
392        matches!(self, CerseiError::ContextOverflow { .. })
393    }
394}
395
396pub type Result<T> = std::result::Result<T, CerseiError>;
397
398// ─── Session info ────────────────────────────────────────────────────────────
399
400#[derive(Debug, Clone, Serialize, Deserialize)]
401pub struct SessionInfo {
402    pub id: String,
403    pub created_at: chrono::DateTime<chrono::Utc>,
404    pub message_count: usize,
405    pub model: Option<String>,
406}
407
408#[derive(Debug, Clone, Serialize, Deserialize)]
409pub struct MemoryEntry {
410    pub content: String,
411    pub relevance: f32,
412    pub source: String,
413}