Skip to main content

cersei_types/
lib.rs

1//! cersei-types: Provider-agnostic message types, errors, and content blocks
2//! for the Cersei coding agent SDK.
3
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::time::Duration;
7
8// ─── Roles ───────────────────────────────────────────────────────────────────
9
10#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
11#[serde(rename_all = "lowercase")]
12pub enum Role {
13    User,
14    Assistant,
15    System,
16}
17
18// ─── Content blocks ──────────────────────────────────────────────────────────
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21#[serde(tag = "type", rename_all = "snake_case")]
22pub enum ContentBlock {
23    Text {
24        text: String,
25    },
26    Image {
27        source: ImageSource,
28    },
29    ToolUse {
30        id: String,
31        name: String,
32        input: Value,
33    },
34    ToolResult {
35        tool_use_id: String,
36        content: ToolResultContent,
37        #[serde(skip_serializing_if = "Option::is_none")]
38        is_error: Option<bool>,
39    },
40    Thinking {
41        thinking: String,
42        #[serde(default)]
43        signature: String,
44    },
45    RedactedThinking {
46        data: String,
47    },
48    Document {
49        source: DocumentSource,
50        #[serde(skip_serializing_if = "Option::is_none")]
51        title: Option<String>,
52        #[serde(skip_serializing_if = "Option::is_none")]
53        context: Option<String>,
54        #[serde(skip_serializing_if = "Option::is_none")]
55        citations: Option<CitationsConfig>,
56    },
57    /// Escape hatch for provider-specific block types not covered above.
58    #[serde(other)]
59    Opaque,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
63#[serde(untagged)]
64pub enum ToolResultContent {
65    Text(String),
66    Blocks(Vec<ContentBlock>),
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct ImageSource {
71    #[serde(rename = "type")]
72    pub source_type: String,
73    #[serde(skip_serializing_if = "Option::is_none")]
74    pub media_type: Option<String>,
75    #[serde(skip_serializing_if = "Option::is_none")]
76    pub data: Option<String>,
77    #[serde(skip_serializing_if = "Option::is_none")]
78    pub url: Option<String>,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct DocumentSource {
83    #[serde(rename = "type")]
84    pub source_type: String,
85    #[serde(skip_serializing_if = "Option::is_none")]
86    pub media_type: Option<String>,
87    #[serde(skip_serializing_if = "Option::is_none")]
88    pub data: Option<String>,
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub url: Option<String>,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct CitationsConfig {
95    pub enabled: bool,
96}
97
98// ─── Messages ────────────────────────────────────────────────────────────────
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct Message {
102    pub role: Role,
103    pub content: MessageContent,
104    #[serde(skip_serializing_if = "Option::is_none")]
105    pub id: Option<String>,
106    #[serde(skip_serializing_if = "Option::is_none")]
107    pub metadata: Option<MessageMetadata>,
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
111#[serde(untagged)]
112pub enum MessageContent {
113    Text(String),
114    Blocks(Vec<ContentBlock>),
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize, Default)]
118pub struct MessageMetadata {
119    #[serde(skip_serializing_if = "Option::is_none")]
120    pub model: Option<String>,
121    #[serde(skip_serializing_if = "Option::is_none")]
122    pub usage: Option<Usage>,
123    #[serde(skip_serializing_if = "Option::is_none")]
124    pub stop_reason: Option<StopReason>,
125    /// Provider-specific metadata (cache tokens, etc.)
126    #[serde(default, skip_serializing_if = "Value::is_null")]
127    pub provider_data: Value,
128}
129
130impl Message {
131    pub fn user(content: impl Into<String>) -> Self {
132        Self {
133            role: Role::User,
134            content: MessageContent::Text(content.into()),
135            id: None,
136            metadata: None,
137        }
138    }
139
140    pub fn user_blocks(blocks: Vec<ContentBlock>) -> Self {
141        Self {
142            role: Role::User,
143            content: MessageContent::Blocks(blocks),
144            id: None,
145            metadata: None,
146        }
147    }
148
149    pub fn assistant(content: impl Into<String>) -> Self {
150        Self {
151            role: Role::Assistant,
152            content: MessageContent::Text(content.into()),
153            id: None,
154            metadata: None,
155        }
156    }
157
158    pub fn assistant_blocks(blocks: Vec<ContentBlock>) -> Self {
159        Self {
160            role: Role::Assistant,
161            content: MessageContent::Blocks(blocks),
162            id: None,
163            metadata: None,
164        }
165    }
166
167    pub fn system(content: impl Into<String>) -> Self {
168        Self {
169            role: Role::System,
170            content: MessageContent::Text(content.into()),
171            id: None,
172            metadata: None,
173        }
174    }
175
176    /// Extract the first text content from this message.
177    pub fn get_text(&self) -> Option<&str> {
178        match &self.content {
179            MessageContent::Text(t) => Some(t.as_str()),
180            MessageContent::Blocks(blocks) => blocks.iter().find_map(|b| {
181                if let ContentBlock::Text { text } = b {
182                    Some(text.as_str())
183                } else {
184                    None
185                }
186            }),
187        }
188    }
189
190    /// Collect all text content blocks into one concatenated string.
191    pub fn get_all_text(&self) -> String {
192        match &self.content {
193            MessageContent::Text(t) => t.clone(),
194            MessageContent::Blocks(blocks) => blocks
195                .iter()
196                .filter_map(|b| {
197                    if let ContentBlock::Text { text } = b {
198                        Some(text.as_str())
199                    } else {
200                        None
201                    }
202                })
203                .collect::<Vec<_>>()
204                .join(""),
205        }
206    }
207
208    pub fn get_tool_use_blocks(&self) -> Vec<&ContentBlock> {
209        match &self.content {
210            MessageContent::Blocks(blocks) => blocks
211                .iter()
212                .filter(|b| matches!(b, ContentBlock::ToolUse { .. }))
213                .collect(),
214            _ => vec![],
215        }
216    }
217
218    pub fn has_tool_use(&self) -> bool {
219        !self.get_tool_use_blocks().is_empty()
220    }
221
222    pub fn content_blocks(&self) -> Vec<ContentBlock> {
223        match &self.content {
224            MessageContent::Text(t) => vec![ContentBlock::Text { text: t.clone() }],
225            MessageContent::Blocks(b) => b.clone(),
226        }
227    }
228}
229
230// ─── Usage / Cost ────────────────────────────────────────────────────────────
231
232#[derive(Debug, Clone, Serialize, Deserialize, Default)]
233pub struct Usage {
234    pub input_tokens: u64,
235    pub output_tokens: u64,
236    #[serde(default)]
237    pub total_tokens: u64,
238    #[serde(skip_serializing_if = "Option::is_none")]
239    pub cost_usd: Option<f64>,
240    /// Provider-specific usage data (e.g. cache_creation_input_tokens for Anthropic)
241    #[serde(default, skip_serializing_if = "Value::is_null")]
242    pub provider_usage: Value,
243}
244
245impl Usage {
246    pub fn total(&self) -> u64 {
247        if self.total_tokens > 0 {
248            self.total_tokens
249        } else {
250            self.input_tokens + self.output_tokens
251        }
252    }
253
254    pub fn merge(&mut self, other: &Usage) {
255        self.input_tokens += other.input_tokens;
256        self.output_tokens += other.output_tokens;
257        self.total_tokens = self.input_tokens + self.output_tokens;
258        if let (Some(a), Some(b)) = (self.cost_usd, other.cost_usd) {
259            self.cost_usd = Some(a + b);
260        } else if other.cost_usd.is_some() {
261            self.cost_usd = other.cost_usd;
262        }
263    }
264}
265
266// ─── Stop reasons ────────────────────────────────────────────────────────────
267
268#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
269#[serde(rename_all = "snake_case")]
270pub enum StopReason {
271    EndTurn,
272    MaxTokens,
273    ToolUse,
274    StopSequence,
275    ContentFilter,
276}
277
278// ─── Tool definition (sent to providers) ─────────────────────────────────────
279
280#[derive(Debug, Clone, Serialize, Deserialize)]
281pub struct ToolDefinition {
282    pub name: String,
283    pub description: String,
284    pub input_schema: Value,
285}
286
287// ─── Stream events ───────────────────────────────────────────────────────────
288
289#[derive(Debug, Clone)]
290pub enum StreamEvent {
291    MessageStart {
292        id: String,
293        model: String,
294    },
295    ContentBlockStart {
296        index: usize,
297        block_type: String,
298        /// For tool_use blocks: the tool use ID. Default: None.
299        #[allow(unused)]
300        id: Option<String>,
301        /// For tool_use blocks: the tool name. Default: None.
302        #[allow(unused)]
303        name: Option<String>,
304    },
305    TextDelta {
306        index: usize,
307        text: String,
308    },
309    InputJsonDelta {
310        index: usize,
311        partial_json: String,
312    },
313    ThinkingDelta {
314        index: usize,
315        thinking: String,
316    },
317    ContentBlockStop {
318        index: usize,
319    },
320    MessageDelta {
321        stop_reason: Option<StopReason>,
322        usage: Option<Usage>,
323    },
324    MessageStop,
325    Error {
326        message: String,
327    },
328    Ping,
329}
330
331// ─── Errors ──────────────────────────────────────────────────────────────────
332
333#[derive(thiserror::Error, Debug)]
334pub enum CerseiError {
335    #[error("Provider error: {0}")]
336    Provider(String),
337
338    #[error("Provider error {status}: {message}")]
339    ProviderStatus { status: u16, message: String },
340
341    #[error("Authentication error: {0}")]
342    Auth(String),
343
344    #[error("Tool error: {0}")]
345    Tool(String),
346
347    #[error("Permission denied: {0}")]
348    Permission(String),
349
350    #[error("Rate limit exceeded")]
351    RateLimit { retry_after: Option<Duration> },
352
353    #[error("Context overflow: {used}/{limit} tokens")]
354    ContextOverflow { used: u64, limit: u64 },
355
356    #[error("Cancelled")]
357    Cancelled,
358
359    #[error("Configuration error: {0}")]
360    Config(String),
361
362    #[error("MCP error: {0}")]
363    Mcp(String),
364
365    #[error("IO error: {0}")]
366    Io(#[from] std::io::Error),
367
368    #[error("JSON error: {0}")]
369    Json(#[from] serde_json::Error),
370
371    #[error("HTTP error: {0}")]
372    Http(#[from] reqwest::Error),
373
374    #[error("{0}")]
375    Other(#[from] anyhow::Error),
376}
377
378impl CerseiError {
379    pub fn is_retryable(&self) -> bool {
380        matches!(
381            self,
382            CerseiError::RateLimit { .. }
383                | CerseiError::ProviderStatus { status: 429, .. }
384                | CerseiError::ProviderStatus { status: 529, .. }
385        )
386    }
387
388    pub fn is_context_limit(&self) -> bool {
389        matches!(self, CerseiError::ContextOverflow { .. })
390    }
391}
392
393pub type Result<T> = std::result::Result<T, CerseiError>;
394
395// ─── Session info ────────────────────────────────────────────────────────────
396
397#[derive(Debug, Clone, Serialize, Deserialize)]
398pub struct SessionInfo {
399    pub id: String,
400    pub created_at: chrono::DateTime<chrono::Utc>,
401    pub message_count: usize,
402    pub model: Option<String>,
403}
404
405#[derive(Debug, Clone, Serialize, Deserialize)]
406pub struct MemoryEntry {
407    pub content: String,
408    pub relevance: f32,
409    pub source: String,
410}