1use std::ops::{Deref, DerefMut};
2
3use derive_builder::Builder;
4use serde::{Deserialize, Serialize, Serializer};
5use serde_json::Value;
6
7use crate::messages;
8
9#[derive(Serialize, Deserialize, Debug, Clone)]
10pub struct Usage {
11 pub input_tokens: Option<u32>,
12 pub output_tokens: Option<u32>,
13}
14
15#[derive(Clone, Debug, Deserialize)]
16pub enum ToolChoice {
17 Auto,
18 Any,
19 Tool(String),
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize, Builder, PartialEq)]
23#[builder(setter(into, strip_option))]
24pub struct Message {
25 pub role: MessageRole,
26 pub content: MessageContentList,
27}
28
29impl Message {
30 pub fn tool_uses(&self) -> Vec<ToolUse> {
32 self.content
33 .0
34 .iter()
35 .filter(|c| matches!(c, MessageContent::ToolUse(_)))
36 .map(|c| match c {
37 MessageContent::ToolUse(tool_use) => tool_use.clone(),
38 _ => unreachable!(),
39 })
40 .collect()
41 }
42
43 pub fn text(&self) -> Option<String> {
45 self.content
46 .0
47 .iter()
48 .filter_map(|c| match c {
49 MessageContent::Text(text) => Some(text.text.clone()),
50 _ => None,
51 })
52 .next()
53 }
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
57pub struct MessageContentList(pub Vec<MessageContent>);
58
59impl Deref for MessageContentList {
60 type Target = Vec<MessageContent>;
61
62 fn deref(&self) -> &Self::Target {
63 &self.0
64 }
65}
66
67impl DerefMut for MessageContentList {
68 fn deref_mut(&mut self) -> &mut Self::Target {
69 &mut self.0
70 }
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
74#[serde(rename_all = "snake_case")]
75pub enum MessageRole {
76 User,
77 Assistant,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
81#[builder(setter(into, strip_option))]
82pub struct CreateMessagesRequest {
83 messages: Vec<Message>,
84 model: String,
85 #[builder(default = messages::DEFAULT_MAX_TOKENS)]
86 max_tokens: i32,
87 #[builder(default)]
88 #[serde(skip_serializing_if = "Option::is_none")]
89 metadata: Option<serde_json::Map<String, Value>>,
90 #[serde(skip_serializing_if = "Option::is_none")]
91 #[builder(default)]
92 stop_sequences: Option<Vec<String>>,
93 #[builder(default = "false")]
94 stream: bool, #[serde(skip_serializing_if = "Option::is_none")]
96 #[builder(default)]
97 temperature: Option<f32>, #[serde(skip_serializing_if = "Option::is_none")]
99 #[builder(default)]
100 tool_choice: Option<ToolChoice>,
101 #[serde(skip_serializing_if = "Option::is_none")]
103 #[builder(default)]
104 tools: Option<Vec<serde_json::Map<String, Value>>>,
105 #[serde(skip_serializing_if = "Option::is_none")]
106 #[builder(default)]
107 top_k: Option<u32>, #[serde(skip_serializing_if = "Option::is_none")]
109 #[builder(default)]
110 top_p: Option<f32>, #[serde(skip_serializing_if = "Option::is_none")]
112 #[builder(default)]
113 system: Option<String>, }
115
116#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
126#[builder(setter(into, strip_option))]
127pub struct CreateMessagesResponse {
128 #[serde(default)]
129 pub id: Option<String>,
130 #[serde(default)]
131 pub content: Option<Vec<MessageContent>>,
132 #[serde(default)]
133 pub model: Option<String>,
134 #[serde(default)]
135 pub stop_reason: Option<String>,
136 #[serde(default)]
137 pub stop_sequence: Option<String>,
138 #[serde(default)]
139 pub usage: Option<Usage>,
140}
141
142impl CreateMessagesResponse {
143 pub fn messages(&self) -> Vec<Message> {
145 let Some(content) = &self.content else {
146 return vec![];
147 };
148 content
149 .iter()
150 .map(|c| Message {
151 role: MessageRole::Assistant,
152 content: c.clone().into(),
153 })
154 .collect()
155 }
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
159#[serde(tag = "type", rename_all = "snake_case")]
160pub enum MessageContent {
161 ToolUse(ToolUse),
162 ToolResult(ToolResult),
163 Text(Text),
164 }
166
167#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default, Builder)]
168#[builder(setter(into, strip_option), default)]
169pub struct ToolUse {
170 pub id: String,
171 pub input: Value,
172 pub name: String,
173}
174
175impl From<ToolUse> for MessageContent {
176 fn from(tool_use: ToolUse) -> Self {
177 MessageContent::ToolUse(tool_use)
178 }
179}
180
181impl From<ToolUse> for MessageContentList {
182 fn from(tool_use: ToolUse) -> Self {
183 MessageContentList(vec![tool_use.into()])
184 }
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default, Builder)]
188#[builder(setter(into, strip_option), default)]
189pub struct ToolResult {
190 pub tool_use_id: String,
191 pub content: Option<String>,
192 pub is_error: bool,
193}
194
195impl From<ToolResult> for MessageContent {
196 fn from(tool_result: ToolResult) -> Self {
197 MessageContent::ToolResult(tool_result)
198 }
199}
200
201impl From<ToolResult> for MessageContentList {
202 fn from(tool_result: ToolResult) -> Self {
203 MessageContentList(vec![tool_result.into()])
204 }
205}
206
207#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default, Builder)]
208#[builder(setter(into, strip_option), default)]
209pub struct Text {
210 pub text: String,
211}
212
213impl From<Text> for MessageContent {
214 fn from(text: Text) -> Self {
215 MessageContent::Text(text)
216 }
217}
218
219impl From<Text> for MessageContentList {
220 fn from(text: Text) -> Self {
221 MessageContentList(vec![text.into()])
222 }
223}
224
225impl<S: AsRef<str>> From<S> for MessageContent {
226 fn from(s: S) -> Self {
227 MessageContent::Text(Text {
228 text: s.as_ref().to_string(),
229 })
230 }
231}
232
233impl<S: AsRef<str>> From<S> for Message {
234 fn from(s: S) -> Self {
235 MessageBuilder::default()
236 .role(MessageRole::User)
237 .content(s.as_ref().to_string())
238 .build()
239 .expect("infallible")
240 }
241}
242
243impl<S: AsRef<str>> From<S> for MessageContentList {
261 fn from(s: S) -> Self {
262 MessageContentList(vec![s.as_ref().into()])
263 }
264}
265
266impl From<MessageContent> for MessageContentList {
267 fn from(content: MessageContent) -> Self {
268 MessageContentList(vec![content])
269 }
270}
271
272impl Serialize for ToolChoice {
274 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
275 where
276 S: Serializer,
277 {
278 match self {
279 ToolChoice::Auto => {
280 serde::Serialize::serialize(&serde_json::json!({"type": "auto"}), serializer)
281 }
282 ToolChoice::Any => {
283 serde::Serialize::serialize(&serde_json::json!({"type": "any"}), serializer)
284 }
285 ToolChoice::Tool(name) => serde::Serialize::serialize(
286 &serde_json::json!({"type": "tool", "name": name}),
287 serializer,
288 ),
289 }
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use serde_json::json;
296
297 use super::*;
298
299 #[test_log::test(tokio::test)]
300 async fn test_deserialize_response() {
301 let response = json!({
302 "id":"msg_01KkaCASJuaAgTWD2wqdbwC8",
303 "type":"message",
304 "role":"assistant",
305 "model":"claude-3-5-sonnet-20241022",
306 "content":[
307 {"type":"text",
308 "text":"Hi! How can I help you today?"}],
309 "stop_reason":"end_turn",
310 "stop_sequence":null,
311 "usage":{"input_tokens":10,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":12}}).to_string();
312
313 assert!(serde_json::from_str::<CreateMessagesResponse>(&response).is_ok());
314 }
315
316 #[test_log::test(tokio::test)]
317 async fn test_from_str() {
318 let message: Message = "Hello world!".into();
319
320 assert_eq!(
321 message,
322 Message {
323 role: MessageRole::User,
324 content: MessageContentList(vec![MessageContent::Text(Text {
325 text: "Hello world!".to_string()
326 })])
327 }
328 );
329
330 assert_eq!(message.text(), Some("Hello world!".to_string()));
331 }
332}