1use std::{
2 ops::{Deref, DerefMut},
3 pin::Pin,
4};
5
6use derive_builder::Builder;
7use serde::{Deserialize, Serialize, Serializer};
8use serde_json::Value;
9use tokio_stream::Stream;
10
11use crate::{errors::AnthropicError, messages};
12
13#[derive(Serialize, Deserialize, Debug, Clone)]
14pub struct Usage {
15 pub input_tokens: Option<u32>,
16 pub output_tokens: Option<u32>,
17}
18
19#[derive(Clone, Debug, Deserialize)]
20pub enum ToolChoice {
21 Auto,
22 Any,
23 Tool(String),
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize, Builder, PartialEq)]
27#[builder(setter(into, strip_option))]
28pub struct Message {
29 pub role: MessageRole,
30 pub content: MessageContentList,
31}
32
33impl Message {
34 pub fn tool_uses(&self) -> Vec<ToolUse> {
36 self.content
37 .0
38 .iter()
39 .filter(|c| matches!(c, MessageContent::ToolUse(_)))
40 .map(|c| match c {
41 MessageContent::ToolUse(tool_use) => tool_use.clone(),
42 _ => unreachable!(),
43 })
44 .collect()
45 }
46
47 pub fn text(&self) -> Option<String> {
49 self.content
50 .0
51 .iter()
52 .filter_map(|c| match c {
53 MessageContent::Text(text) => Some(text.text.clone()),
54 _ => None,
55 })
56 .next()
57 }
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
61pub struct MessageContentList(pub Vec<MessageContent>);
62
63impl Deref for MessageContentList {
64 type Target = Vec<MessageContent>;
65
66 fn deref(&self) -> &Self::Target {
67 &self.0
68 }
69}
70
71impl DerefMut for MessageContentList {
72 fn deref_mut(&mut self) -> &mut Self::Target {
73 &mut self.0
74 }
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
78#[serde(rename_all = "snake_case")]
79pub enum MessageRole {
80 User,
81 Assistant,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
85#[builder(setter(into, strip_option))]
86pub struct CreateMessagesRequest {
87 pub messages: Vec<Message>,
88 pub model: String,
89 #[builder(default = messages::DEFAULT_MAX_TOKENS)]
90 pub max_tokens: i32,
91 #[builder(default)]
92 #[serde(skip_serializing_if = "Option::is_none")]
93 pub metadata: Option<serde_json::Map<String, Value>>,
94 #[serde(skip_serializing_if = "Option::is_none")]
95 #[builder(default)]
96 pub stop_sequences: Option<Vec<String>>,
97 #[builder(default = "false")]
98 pub stream: bool, #[serde(skip_serializing_if = "Option::is_none")]
100 #[builder(default)]
101 pub temperature: Option<f32>, #[serde(skip_serializing_if = "Option::is_none")]
103 #[builder(default)]
104 pub tool_choice: Option<ToolChoice>,
105 #[serde(skip_serializing_if = "Option::is_none")]
107 #[builder(default)]
108 pub tools: Option<Vec<serde_json::Map<String, Value>>>,
109 #[serde(skip_serializing_if = "Option::is_none")]
110 #[builder(default)]
111 pub top_k: Option<u32>, #[serde(skip_serializing_if = "Option::is_none")]
113 #[builder(default)]
114 pub top_p: Option<f32>, #[serde(skip_serializing_if = "Option::is_none")]
116 #[builder(default)]
117 pub system: Option<String>, }
119
120#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
121#[builder(setter(into, strip_option))]
122pub struct CreateMessagesResponse {
123 #[serde(default)]
124 pub id: Option<String>,
125 #[serde(default)]
126 pub content: Option<Vec<MessageContent>>,
127 #[serde(default)]
128 pub model: Option<String>,
129 #[serde(default)]
130 pub stop_reason: Option<String>,
131 #[serde(default)]
132 pub stop_sequence: Option<String>,
133 #[serde(default)]
134 pub usage: Option<Usage>,
135}
136
137impl CreateMessagesResponse {
138 pub fn messages(&self) -> Vec<Message> {
140 let Some(content) = &self.content else {
141 return vec![];
142 };
143 content
144 .iter()
145 .map(|c| Message {
146 role: MessageRole::Assistant,
147 content: c.clone().into(),
148 })
149 .collect()
150 }
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
154#[serde(tag = "type", rename_all = "snake_case")]
155pub enum MessageContent {
156 ToolUse(ToolUse),
157 ToolResult(ToolResult),
158 Text(Text),
159 }
161
162impl MessageContent {
163 pub fn as_tool_use(&self) -> Option<&ToolUse> {
164 if let MessageContent::ToolUse(tool_use) = self {
165 Some(tool_use)
166 } else {
167 None
168 }
169 }
170
171 pub fn as_tool_result(&self) -> Option<&ToolResult> {
172 if let MessageContent::ToolResult(tool_result) = self {
173 Some(tool_result)
174 } else {
175 None
176 }
177 }
178
179 pub fn as_text(&self) -> Option<&Text> {
180 if let MessageContent::Text(text) = self {
181 Some(text)
182 } else {
183 None
184 }
185 }
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default, Builder)]
189#[builder(setter(into, strip_option), default)]
190pub struct ToolUse {
191 pub id: String,
192 pub input: Value,
193 pub name: String,
194}
195
196impl From<ToolUse> for MessageContent {
197 fn from(tool_use: ToolUse) -> Self {
198 MessageContent::ToolUse(tool_use)
199 }
200}
201
202impl From<ToolUse> for MessageContentList {
203 fn from(tool_use: ToolUse) -> Self {
204 MessageContentList(vec![tool_use.into()])
205 }
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default, Builder)]
209#[builder(setter(into, strip_option), default)]
210pub struct ToolResult {
211 pub tool_use_id: String,
212 pub content: Option<String>,
213 pub is_error: bool,
214}
215
216impl From<ToolResult> for MessageContent {
217 fn from(tool_result: ToolResult) -> Self {
218 MessageContent::ToolResult(tool_result)
219 }
220}
221
222impl From<ToolResult> for MessageContentList {
223 fn from(tool_result: ToolResult) -> Self {
224 MessageContentList(vec![tool_result.into()])
225 }
226}
227
228#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default, Builder)]
229#[builder(setter(into, strip_option), default)]
230pub struct Text {
231 pub text: String,
232}
233
234impl<S: AsRef<str>> From<S> for Text {
235 fn from(s: S) -> Self {
236 Text {
237 text: s.as_ref().to_string(),
238 }
239 }
240}
241
242impl From<Text> for MessageContent {
243 fn from(text: Text) -> Self {
244 MessageContent::Text(text)
245 }
246}
247
248impl From<Text> for MessageContentList {
249 fn from(text: Text) -> Self {
250 MessageContentList(vec![text.into()])
251 }
252}
253
254impl<S: AsRef<str>> From<S> for MessageContent {
255 fn from(s: S) -> Self {
256 MessageContent::Text(Text {
257 text: s.as_ref().to_string(),
258 })
259 }
260}
261
262impl<S: AsRef<str>> From<S> for Message {
263 fn from(s: S) -> Self {
264 MessageBuilder::default()
265 .role(MessageRole::User)
266 .content(s.as_ref().to_string())
267 .build()
268 .expect("infallible")
269 }
270}
271
272impl<S: AsRef<str>> From<S> for MessageContentList {
274 fn from(s: S) -> Self {
275 MessageContentList(vec![s.as_ref().into()])
276 }
277}
278
279impl From<MessageContent> for MessageContentList {
280 fn from(content: MessageContent) -> Self {
281 MessageContentList(vec![content])
282 }
283}
284
285impl Serialize for ToolChoice {
286 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
287 where
288 S: Serializer,
289 {
290 match self {
291 ToolChoice::Auto => {
292 serde::Serialize::serialize(&serde_json::json!({"type": "auto"}), serializer)
293 }
294 ToolChoice::Any => {
295 serde::Serialize::serialize(&serde_json::json!({"type": "any"}), serializer)
296 }
297 ToolChoice::Tool(name) => serde::Serialize::serialize(
298 &serde_json::json!({"type": "tool", "name": name}),
299 serializer,
300 ),
301 }
302 }
303}
304#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
305#[serde(rename_all = "snake_case", tag = "type")]
306pub enum ContentBlockDelta {
307 TextDelta { text: String },
308}
309
310#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
311pub struct MessageDeltaUsage {
312 pub output_tokens: usize,
313}
314
315#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
316pub struct MessageDelta {
317 pub stop_reason: Option<String>,
318 pub stop_sequence: Option<String>,
319}
320
321#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
322#[serde(rename_all = "snake_case", tag = "type")]
323pub enum MessagesStreamEvent {
324 MessageStart {
325 message: Message,
326 },
327 ContentBlockStart {
328 index: usize,
329 content_block: MessageContent,
330 },
331 ContentBlockDelta {
332 index: usize,
333 delta: ContentBlockDelta,
334 },
335 ContentBlockStop {
336 index: usize,
337 },
338 MessageDelta {
339 delta: MessageDelta,
340 usage: MessageDeltaUsage,
341 },
342 MessageStop,
343}
344
345pub type CreateMessagesResponseStream =
346 Pin<Box<dyn Stream<Item = Result<MessagesStreamEvent, AnthropicError>> + Send>>;
347
348#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
349pub struct ListModelsResponse {
350 #[serde(default)]
351 pub data: Vec<Model>,
352
353 #[serde(default)]
354 pub first_id: Option<String>,
355 pub has_more: bool,
356 #[serde(default)]
357 pub last_id: Option<String>,
358}
359
360#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
361pub struct Model {
362 pub created_at: String,
363 pub display_name: String,
364 pub id: String,
365 #[serde(rename = "type")]
366 pub model_type: String,
367}
368
369pub type GetModelResponse = Model;
370
371#[cfg(test)]
372mod tests {
373 use serde_json::json;
374
375 use super::*;
376
377 #[test_log::test(tokio::test)]
378 async fn test_deserialize_response() {
379 let response = json!({
380 "id":"msg_01KkaCASJuaAgTWD2wqdbwC8",
381 "type":"message",
382 "role":"assistant",
383 "model":"claude-3-5-sonnet-20241022",
384 "content":[
385 {"type":"text",
386 "text":"Hi! How can I help you today?"}],
387 "stop_reason":"end_turn",
388 "stop_sequence":null,
389 "usage":{"input_tokens":10,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":12}}).to_string();
390
391 assert!(serde_json::from_str::<CreateMessagesResponse>(&response).is_ok());
392 }
393
394 #[test_log::test(tokio::test)]
395 async fn test_from_str() {
396 let message: Message = "Hello world!".into();
397
398 assert_eq!(
399 message,
400 Message {
401 role: MessageRole::User,
402 content: MessageContentList(vec![MessageContent::Text(Text {
403 text: "Hello world!".to_string()
404 })])
405 }
406 );
407
408 assert_eq!(message.text(), Some("Hello world!".to_string()));
409 }
410}