1use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::time::Duration;
7
8mod media;
9pub use media::{detect_mime, MediaKind};
10
11#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
14#[serde(rename_all = "lowercase")]
15pub enum Role {
16 User,
17 Assistant,
18 System,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
24#[serde(tag = "type", rename_all = "snake_case")]
25pub enum ContentBlock {
26 Text {
27 text: String,
28 },
29 Image {
30 source: ImageSource,
31 },
32 ToolUse {
33 id: String,
34 name: String,
35 input: Value,
36 },
37 ToolResult {
38 tool_use_id: String,
39 content: ToolResultContent,
40 #[serde(skip_serializing_if = "Option::is_none")]
41 is_error: Option<bool>,
42 },
43 Thinking {
44 thinking: String,
45 #[serde(default)]
46 signature: String,
47 },
48 RedactedThinking {
49 data: String,
50 },
51 Document {
52 source: DocumentSource,
53 #[serde(skip_serializing_if = "Option::is_none")]
54 title: Option<String>,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 context: Option<String>,
57 #[serde(skip_serializing_if = "Option::is_none")]
58 citations: Option<CitationsConfig>,
59 },
60 #[serde(other)]
62 Opaque,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66#[serde(untagged)]
67pub enum ToolResultContent {
68 Text(String),
69 Blocks(Vec<ContentBlock>),
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct ImageSource {
74 #[serde(rename = "type")]
75 pub source_type: String,
76 #[serde(skip_serializing_if = "Option::is_none")]
77 pub media_type: Option<String>,
78 #[serde(skip_serializing_if = "Option::is_none")]
79 pub data: Option<String>,
80 #[serde(skip_serializing_if = "Option::is_none")]
81 pub url: Option<String>,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct DocumentSource {
86 #[serde(rename = "type")]
87 pub source_type: String,
88 #[serde(skip_serializing_if = "Option::is_none")]
89 pub media_type: Option<String>,
90 #[serde(skip_serializing_if = "Option::is_none")]
91 pub data: Option<String>,
92 #[serde(skip_serializing_if = "Option::is_none")]
93 pub url: Option<String>,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct CitationsConfig {
98 pub enabled: bool,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct Message {
105 pub role: Role,
106 pub content: MessageContent,
107 #[serde(skip_serializing_if = "Option::is_none")]
108 pub id: Option<String>,
109 #[serde(skip_serializing_if = "Option::is_none")]
110 pub metadata: Option<MessageMetadata>,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
114#[serde(untagged)]
115pub enum MessageContent {
116 Text(String),
117 Blocks(Vec<ContentBlock>),
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize, Default)]
121pub struct MessageMetadata {
122 #[serde(skip_serializing_if = "Option::is_none")]
123 pub model: Option<String>,
124 #[serde(skip_serializing_if = "Option::is_none")]
125 pub usage: Option<Usage>,
126 #[serde(skip_serializing_if = "Option::is_none")]
127 pub stop_reason: Option<StopReason>,
128 #[serde(default, skip_serializing_if = "Value::is_null")]
130 pub provider_data: Value,
131}
132
133impl Message {
134 pub fn user(content: impl Into<String>) -> Self {
135 Self {
136 role: Role::User,
137 content: MessageContent::Text(content.into()),
138 id: None,
139 metadata: None,
140 }
141 }
142
143 pub fn user_blocks(blocks: Vec<ContentBlock>) -> Self {
144 Self {
145 role: Role::User,
146 content: MessageContent::Blocks(blocks),
147 id: None,
148 metadata: None,
149 }
150 }
151
152 pub fn assistant(content: impl Into<String>) -> Self {
153 Self {
154 role: Role::Assistant,
155 content: MessageContent::Text(content.into()),
156 id: None,
157 metadata: None,
158 }
159 }
160
161 pub fn assistant_blocks(blocks: Vec<ContentBlock>) -> Self {
162 Self {
163 role: Role::Assistant,
164 content: MessageContent::Blocks(blocks),
165 id: None,
166 metadata: None,
167 }
168 }
169
170 pub fn system(content: impl Into<String>) -> Self {
171 Self {
172 role: Role::System,
173 content: MessageContent::Text(content.into()),
174 id: None,
175 metadata: None,
176 }
177 }
178
179 pub fn get_text(&self) -> Option<&str> {
181 match &self.content {
182 MessageContent::Text(t) => Some(t.as_str()),
183 MessageContent::Blocks(blocks) => blocks.iter().find_map(|b| {
184 if let ContentBlock::Text { text } = b {
185 Some(text.as_str())
186 } else {
187 None
188 }
189 }),
190 }
191 }
192
193 pub fn get_all_text(&self) -> String {
195 match &self.content {
196 MessageContent::Text(t) => t.clone(),
197 MessageContent::Blocks(blocks) => blocks
198 .iter()
199 .filter_map(|b| {
200 if let ContentBlock::Text { text } = b {
201 Some(text.as_str())
202 } else {
203 None
204 }
205 })
206 .collect::<Vec<_>>()
207 .join(""),
208 }
209 }
210
211 pub fn get_tool_use_blocks(&self) -> Vec<&ContentBlock> {
212 match &self.content {
213 MessageContent::Blocks(blocks) => blocks
214 .iter()
215 .filter(|b| matches!(b, ContentBlock::ToolUse { .. }))
216 .collect(),
217 _ => vec![],
218 }
219 }
220
221 pub fn has_tool_use(&self) -> bool {
222 !self.get_tool_use_blocks().is_empty()
223 }
224
225 pub fn content_blocks(&self) -> Vec<ContentBlock> {
226 match &self.content {
227 MessageContent::Text(t) => vec![ContentBlock::Text { text: t.clone() }],
228 MessageContent::Blocks(b) => b.clone(),
229 }
230 }
231}
232
233#[derive(Debug, Clone, Serialize, Deserialize, Default)]
236pub struct Usage {
237 pub input_tokens: u64,
238 pub output_tokens: u64,
239 #[serde(default)]
240 pub total_tokens: u64,
241 #[serde(skip_serializing_if = "Option::is_none")]
242 pub cost_usd: Option<f64>,
243 #[serde(default, skip_serializing_if = "Value::is_null")]
245 pub provider_usage: Value,
246}
247
248impl Usage {
249 pub fn total(&self) -> u64 {
250 if self.total_tokens > 0 {
251 self.total_tokens
252 } else {
253 self.input_tokens + self.output_tokens
254 }
255 }
256
257 pub fn merge(&mut self, other: &Usage) {
258 self.input_tokens += other.input_tokens;
259 self.output_tokens += other.output_tokens;
260 self.total_tokens = self.input_tokens + self.output_tokens;
261 if let (Some(a), Some(b)) = (self.cost_usd, other.cost_usd) {
262 self.cost_usd = Some(a + b);
263 } else if other.cost_usd.is_some() {
264 self.cost_usd = other.cost_usd;
265 }
266 }
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
272#[serde(rename_all = "snake_case")]
273pub enum StopReason {
274 EndTurn,
275 MaxTokens,
276 ToolUse,
277 StopSequence,
278 ContentFilter,
279}
280
281#[derive(Debug, Clone, Serialize, Deserialize)]
284pub struct ToolDefinition {
285 pub name: String,
286 pub description: String,
287 pub input_schema: Value,
288}
289
290#[derive(Debug, Clone)]
293pub enum StreamEvent {
294 MessageStart {
295 id: String,
296 model: String,
297 },
298 ContentBlockStart {
299 index: usize,
300 block_type: String,
301 #[allow(unused)]
303 id: Option<String>,
304 #[allow(unused)]
306 name: Option<String>,
307 },
308 TextDelta {
309 index: usize,
310 text: String,
311 },
312 InputJsonDelta {
313 index: usize,
314 partial_json: String,
315 },
316 ThinkingDelta {
317 index: usize,
318 thinking: String,
319 },
320 ContentBlockStop {
321 index: usize,
322 },
323 MessageDelta {
324 stop_reason: Option<StopReason>,
325 usage: Option<Usage>,
326 },
327 MessageStop,
328 Error {
329 message: String,
330 },
331 Ping,
332}
333
334#[derive(thiserror::Error, Debug)]
337pub enum CerseiError {
338 #[error("Provider error: {0}")]
339 Provider(String),
340
341 #[error("Provider error {status}: {message}")]
342 ProviderStatus { status: u16, message: String },
343
344 #[error("Authentication error: {0}")]
345 Auth(String),
346
347 #[error("Tool error: {0}")]
348 Tool(String),
349
350 #[error("Permission denied: {0}")]
351 Permission(String),
352
353 #[error("Rate limit exceeded")]
354 RateLimit { retry_after: Option<Duration> },
355
356 #[error("Context overflow: {used}/{limit} tokens")]
357 ContextOverflow { used: u64, limit: u64 },
358
359 #[error("Cancelled")]
360 Cancelled,
361
362 #[error("Configuration error: {0}")]
363 Config(String),
364
365 #[error("MCP error: {0}")]
366 Mcp(String),
367
368 #[error("IO error: {0}")]
369 Io(#[from] std::io::Error),
370
371 #[error("JSON error: {0}")]
372 Json(#[from] serde_json::Error),
373
374 #[error("HTTP error: {0}")]
375 Http(#[from] reqwest::Error),
376
377 #[error("{0}")]
378 Other(#[from] anyhow::Error),
379}
380
381impl CerseiError {
382 pub fn is_retryable(&self) -> bool {
383 matches!(
384 self,
385 CerseiError::RateLimit { .. }
386 | CerseiError::ProviderStatus { status: 429, .. }
387 | CerseiError::ProviderStatus { status: 529, .. }
388 )
389 }
390
391 pub fn is_context_limit(&self) -> bool {
392 matches!(self, CerseiError::ContextOverflow { .. })
393 }
394}
395
396pub type Result<T> = std::result::Result<T, CerseiError>;
397
398#[derive(Debug, Clone, Serialize, Deserialize)]
401pub struct SessionInfo {
402 pub id: String,
403 pub created_at: chrono::DateTime<chrono::Utc>,
404 pub message_count: usize,
405 pub model: Option<String>,
406}
407
408#[derive(Debug, Clone, Serialize, Deserialize)]
409pub struct MemoryEntry {
410 pub content: String,
411 pub relevance: f32,
412 pub source: String,
413}