async_anthropic/
types.rs

1use std::ops::{Deref, DerefMut};
2
3use derive_builder::Builder;
4use serde::{Deserialize, Serialize, Serializer};
5use serde_json::Value;
6
7use crate::messages;
8
9#[derive(Serialize, Deserialize, Debug, Clone)]
10pub struct Usage {
11    pub input_tokens: Option<u32>,
12    pub output_tokens: Option<u32>,
13}
14
15#[derive(Clone, Debug, Deserialize)]
16pub enum ToolChoice {
17    Auto,
18    Any,
19    Tool(String),
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize, Builder, PartialEq)]
23#[builder(setter(into, strip_option))]
24pub struct Message {
25    pub role: MessageRole,
26    pub content: MessageContentList,
27}
28
29impl Message {
30    /// Returns all the tool uses in the message
31    pub fn tool_uses(&self) -> Vec<ToolUse> {
32        self.content
33            .0
34            .iter()
35            .filter(|c| matches!(c, MessageContent::ToolUse(_)))
36            .map(|c| match c {
37                MessageContent::ToolUse(tool_use) => tool_use.clone(),
38                _ => unreachable!(),
39            })
40            .collect()
41    }
42
43    /// Returns the first text content in the message
44    pub fn text(&self) -> Option<String> {
45        self.content
46            .0
47            .iter()
48            .filter_map(|c| match c {
49                MessageContent::Text(text) => Some(text.text.clone()),
50                _ => None,
51            })
52            .next()
53    }
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
57pub struct MessageContentList(pub Vec<MessageContent>);
58
59impl Deref for MessageContentList {
60    type Target = Vec<MessageContent>;
61
62    fn deref(&self) -> &Self::Target {
63        &self.0
64    }
65}
66
67impl DerefMut for MessageContentList {
68    fn deref_mut(&mut self) -> &mut Self::Target {
69        &mut self.0
70    }
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
74#[serde(rename_all = "snake_case")]
75pub enum MessageRole {
76    User,
77    Assistant,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
81#[builder(setter(into, strip_option))]
82pub struct CreateMessagesRequest {
83    messages: Vec<Message>,
84    model: String,
85    #[builder(default = messages::DEFAULT_MAX_TOKENS)]
86    max_tokens: i32,
87    #[builder(default)]
88    #[serde(skip_serializing_if = "Option::is_none")]
89    metadata: Option<serde_json::Map<String, Value>>,
90    #[serde(skip_serializing_if = "Option::is_none")]
91    #[builder(default)]
92    stop_sequences: Option<Vec<String>>,
93    #[builder(default = "false")]
94    stream: bool, // Optional default false
95    #[serde(skip_serializing_if = "Option::is_none")]
96    #[builder(default)]
97    temperature: Option<f32>, // 0 < x < 1
98    #[serde(skip_serializing_if = "Option::is_none")]
99    #[builder(default)]
100    tool_choice: Option<ToolChoice>,
101    // TODO: Type this
102    #[serde(skip_serializing_if = "Option::is_none")]
103    #[builder(default)]
104    tools: Option<Vec<serde_json::Map<String, Value>>>,
105    #[serde(skip_serializing_if = "Option::is_none")]
106    #[builder(default)]
107    top_k: Option<u32>, // > 0
108    #[serde(skip_serializing_if = "Option::is_none")]
109    #[builder(default)]
110    top_p: Option<f32>, // 0 < x < 1
111    #[serde(skip_serializing_if = "Option::is_none")]
112    #[builder(default)]
113    system: Option<String>, // 0 < x < 1
114}
115
116// {"id":"msg_01KkaCASJuaAgTWD2wqdbwC8",
117// "type":"message",
118// "role":"assistant",
119// "model":"claude-3-5-sonnet-20241022",
120// "content":[{"type":"text",
121// "text":"Hi! How can I help you today?"}],
122// "stop_reason":"end_turn",
123// "stop_sequence":null,
124// "usage":{"input_tokens":10,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":12}}
125#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
126#[builder(setter(into, strip_option))]
127pub struct CreateMessagesResponse {
128    #[serde(default)]
129    pub id: Option<String>,
130    #[serde(default)]
131    pub content: Option<Vec<MessageContent>>,
132    #[serde(default)]
133    pub model: Option<String>,
134    #[serde(default)]
135    pub stop_reason: Option<String>,
136    #[serde(default)]
137    pub stop_sequence: Option<String>,
138    #[serde(default)]
139    pub usage: Option<Usage>,
140}
141
142impl CreateMessagesResponse {
143    /// Returns the content as Messages so they are more easily reusable
144    pub fn messages(&self) -> Vec<Message> {
145        let Some(content) = &self.content else {
146            return vec![];
147        };
148        content
149            .iter()
150            .map(|c| Message {
151                role: MessageRole::Assistant,
152                content: c.clone().into(),
153            })
154            .collect()
155    }
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
159#[serde(tag = "type", rename_all = "snake_case")]
160pub enum MessageContent {
161    ToolUse(ToolUse),
162    ToolResult(ToolResult),
163    Text(Text),
164    // TODO: Implement images and documents
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default, Builder)]
168#[builder(setter(into, strip_option), default)]
169pub struct ToolUse {
170    pub id: String,
171    pub input: Value,
172    pub name: String,
173}
174
175impl From<ToolUse> for MessageContent {
176    fn from(tool_use: ToolUse) -> Self {
177        MessageContent::ToolUse(tool_use)
178    }
179}
180
181impl From<ToolUse> for MessageContentList {
182    fn from(tool_use: ToolUse) -> Self {
183        MessageContentList(vec![tool_use.into()])
184    }
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default, Builder)]
188#[builder(setter(into, strip_option), default)]
189pub struct ToolResult {
190    pub tool_use_id: String,
191    pub content: Option<String>,
192    pub is_error: bool,
193}
194
195impl From<ToolResult> for MessageContent {
196    fn from(tool_result: ToolResult) -> Self {
197        MessageContent::ToolResult(tool_result)
198    }
199}
200
201impl From<ToolResult> for MessageContentList {
202    fn from(tool_result: ToolResult) -> Self {
203        MessageContentList(vec![tool_result.into()])
204    }
205}
206
207#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default, Builder)]
208#[builder(setter(into, strip_option), default)]
209pub struct Text {
210    pub text: String,
211}
212
213impl From<Text> for MessageContent {
214    fn from(text: Text) -> Self {
215        MessageContent::Text(text)
216    }
217}
218
219impl From<Text> for MessageContentList {
220    fn from(text: Text) -> Self {
221        MessageContentList(vec![text.into()])
222    }
223}
224
225impl<S: AsRef<str>> From<S> for MessageContent {
226    fn from(s: S) -> Self {
227        MessageContent::Text(Text {
228            text: s.as_ref().to_string(),
229        })
230    }
231}
232
233impl<S: AsRef<str>> From<S> for Message {
234    fn from(s: S) -> Self {
235        MessageBuilder::default()
236            .role(MessageRole::User)
237            .content(s.as_ref().to_string())
238            .build()
239            .expect("infallible")
240    }
241}
242
243// Implement for any IntoIterator where item is Into<MessageContent>
244// TODO: Uncomment when we have specialization :')
245// impl<I, T> From<I> for MessageContentList<MessageContent>
246// where
247//     I: IntoIterator<Item = T>,
248//     T: Into<MessageContent>,
249// {
250//     fn from(iter: I) -> Self {
251//         MessageContentList(
252//             iter.into_iter()
253//                 .map(|item| item.into())
254//                 .collect::<Vec<MessageContent>>(),
255//         )
256//     }
257// }
258
259// Any single AsRef<str> can be converted to a MessageContent, in a list as a single item
260impl<S: AsRef<str>> From<S> for MessageContentList {
261    fn from(s: S) -> Self {
262        MessageContentList(vec![s.as_ref().into()])
263    }
264}
265
266impl From<MessageContent> for MessageContentList {
267    fn from(content: MessageContent) -> Self {
268        MessageContentList(vec![content])
269    }
270}
271
272// TODO: check if needed
273impl Serialize for ToolChoice {
274    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
275    where
276        S: Serializer,
277    {
278        match self {
279            ToolChoice::Auto => {
280                serde::Serialize::serialize(&serde_json::json!({"type": "auto"}), serializer)
281            }
282            ToolChoice::Any => {
283                serde::Serialize::serialize(&serde_json::json!({"type": "any"}), serializer)
284            }
285            ToolChoice::Tool(name) => serde::Serialize::serialize(
286                &serde_json::json!({"type": "tool", "name": name}),
287                serializer,
288            ),
289        }
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use serde_json::json;
296
297    use super::*;
298
299    #[test_log::test(tokio::test)]
300    async fn test_deserialize_response() {
301        let response = json!({
302        "id":"msg_01KkaCASJuaAgTWD2wqdbwC8",
303        "type":"message",
304        "role":"assistant",
305        "model":"claude-3-5-sonnet-20241022",
306        "content":[
307            {"type":"text",
308        "text":"Hi! How can I help you today?"}],
309        "stop_reason":"end_turn",
310        "stop_sequence":null,
311        "usage":{"input_tokens":10,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":12}}).to_string();
312
313        assert!(serde_json::from_str::<CreateMessagesResponse>(&response).is_ok());
314    }
315
316    #[test_log::test(tokio::test)]
317    async fn test_from_str() {
318        let message: Message = "Hello world!".into();
319
320        assert_eq!(
321            message,
322            Message {
323                role: MessageRole::User,
324                content: MessageContentList(vec![MessageContent::Text(Text {
325                    text: "Hello world!".to_string()
326                })])
327            }
328        );
329
330        assert_eq!(message.text(), Some("Hello world!".to_string()));
331    }
332}