async_anthropic/
types.rs

1use std::{
2    ops::{Deref, DerefMut},
3    pin::Pin,
4};
5
6use derive_builder::Builder;
7use serde::{Deserialize, Serialize, Serializer};
8use serde_json::Value;
9use tokio_stream::Stream;
10
11use crate::{errors::AnthropicError, messages};
12
13#[derive(Serialize, Deserialize, Debug, Clone)]
14pub struct Usage {
15    pub input_tokens: Option<u32>,
16    pub output_tokens: Option<u32>,
17}
18
19#[derive(Clone, Debug, Deserialize)]
20pub enum ToolChoice {
21    Auto,
22    Any,
23    Tool(String),
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize, Builder, PartialEq)]
27#[builder(setter(into, strip_option))]
28pub struct Message {
29    pub role: MessageRole,
30    pub content: MessageContentList,
31}
32
33impl Message {
34    /// Returns all the tool uses in the message
35    pub fn tool_uses(&self) -> Vec<ToolUse> {
36        self.content
37            .0
38            .iter()
39            .filter(|c| matches!(c, MessageContent::ToolUse(_)))
40            .map(|c| match c {
41                MessageContent::ToolUse(tool_use) => tool_use.clone(),
42                _ => unreachable!(),
43            })
44            .collect()
45    }
46
47    /// Returns the first text content in the message
48    pub fn text(&self) -> Option<String> {
49        self.content
50            .0
51            .iter()
52            .filter_map(|c| match c {
53                MessageContent::Text(text) => Some(text.text.clone()),
54                _ => None,
55            })
56            .next()
57    }
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
61pub struct MessageContentList(pub Vec<MessageContent>);
62
63impl Deref for MessageContentList {
64    type Target = Vec<MessageContent>;
65
66    fn deref(&self) -> &Self::Target {
67        &self.0
68    }
69}
70
71impl DerefMut for MessageContentList {
72    fn deref_mut(&mut self) -> &mut Self::Target {
73        &mut self.0
74    }
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
78#[serde(rename_all = "snake_case")]
79pub enum MessageRole {
80    User,
81    Assistant,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
85#[builder(setter(into, strip_option))]
86pub struct CreateMessagesRequest {
87    pub messages: Vec<Message>,
88    pub model: String,
89    #[builder(default = messages::DEFAULT_MAX_TOKENS)]
90    pub max_tokens: i32,
91    #[builder(default)]
92    #[serde(skip_serializing_if = "Option::is_none")]
93    pub metadata: Option<serde_json::Map<String, Value>>,
94    #[serde(skip_serializing_if = "Option::is_none")]
95    #[builder(default)]
96    pub stop_sequences: Option<Vec<String>>,
97    #[builder(default = "false")]
98    pub stream: bool, // Optional default false
99    #[serde(skip_serializing_if = "Option::is_none")]
100    #[builder(default)]
101    pub temperature: Option<f32>, // 0 < x < 1
102    #[serde(skip_serializing_if = "Option::is_none")]
103    #[builder(default)]
104    pub tool_choice: Option<ToolChoice>,
105    // TODO: Type this
106    #[serde(skip_serializing_if = "Option::is_none")]
107    #[builder(default)]
108    pub tools: Option<Vec<serde_json::Map<String, Value>>>,
109    #[serde(skip_serializing_if = "Option::is_none")]
110    #[builder(default)]
111    pub top_k: Option<u32>, // > 0
112    #[serde(skip_serializing_if = "Option::is_none")]
113    #[builder(default)]
114    pub top_p: Option<f32>, // 0 < x < 1
115    #[serde(skip_serializing_if = "Option::is_none")]
116    #[builder(default)]
117    pub system: Option<String>, // 0 < x < 1
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
121#[builder(setter(into, strip_option))]
122pub struct CreateMessagesResponse {
123    #[serde(default)]
124    pub id: Option<String>,
125    #[serde(default)]
126    pub content: Option<Vec<MessageContent>>,
127    #[serde(default)]
128    pub model: Option<String>,
129    #[serde(default)]
130    pub stop_reason: Option<String>,
131    #[serde(default)]
132    pub stop_sequence: Option<String>,
133    #[serde(default)]
134    pub usage: Option<Usage>,
135}
136
137impl CreateMessagesResponse {
138    /// Returns the content as Messages so they are more easily reusable
139    pub fn messages(&self) -> Vec<Message> {
140        let Some(content) = &self.content else {
141            return vec![];
142        };
143        content
144            .iter()
145            .map(|c| Message {
146                role: MessageRole::Assistant,
147                content: c.clone().into(),
148            })
149            .collect()
150    }
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
154#[serde(tag = "type", rename_all = "snake_case")]
155pub enum MessageContent {
156    ToolUse(ToolUse),
157    ToolResult(ToolResult),
158    Text(Text),
159    // TODO: Implement images and documents
160}
161
162impl MessageContent {
163    pub fn as_tool_use(&self) -> Option<&ToolUse> {
164        if let MessageContent::ToolUse(tool_use) = self {
165            Some(tool_use)
166        } else {
167            None
168        }
169    }
170
171    pub fn as_tool_result(&self) -> Option<&ToolResult> {
172        if let MessageContent::ToolResult(tool_result) = self {
173            Some(tool_result)
174        } else {
175            None
176        }
177    }
178
179    pub fn as_text(&self) -> Option<&Text> {
180        if let MessageContent::Text(text) = self {
181            Some(text)
182        } else {
183            None
184        }
185    }
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default, Builder)]
189#[builder(setter(into, strip_option), default)]
190pub struct ToolUse {
191    pub id: String,
192    pub input: Value,
193    pub name: String,
194}
195
196impl From<ToolUse> for MessageContent {
197    fn from(tool_use: ToolUse) -> Self {
198        MessageContent::ToolUse(tool_use)
199    }
200}
201
202impl From<ToolUse> for MessageContentList {
203    fn from(tool_use: ToolUse) -> Self {
204        MessageContentList(vec![tool_use.into()])
205    }
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default, Builder)]
209#[builder(setter(into, strip_option), default)]
210pub struct ToolResult {
211    pub tool_use_id: String,
212    pub content: Option<String>,
213    pub is_error: bool,
214}
215
216impl From<ToolResult> for MessageContent {
217    fn from(tool_result: ToolResult) -> Self {
218        MessageContent::ToolResult(tool_result)
219    }
220}
221
222impl From<ToolResult> for MessageContentList {
223    fn from(tool_result: ToolResult) -> Self {
224        MessageContentList(vec![tool_result.into()])
225    }
226}
227
228#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default, Builder)]
229#[builder(setter(into, strip_option), default)]
230pub struct Text {
231    pub text: String,
232}
233
234impl<S: AsRef<str>> From<S> for Text {
235    fn from(s: S) -> Self {
236        Text {
237            text: s.as_ref().to_string(),
238        }
239    }
240}
241
242impl From<Text> for MessageContent {
243    fn from(text: Text) -> Self {
244        MessageContent::Text(text)
245    }
246}
247
248impl From<Text> for MessageContentList {
249    fn from(text: Text) -> Self {
250        MessageContentList(vec![text.into()])
251    }
252}
253
254impl<S: AsRef<str>> From<S> for MessageContent {
255    fn from(s: S) -> Self {
256        MessageContent::Text(Text {
257            text: s.as_ref().to_string(),
258        })
259    }
260}
261
262impl<S: AsRef<str>> From<S> for Message {
263    fn from(s: S) -> Self {
264        MessageBuilder::default()
265            .role(MessageRole::User)
266            .content(s.as_ref().to_string())
267            .build()
268            .expect("infallible")
269    }
270}
271
272// Any single AsRef<str> can be converted to a MessageContent, in a list as a single item
273impl<S: AsRef<str>> From<S> for MessageContentList {
274    fn from(s: S) -> Self {
275        MessageContentList(vec![s.as_ref().into()])
276    }
277}
278
279impl From<MessageContent> for MessageContentList {
280    fn from(content: MessageContent) -> Self {
281        MessageContentList(vec![content])
282    }
283}
284
285impl Serialize for ToolChoice {
286    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
287    where
288        S: Serializer,
289    {
290        match self {
291            ToolChoice::Auto => {
292                serde::Serialize::serialize(&serde_json::json!({"type": "auto"}), serializer)
293            }
294            ToolChoice::Any => {
295                serde::Serialize::serialize(&serde_json::json!({"type": "any"}), serializer)
296            }
297            ToolChoice::Tool(name) => serde::Serialize::serialize(
298                &serde_json::json!({"type": "tool", "name": name}),
299                serializer,
300            ),
301        }
302    }
303}
304#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
305#[serde(rename_all = "snake_case", tag = "type")]
306pub enum ContentBlockDelta {
307    TextDelta { text: String },
308}
309
310#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
311pub struct MessageDeltaUsage {
312    pub output_tokens: usize,
313}
314
315#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
316pub struct MessageDelta {
317    pub stop_reason: Option<String>,
318    pub stop_sequence: Option<String>,
319}
320
321#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
322#[serde(rename_all = "snake_case", tag = "type")]
323pub enum MessagesStreamEvent {
324    MessageStart {
325        message: Message,
326    },
327    ContentBlockStart {
328        index: usize,
329        content_block: MessageContent,
330    },
331    ContentBlockDelta {
332        index: usize,
333        delta: ContentBlockDelta,
334    },
335    ContentBlockStop {
336        index: usize,
337    },
338    MessageDelta {
339        delta: MessageDelta,
340        usage: MessageDeltaUsage,
341    },
342    MessageStop,
343}
344
345pub type CreateMessagesResponseStream =
346    Pin<Box<dyn Stream<Item = Result<MessagesStreamEvent, AnthropicError>> + Send>>;
347
348#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
349pub struct ListModelsResponse {
350    #[serde(default)]
351    pub data: Vec<Model>,
352
353    #[serde(default)]
354    pub first_id: Option<String>,
355    pub has_more: bool,
356    #[serde(default)]
357    pub last_id: Option<String>,
358}
359
360#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
361pub struct Model {
362    pub created_at: String,
363    pub display_name: String,
364    pub id: String,
365    #[serde(rename = "type")]
366    pub model_type: String,
367}
368
369pub type GetModelResponse = Model;
370
371#[cfg(test)]
372mod tests {
373    use serde_json::json;
374
375    use super::*;
376
377    #[test_log::test(tokio::test)]
378    async fn test_deserialize_response() {
379        let response = json!({
380        "id":"msg_01KkaCASJuaAgTWD2wqdbwC8",
381        "type":"message",
382        "role":"assistant",
383        "model":"claude-3-5-sonnet-20241022",
384        "content":[
385            {"type":"text",
386        "text":"Hi! How can I help you today?"}],
387        "stop_reason":"end_turn",
388        "stop_sequence":null,
389        "usage":{"input_tokens":10,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":12}}).to_string();
390
391        assert!(serde_json::from_str::<CreateMessagesResponse>(&response).is_ok());
392    }
393
394    #[test_log::test(tokio::test)]
395    async fn test_from_str() {
396        let message: Message = "Hello world!".into();
397
398        assert_eq!(
399            message,
400            Message {
401                role: MessageRole::User,
402                content: MessageContentList(vec![MessageContent::Text(Text {
403                    text: "Hello world!".to_string()
404                })])
405            }
406        );
407
408        assert_eq!(message.text(), Some("Hello world!".to_string()));
409    }
410}