1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use thiserror::Error;
4
5#[derive(Debug, Error)]
7pub enum MessageError {
8 #[error("API request failed: {0}")]
9 RequestFailed(String),
10 #[error("API error: {0}")]
11 ApiError(String),
12}
13
14impl From<String> for MessageError {
15 fn from(error: String) -> Self {
16 MessageError::ApiError(error)
17 }
18}
19
20#[async_trait]
21pub trait MessageClient {
22 async fn create_message<'a>(
23 &'a self,
24 params: Option<&'a CreateMessageParams>,
25 ) -> Result<CreateMessageResponse, MessageError>;
26
27 async fn count_tokens<'a>(
28 &'a self,
29 params: Option<&'a CountMessageTokensParams>,
30 ) -> Result<CountMessageTokensResponse, MessageError>;
31
32 async fn create_message_streaming<'a>(
33 &'a self,
34 body: &'a CreateMessageParams,
35 ) -> Result<
36 impl futures_util::Stream<Item = Result<StreamEvent, MessageError>> + 'a,
37 MessageError,
38 >;
39}
40
41#[derive(Debug)]
42pub struct RequiredMessageParams {
43 pub model: String,
44 pub messages: Vec<Message>,
45 pub max_tokens: u32,
46}
47
48#[derive(Debug, Deserialize, Serialize, Default)]
50pub struct CreateMessageParams {
51 pub max_tokens: u32,
53 pub messages: Vec<Message>,
55 pub model: String,
57 #[serde(skip_serializing_if = "Option::is_none")]
59 pub system: Option<String>,
60 #[serde(skip_serializing_if = "Option::is_none")]
62 pub temperature: Option<f32>,
63 #[serde(skip_serializing_if = "Option::is_none")]
65 pub stop_sequences: Option<Vec<String>>,
66 #[serde(skip_serializing_if = "Option::is_none")]
68 pub stream: Option<bool>,
69 #[serde(skip_serializing_if = "Option::is_none")]
71 pub top_k: Option<u32>,
72 #[serde(skip_serializing_if = "Option::is_none")]
74 pub top_p: Option<f32>,
75 #[serde(skip_serializing_if = "Option::is_none")]
77 pub tools: Option<Vec<Tool>>,
78 #[serde(skip_serializing_if = "Option::is_none")]
80 pub tool_choice: Option<ToolChoice>,
81 #[serde(skip_serializing_if = "Option::is_none")]
83 pub thinking: Option<Thinking>,
84 #[serde(skip_serializing_if = "Option::is_none")]
86 pub metadata: Option<Metadata>,
87}
88
89impl From<RequiredMessageParams> for CreateMessageParams {
90 fn from(required: RequiredMessageParams) -> Self {
91 Self {
92 model: required.model,
93 messages: required.messages,
94 max_tokens: required.max_tokens,
95 ..Default::default()
96 }
97 }
98}
99
100impl CreateMessageParams {
101 pub fn new(required: RequiredMessageParams) -> Self {
103 required.into()
104 }
105
106 pub fn with_system(mut self, system: impl Into<String>) -> Self {
108 self.system = Some(system.into());
109 self
110 }
111
112 pub fn with_temperature(mut self, temperature: f32) -> Self {
113 self.temperature = Some(temperature);
114 self
115 }
116
117 pub fn with_stop_sequences(mut self, stop_sequences: Vec<String>) -> Self {
118 self.stop_sequences = Some(stop_sequences);
119 self
120 }
121
122 pub fn with_stream(mut self, stream: bool) -> Self {
123 self.stream = Some(stream);
124 self
125 }
126
127 pub fn with_top_k(mut self, top_k: u32) -> Self {
128 self.top_k = Some(top_k);
129 self
130 }
131
132 pub fn with_top_p(mut self, top_p: f32) -> Self {
133 self.top_p = Some(top_p);
134 self
135 }
136
137 pub fn with_tools(mut self, tools: Vec<Tool>) -> Self {
138 self.tools = Some(tools);
139 self
140 }
141
142 pub fn with_tool_choice(mut self, tool_choice: ToolChoice) -> Self {
143 self.tool_choice = Some(tool_choice);
144 self
145 }
146
147 pub fn with_thinking(mut self, thinking: Thinking) -> Self {
148 self.thinking = Some(thinking);
149 self
150 }
151
152 pub fn with_metadata(mut self, metadata: Metadata) -> Self {
153 self.metadata = Some(metadata);
154 self
155 }
156}
157
158#[derive(Debug, Serialize, Deserialize, Clone)]
160pub struct Message {
161 pub role: Role,
163 #[serde(flatten)]
165 pub content: MessageContent,
166}
167
168#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
170#[serde(rename_all = "lowercase")]
171pub enum Role {
172 User,
173 Assistant,
174}
175
176#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
178#[serde(untagged)]
179pub enum MessageContent {
180 Text { content: String },
182 Blocks { content: Vec<ContentBlock> },
184}
185
186#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
188#[serde(tag = "type")]
189pub enum ContentBlock {
190 #[serde(rename = "text")]
192 Text { text: String },
193 #[serde(rename = "image")]
195 Image { source: ImageSource },
196 #[serde(rename = "tool_use")]
198 ToolUse {
199 id: String,
200 name: String,
201 input: serde_json::Value,
202 },
203 #[serde(rename = "tool_result")]
205 ToolResult {
206 tool_use_id: String,
207 content: String,
208 },
209 #[serde(rename = "thinking")]
211 Thinking { thinking: String, signature: String },
212 #[serde(rename = "redacted_thinking")]
214 RedactedThinking { data: String },
215}
216
217#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
219pub struct ImageSource {
220 #[serde(rename = "type")]
222 pub type_: String,
223 pub media_type: String,
225 pub data: String,
227}
228
229#[derive(Debug, Serialize, Deserialize)]
231pub struct Tool {
232 pub name: String,
234 #[serde(skip_serializing_if = "Option::is_none")]
236 pub description: Option<String>,
237 pub input_schema: serde_json::Value,
239}
240
241#[derive(Debug, Serialize, Deserialize)]
243#[serde(tag = "type")]
244pub enum ToolChoice {
245 #[serde(rename = "auto")]
247 Auto,
248 #[serde(rename = "any")]
250 Any,
251 #[serde(rename = "tool")]
253 Tool { name: String },
254 #[serde(rename = "none")]
256 None,
257}
258
259#[derive(Debug, Deserialize, Serialize)]
261pub struct Thinking {
262 pub budget_tokens: usize,
264 #[serde(rename = "type")]
265 pub type_: ThinkingType,
266}
267
268#[derive(Debug, Deserialize, Serialize)]
269pub enum ThinkingType {
270 #[serde(rename = "enabled")]
271 Enabled,
272}
273#[derive(Debug, Serialize, Deserialize, Default)]
275pub struct Metadata {
276 #[serde(flatten)]
278 pub fields: std::collections::HashMap<String, String>,
279}
280
281#[derive(Debug, Deserialize, Serialize)]
283pub struct CreateMessageResponse {
284 pub content: Vec<ContentBlock>,
286 pub id: String,
288 pub model: String,
290 pub role: Role,
292 pub stop_reason: Option<StopReason>,
294 pub stop_sequence: Option<String>,
296 #[serde(rename = "type")]
298 pub type_: String,
299 pub usage: Usage,
301}
302
303#[derive(Debug, Deserialize, Serialize)]
305#[serde(rename_all = "snake_case")]
306pub enum StopReason {
307 EndTurn,
308 MaxTokens,
309 StopSequence,
310 ToolUse,
311 Refusal,
312}
313
314#[derive(Debug, Deserialize, Serialize)]
316pub struct Usage {
317 pub input_tokens: u32,
319 pub output_tokens: u32,
321}
322
323#[derive(Debug, Deserialize, Serialize)]
324pub struct StreamUsage {
325 #[serde(default)]
327 pub input_tokens: u32,
328 pub output_tokens: u32,
330}
331
332impl Message {
333 pub fn new_text(role: Role, text: impl Into<String>) -> Self {
335 Self {
336 role,
337 content: MessageContent::Text {
338 content: text.into(),
339 },
340 }
341 }
342
343 pub fn new_blocks(role: Role, blocks: Vec<ContentBlock>) -> Self {
345 Self {
346 role,
347 content: MessageContent::Blocks { content: blocks },
348 }
349 }
350}
351
352impl ContentBlock {
354 pub fn text(text: impl Into<String>) -> Self {
356 Self::Text { text: text.into() }
357 }
358
359 pub fn image(
361 type_: impl Into<String>,
362 media_type: impl Into<String>,
363 data: impl Into<String>,
364 ) -> Self {
365 Self::Image {
366 source: ImageSource {
367 type_: type_.into(),
368 media_type: media_type.into(),
369 data: data.into(),
370 },
371 }
372 }
373}
374
375#[derive(Debug, Serialize, Default)]
376pub struct CountMessageTokensParams {
377 pub model: String,
378 pub messages: Vec<Message>,
379}
380
381#[derive(Debug, Deserialize)]
382pub struct CountMessageTokensResponse {
383 pub input_tokens: u32,
384}
385
386#[derive(Debug, Deserialize, Serialize)]
387#[serde(tag = "type")]
388pub enum StreamEvent {
389 #[serde(rename = "message_start")]
390 MessageStart { message: MessageStartContent },
391 #[serde(rename = "content_block_start")]
392 ContentBlockStart {
393 index: usize,
394 content_block: ContentBlock,
395 },
396 #[serde(rename = "content_block_delta")]
397 ContentBlockDelta {
398 index: usize,
399 delta: ContentBlockDelta,
400 },
401 #[serde(rename = "content_block_stop")]
402 ContentBlockStop { index: usize },
403 #[serde(rename = "message_delta")]
404 MessageDelta {
405 delta: MessageDeltaContent,
406 usage: Option<StreamUsage>,
407 },
408 #[serde(rename = "message_stop")]
409 MessageStop,
410 #[serde(rename = "ping")]
411 Ping,
412 #[serde(rename = "error")]
413 Error { error: StreamError },
414}
415
416#[derive(Debug, Deserialize, Serialize)]
417pub struct MessageStartContent {
418 pub id: String,
419 #[serde(rename = "type")]
420 pub type_: String,
421 pub role: Role,
422 pub content: Vec<ContentBlock>,
423 pub model: String,
424 pub stop_reason: Option<StopReason>,
425 pub stop_sequence: Option<String>,
426 pub usage: Usage,
427}
428
429#[derive(Debug, Deserialize, Serialize)]
430#[serde(tag = "type")]
431pub enum ContentBlockDelta {
432 #[serde(rename = "text_delta")]
433 TextDelta { text: String },
434 #[serde(rename = "input_json_delta")]
435 InputJsonDelta { partial_json: String },
436 #[serde(rename = "thinking_delta")]
437 ThinkingDelta { thinking: String },
438 #[serde(rename = "signature_delta")]
439 SignatureDelta { signature: String },
440}
441
442#[derive(Debug, Deserialize, Serialize)]
443pub struct MessageDeltaContent {
444 pub stop_reason: Option<StopReason>,
445 pub stop_sequence: Option<String>,
446}
447
448#[derive(Debug, Deserialize, Serialize)]
449pub struct StreamError {
450 #[serde(rename = "type")]
451 pub type_: String,
452 pub message: String,
453}