Skip to main content

llm/providers/anthropic/
types.rs

1use crate::TokenUsage;
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
5#[serde(rename_all = "lowercase")]
6pub enum Role {
7    User,
8    Assistant,
9}
10
11#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
12#[serde(rename_all = "lowercase")]
13pub enum CacheType {
14    Ephemeral,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18#[serde(rename_all = "snake_case")]
19pub struct Request {
20    pub model: String,
21    pub messages: Vec<Message>,
22    pub max_tokens: u32,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub temperature: Option<f32>,
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub system: Option<SystemContent>,
27    #[serde(skip_serializing_if = "Option::is_none")]
28    pub tools: Option<Vec<Tool>>,
29    pub stream: bool,
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub thinking: Option<Thinking>,
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub cache_control: Option<CacheControl>,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct Thinking {
38    #[serde(rename = "type")]
39    pub thinking_type: String,
40    pub budget_tokens: u32,
41}
42
43impl Thinking {
44    pub fn new(budget_tokens: u32) -> Self {
45        Self { thinking_type: "enabled".to_string(), budget_tokens }
46    }
47}
48
49impl Request {
50    pub fn new(model: String, messages: Vec<Message>) -> Self {
51        Self {
52            model,
53            messages,
54            max_tokens: 4096,
55            temperature: None,
56            system: None,
57            tools: None,
58            stream: false,
59            thinking: None,
60            cache_control: None,
61        }
62    }
63
64    pub fn with_auto_caching(mut self) -> Self {
65        self.cache_control = Some(CacheControl::ephemeral());
66        self
67    }
68
69    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
70        self.max_tokens = max_tokens;
71        self
72    }
73
74    pub fn with_temperature(mut self, temperature: f32) -> Self {
75        self.temperature = Some(temperature);
76        self
77    }
78
79    pub fn with_system_cached(mut self, system: String) -> Self {
80        self.system = Some(SystemContent::Blocks(vec![SystemContentBlock::Text {
81            text: system,
82            cache_control: Some(CacheControl::ephemeral()),
83        }]));
84        self
85    }
86
87    pub fn with_tools(mut self, tools: Vec<Tool>) -> Self {
88        self.tools = Some(tools);
89        self
90    }
91
92    pub fn with_stream(mut self, stream: bool) -> Self {
93        self.stream = stream;
94        self
95    }
96
97    pub fn with_thinking(mut self, thinking: Thinking) -> Self {
98        self.thinking = Some(thinking);
99        self
100    }
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
104#[serde(rename_all = "snake_case")]
105pub struct Message {
106    pub role: Role,
107    pub content: Content,
108    #[serde(skip_serializing_if = "Option::is_none")]
109    pub cache_control: Option<CacheControl>,
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
113#[serde(untagged)]
114pub enum Content {
115    Text(String),
116    Blocks(Vec<ContentBlock>),
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
120#[serde(untagged)]
121pub enum SystemContent {
122    Text(String),
123    Blocks(Vec<SystemContentBlock>),
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
127#[serde(tag = "type")]
128pub enum SystemContentBlock {
129    #[serde(rename = "text")]
130    Text {
131        text: String,
132        #[serde(skip_serializing_if = "Option::is_none")]
133        cache_control: Option<CacheControl>,
134    },
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
138#[serde(tag = "type")]
139pub enum ContentBlock {
140    #[serde(rename = "text")]
141    Text {
142        text: String,
143        #[serde(skip_serializing_if = "Option::is_none")]
144        cache_control: Option<CacheControl>,
145    },
146    #[serde(rename = "image")]
147    Image { source: ImageSource },
148    #[serde(rename = "tool_use")]
149    ToolUse { id: String, name: String, input: serde_json::Value },
150    #[serde(rename = "tool_result")]
151    ToolResult {
152        tool_use_id: String,
153        content: String,
154        #[serde(skip_serializing_if = "Option::is_none")]
155        is_error: Option<bool>,
156    },
157}
158
159#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct ImageSource {
161    #[serde(rename = "type")]
162    pub source_type: String,
163    pub media_type: String,
164    pub data: String,
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct CacheControl {
169    #[serde(rename = "type")]
170    pub cache_type: CacheType,
171}
172
173impl CacheControl {
174    pub fn ephemeral() -> Self {
175        Self { cache_type: CacheType::Ephemeral }
176    }
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize)]
180#[serde(rename_all = "snake_case")]
181pub struct Tool {
182    pub name: String,
183    pub description: String,
184    pub input_schema: serde_json::Value,
185    #[serde(skip_serializing_if = "Option::is_none")]
186    pub cache_control: Option<CacheControl>,
187}
188
189#[derive(Debug, Clone, Deserialize)]
190#[serde(tag = "type", rename_all = "snake_case")]
191pub enum StreamEvent {
192    MessageStart {
193        #[serde(flatten)]
194        data: MessageStart,
195    },
196    ContentBlockStart {
197        #[serde(flatten)]
198        data: ContentBlockStart,
199    },
200    ContentBlockDelta {
201        #[serde(flatten)]
202        data: ContentBlockDelta,
203    },
204    ContentBlockStop {
205        #[serde(flatten)]
206        data: ContentBlockStop,
207    },
208    MessageDelta {
209        #[serde(flatten)]
210        data: MessageDelta,
211    },
212    MessageStop {
213        #[serde(flatten)]
214        data: MessageStop,
215    },
216    #[serde(rename = "error")]
217    Error {
218        #[serde(flatten)]
219        data: ErrorEvent,
220    },
221    #[serde(rename = "ping")]
222    Ping,
223}
224
225#[derive(Debug, Clone, Deserialize)]
226pub struct MessageStart {
227    pub message: ResponseMessage,
228}
229
230#[derive(Debug, Clone, Deserialize)]
231#[serde(rename_all = "snake_case")]
232pub struct ResponseMessage {
233    pub id: String,
234    #[serde(rename = "type")]
235    pub message_type: String,
236    pub role: Role,
237    pub content: Vec<serde_json::Value>,
238    pub model: String,
239    pub stop_reason: Option<String>,
240    pub stop_sequence: Option<String>,
241    pub usage: Usage,
242}
243
244#[derive(Debug, Clone, Deserialize)]
245#[serde(rename_all = "snake_case")]
246pub struct Usage {
247    pub input_tokens: u32,
248    pub output_tokens: u32,
249    #[serde(skip_serializing_if = "Option::is_none")]
250    pub cache_creation_input_tokens: Option<u32>,
251    #[serde(skip_serializing_if = "Option::is_none")]
252    pub cache_read_input_tokens: Option<u32>,
253}
254
255impl From<&Usage> for TokenUsage {
256    fn from(usage: &Usage) -> Self {
257        TokenUsage {
258            input_tokens: usage.input_tokens,
259            output_tokens: usage.output_tokens,
260            cache_read_tokens: usage.cache_read_input_tokens,
261            cache_creation_tokens: usage.cache_creation_input_tokens,
262            ..TokenUsage::default()
263        }
264    }
265}
266
267#[derive(Debug, Clone, Deserialize)]
268#[serde(rename_all = "snake_case")]
269pub struct ContentBlockStart {
270    pub index: u32,
271    pub content_block: ContentBlockStartData,
272}
273
274#[derive(Debug, Clone, Deserialize)]
275#[serde(tag = "type")]
276pub enum ContentBlockStartData {
277    #[serde(rename = "text")]
278    Text { text: String },
279    #[serde(rename = "tool_use")]
280    ToolUse { id: String, name: String },
281    #[serde(rename = "thinking")]
282    Thinking { thinking: String },
283}
284
285#[derive(Debug, Clone, Deserialize)]
286#[serde(rename_all = "snake_case")]
287pub struct ContentBlockDelta {
288    pub index: u32,
289    pub delta: ContentBlockDeltaData,
290}
291
292#[derive(Debug, Clone, Deserialize)]
293#[serde(tag = "type", rename_all = "snake_case")]
294pub enum ContentBlockDeltaData {
295    TextDelta { text: String },
296    InputJsonDelta { partial_json: String },
297    ThinkingDelta { thinking: String },
298}
299
300#[derive(Debug, Clone, Deserialize)]
301pub struct ContentBlockStop {
302    pub index: u32,
303}
304
305#[derive(Debug, Clone, Deserialize)]
306#[serde(rename_all = "snake_case")]
307pub struct MessageDelta {
308    pub delta: MessageDeltaData,
309    #[serde(default)]
310    pub usage: Option<Usage>,
311}
312
313#[derive(Debug, Clone, Deserialize)]
314#[serde(rename_all = "snake_case")]
315pub struct MessageDeltaData {
316    #[serde(skip_serializing_if = "Option::is_none")]
317    pub stop_reason: Option<String>,
318    #[serde(skip_serializing_if = "Option::is_none")]
319    pub stop_sequence: Option<String>,
320}
321
322#[derive(Debug, Clone, Deserialize)]
323pub struct MessageStop {}
324
325#[derive(Debug, Clone, Deserialize)]
326#[serde(rename_all = "snake_case")]
327pub struct ErrorEvent {
328    pub error: Error,
329}
330
331#[derive(Debug, Clone, Deserialize)]
332#[serde(rename_all = "snake_case")]
333pub struct Error {
334    #[serde(rename = "type")]
335    pub error_type: String,
336    pub message: String,
337}