1use crate::{
4 OneOrMany,
5 completion::{self, CompletionError},
6 json_utils,
7 message::{self, DocumentMediaType, DocumentSourceKind, MessageError, Reasoning},
8 one_or_many::string_or_one_or_many,
9};
10use std::{convert::Infallible, str::FromStr};
11
12use super::client::Client;
13use crate::completion::CompletionRequest;
14use crate::providers::anthropic::streaming::StreamingCompletionResponse;
15use serde::{Deserialize, Serialize};
16use serde_json::json;
17
18pub const CLAUDE_4_OPUS: &str = "claude-opus-4-0";
24
25pub const CLAUDE_4_SONNET: &str = "claude-sonnet-4-0";
27
28pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
30
31pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
33
34pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
36
37pub const CLAUDE_3_OPUS: &str = "claude-3-opus-latest";
39
40pub const CLAUDE_3_SONNET: &str = "claude-3-sonnet-20240229";
42
43pub const CLAUDE_3_HAIKU: &str = "claude-3-haiku-20240307";
45
46pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
47pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
48pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
49
50#[derive(Debug, Deserialize, Serialize)]
51pub struct CompletionResponse {
52 pub content: Vec<Content>,
53 pub id: String,
54 pub model: String,
55 pub role: String,
56 pub stop_reason: Option<String>,
57 pub stop_sequence: Option<String>,
58 pub usage: Usage,
59}
60
61#[derive(Debug, Deserialize, Serialize)]
62pub struct Usage {
63 pub input_tokens: u64,
64 pub cache_read_input_tokens: Option<u64>,
65 pub cache_creation_input_tokens: Option<u64>,
66 pub output_tokens: u64,
67}
68
69impl std::fmt::Display for Usage {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 write!(
72 f,
73 "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
74 self.input_tokens,
75 match self.cache_read_input_tokens {
76 Some(token) => token.to_string(),
77 None => "n/a".to_string(),
78 },
79 match self.cache_creation_input_tokens {
80 Some(token) => token.to_string(),
81 None => "n/a".to_string(),
82 },
83 self.output_tokens
84 )
85 }
86}
87
88#[derive(Debug, Deserialize, Serialize)]
89pub struct ToolDefinition {
90 pub name: String,
91 pub description: Option<String>,
92 pub input_schema: serde_json::Value,
93}
94
95#[derive(Debug, Deserialize, Serialize)]
96#[serde(tag = "type", rename_all = "snake_case")]
97pub enum CacheControl {
98 Ephemeral,
99}
100
101impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
102 type Error = CompletionError;
103
104 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
105 let content = response
106 .content
107 .iter()
108 .map(|content| {
109 Ok(match content {
110 Content::Text { text } => completion::AssistantContent::text(text),
111 Content::ToolUse { id, name, input } => {
112 completion::AssistantContent::tool_call(id, name, input.clone())
113 }
114 _ => {
115 return Err(CompletionError::ResponseError(
116 "Response did not contain a message or tool call".into(),
117 ));
118 }
119 })
120 })
121 .collect::<Result<Vec<_>, _>>()?;
122
123 let choice = OneOrMany::many(content).map_err(|_| {
124 CompletionError::ResponseError(
125 "Response contained no message or tool call (empty)".to_owned(),
126 )
127 })?;
128
129 let usage = completion::Usage {
130 input_tokens: response.usage.input_tokens,
131 output_tokens: response.usage.output_tokens,
132 total_tokens: response.usage.input_tokens + response.usage.output_tokens,
133 };
134
135 Ok(completion::CompletionResponse {
136 choice,
137 usage,
138 raw_response: response,
139 })
140 }
141}
142
143#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
144pub struct Message {
145 pub role: Role,
146 #[serde(deserialize_with = "string_or_one_or_many")]
147 pub content: OneOrMany<Content>,
148}
149
150#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
151#[serde(rename_all = "lowercase")]
152pub enum Role {
153 User,
154 Assistant,
155}
156
157#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
158#[serde(tag = "type", rename_all = "snake_case")]
159pub enum Content {
160 Text {
161 text: String,
162 },
163 Image {
164 source: ImageSource,
165 },
166 ToolUse {
167 id: String,
168 name: String,
169 input: serde_json::Value,
170 },
171 ToolResult {
172 tool_use_id: String,
173 #[serde(deserialize_with = "string_or_one_or_many")]
174 content: OneOrMany<ToolResultContent>,
175 #[serde(skip_serializing_if = "Option::is_none")]
176 is_error: Option<bool>,
177 },
178 Document {
179 source: DocumentSource,
180 },
181 Thinking {
182 thinking: String,
183 signature: Option<String>,
184 },
185}
186
187impl FromStr for Content {
188 type Err = Infallible;
189
190 fn from_str(s: &str) -> Result<Self, Self::Err> {
191 Ok(Content::Text { text: s.to_owned() })
192 }
193}
194
195#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
196#[serde(tag = "type", rename_all = "snake_case")]
197pub enum ToolResultContent {
198 Text { text: String },
199 Image(ImageSource),
200}
201
202impl FromStr for ToolResultContent {
203 type Err = Infallible;
204
205 fn from_str(s: &str) -> Result<Self, Self::Err> {
206 Ok(ToolResultContent::Text { text: s.to_owned() })
207 }
208}
209
210#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
211pub struct ImageSource {
212 pub data: String,
213 pub media_type: ImageFormat,
214 pub r#type: SourceType,
215}
216
217#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
218pub struct DocumentSource {
219 pub data: String,
220 pub media_type: DocumentFormat,
221 pub r#type: SourceType,
222}
223
224#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
225#[serde(rename_all = "lowercase")]
226pub enum ImageFormat {
227 #[serde(rename = "image/jpeg")]
228 JPEG,
229 #[serde(rename = "image/png")]
230 PNG,
231 #[serde(rename = "image/gif")]
232 GIF,
233 #[serde(rename = "image/webp")]
234 WEBP,
235}
236
237#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
241#[serde(rename_all = "lowercase")]
242pub enum DocumentFormat {
243 #[serde(rename = "application/pdf")]
244 PDF,
245}
246
247#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
248#[serde(rename_all = "lowercase")]
249pub enum SourceType {
250 BASE64,
251}
252
253impl From<String> for Content {
254 fn from(text: String) -> Self {
255 Content::Text { text }
256 }
257}
258
259impl From<String> for ToolResultContent {
260 fn from(text: String) -> Self {
261 ToolResultContent::Text { text }
262 }
263}
264
265impl TryFrom<message::ContentFormat> for SourceType {
266 type Error = MessageError;
267
268 fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
269 match format {
270 message::ContentFormat::Base64 => Ok(SourceType::BASE64),
271 message::ContentFormat::String => Err(MessageError::ConversionError(
272 "Image urls are not supported in Anthropic".to_owned(),
273 )),
274 }
275 }
276}
277
278impl From<SourceType> for message::ContentFormat {
279 fn from(source_type: SourceType) -> Self {
280 match source_type {
281 SourceType::BASE64 => message::ContentFormat::Base64,
282 }
283 }
284}
285
286impl TryFrom<message::ImageMediaType> for ImageFormat {
287 type Error = MessageError;
288
289 fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
290 Ok(match media_type {
291 message::ImageMediaType::JPEG => ImageFormat::JPEG,
292 message::ImageMediaType::PNG => ImageFormat::PNG,
293 message::ImageMediaType::GIF => ImageFormat::GIF,
294 message::ImageMediaType::WEBP => ImageFormat::WEBP,
295 _ => {
296 return Err(MessageError::ConversionError(
297 format!("Unsupported image media type: {media_type:?}").to_owned(),
298 ));
299 }
300 })
301 }
302}
303
304impl From<ImageFormat> for message::ImageMediaType {
305 fn from(format: ImageFormat) -> Self {
306 match format {
307 ImageFormat::JPEG => message::ImageMediaType::JPEG,
308 ImageFormat::PNG => message::ImageMediaType::PNG,
309 ImageFormat::GIF => message::ImageMediaType::GIF,
310 ImageFormat::WEBP => message::ImageMediaType::WEBP,
311 }
312 }
313}
314
315impl TryFrom<DocumentMediaType> for DocumentFormat {
316 type Error = MessageError;
317 fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
318 if !matches!(value, DocumentMediaType::PDF) {
319 return Err(MessageError::ConversionError(
320 "Anthropic only supports PDF documents".to_string(),
321 ));
322 };
323
324 Ok(DocumentFormat::PDF)
325 }
326}
327
328impl From<message::AssistantContent> for Content {
329 fn from(text: message::AssistantContent) -> Self {
330 match text {
331 message::AssistantContent::Text(message::Text { text }) => Content::Text { text },
332 message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
333 Content::ToolUse {
334 id,
335 name: function.name,
336 input: function.arguments,
337 }
338 }
339 message::AssistantContent::Reasoning(Reasoning { reasoning, id }) => {
340 Content::Thinking {
341 thinking: reasoning.first().cloned().unwrap_or(String::new()),
342 signature: id,
343 }
344 }
345 }
346 }
347}
348
349impl TryFrom<message::Message> for Message {
350 type Error = MessageError;
351
352 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
353 Ok(match message {
354 message::Message::User { content } => Message {
355 role: Role::User,
356 content: content.try_map(|content| match content {
357 message::UserContent::Text(message::Text { text }) => {
358 Ok(Content::Text { text })
359 }
360 message::UserContent::ToolResult(message::ToolResult {
361 id, content, ..
362 }) => Ok(Content::ToolResult {
363 tool_use_id: id,
364 content: content.try_map(|content| match content {
365 message::ToolResultContent::Text(message::Text { text }) => {
366 Ok(ToolResultContent::Text { text })
367 }
368 message::ToolResultContent::Image(image) => {
369 let DocumentSourceKind::Base64(data) = image.data else {
370 return Err(MessageError::ConversionError(
371 "Only base64 strings can be used with the Anthropic API"
372 .to_string(),
373 ));
374 };
375 let media_type =
376 image.media_type.ok_or(MessageError::ConversionError(
377 "Image media type is required".to_owned(),
378 ))?;
379 Ok(ToolResultContent::Image(ImageSource {
380 data,
381 media_type: media_type.try_into()?,
382 r#type: SourceType::BASE64,
383 }))
384 }
385 })?,
386 is_error: None,
387 }),
388 message::UserContent::Image(message::Image {
389 data, media_type, ..
390 }) => {
391 let DocumentSourceKind::Base64(data) = data else {
392 return Err(MessageError::ConversionError(
393 "Only base64 strings are allowed in the Anthropic API".to_string(),
394 ));
395 };
396
397 let source = ImageSource {
398 data,
399 media_type: match media_type {
400 Some(media_type) => media_type.try_into()?,
401 None => {
402 return Err(MessageError::ConversionError(
403 "Image media type is required".to_owned(),
404 ));
405 }
406 },
407 r#type: SourceType::BASE64,
408 };
409 Ok(Content::Image { source })
410 }
411 message::UserContent::Document(message::Document {
412 data,
413 format,
414 media_type,
415 ..
416 }) => {
417 let Some(media_type) = media_type else {
418 return Err(MessageError::ConversionError(
419 "Document media type is required".to_string(),
420 ));
421 };
422
423 let source = DocumentSource {
424 data,
425 media_type: media_type.try_into()?,
426 r#type: match format {
427 Some(format) => format.try_into()?,
428 None => SourceType::BASE64,
429 },
430 };
431 Ok(Content::Document { source })
432 }
433 message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
434 "Audio is not supported in Anthropic".to_owned(),
435 )),
436 message::UserContent::Video { .. } => Err(MessageError::ConversionError(
437 "Audio is not supported in Anthropic".to_owned(),
438 )),
439 })?,
440 },
441
442 message::Message::Assistant { content, .. } => Message {
443 content: content.map(|content| content.into()),
444 role: Role::Assistant,
445 },
446 })
447 }
448}
449
450impl TryFrom<Content> for message::AssistantContent {
451 type Error = MessageError;
452
453 fn try_from(content: Content) -> Result<Self, Self::Error> {
454 Ok(match content {
455 Content::Text { text } => message::AssistantContent::text(text),
456 Content::ToolUse { id, name, input } => {
457 message::AssistantContent::tool_call(id, name, input)
458 }
459 Content::Thinking {
460 thinking,
461 signature,
462 } => message::AssistantContent::Reasoning(
463 Reasoning::new(&thinking).optional_id(signature),
464 ),
465 _ => {
466 return Err(MessageError::ConversionError(
467 format!("Unsupported content type for Assistant role: {content:?}").to_owned(),
468 ));
469 }
470 })
471 }
472}
473
474impl From<ToolResultContent> for message::ToolResultContent {
475 fn from(content: ToolResultContent) -> Self {
476 match content {
477 ToolResultContent::Text { text } => message::ToolResultContent::text(text),
478 ToolResultContent::Image(ImageSource {
479 data,
480 media_type: format,
481 ..
482 }) => message::ToolResultContent::image_base64(data, Some(format.into()), None),
483 }
484 }
485}
486
487impl TryFrom<Message> for message::Message {
488 type Error = MessageError;
489
490 fn try_from(message: Message) -> Result<Self, Self::Error> {
491 Ok(match message.role {
492 Role::User => message::Message::User {
493 content: message.content.try_map(|content| {
494 Ok(match content {
495 Content::Text { text } => message::UserContent::text(text),
496 Content::ToolResult {
497 tool_use_id,
498 content,
499 ..
500 } => message::UserContent::tool_result(
501 tool_use_id,
502 content.map(|content| content.into()),
503 ),
504 Content::Image { source } => message::UserContent::Image(message::Image {
505 data: DocumentSourceKind::base64(&source.data),
506 media_type: Some(source.media_type.into()),
507 detail: None,
508 additional_params: None,
509 }),
510 Content::Document { source } => message::UserContent::document(
511 source.data,
512 Some(message::ContentFormat::Base64),
513 Some(message::DocumentMediaType::PDF),
514 ),
515 _ => {
516 return Err(MessageError::ConversionError(
517 "Unsupported content type for User role".to_owned(),
518 ));
519 }
520 })
521 })?,
522 },
523 Role::Assistant => match message.content.first() {
524 Content::Text { .. } | Content::ToolUse { .. } | Content::Thinking { .. } => {
525 message::Message::Assistant {
526 id: None,
527 content: message.content.try_map(|content| content.try_into())?,
528 }
529 }
530
531 _ => {
532 return Err(MessageError::ConversionError(
533 format!("Unsupported message for Assistant role: {message:?}").to_owned(),
534 ));
535 }
536 },
537 })
538 }
539}
540
541#[derive(Clone)]
542pub struct CompletionModel {
543 pub(crate) client: Client,
544 pub model: String,
545 pub default_max_tokens: Option<u64>,
546}
547
548impl CompletionModel {
549 pub fn new(client: Client, model: &str) -> Self {
550 Self {
551 client,
552 model: model.to_string(),
553 default_max_tokens: calculate_max_tokens(model),
554 }
555 }
556}
557
558fn calculate_max_tokens(model: &str) -> Option<u64> {
564 if model.starts_with("claude-opus-4") {
565 Some(32000)
566 } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
567 Some(64000)
568 } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
569 Some(8192)
570 } else if model.starts_with("claude-3-opus")
571 || model.starts_with("claude-3-sonnet")
572 || model.starts_with("claude-3-haiku")
573 {
574 Some(4096)
575 } else {
576 None
577 }
578}
579
580#[derive(Debug, Deserialize, Serialize)]
581struct Metadata {
582 user_id: Option<String>,
583}
584
585#[derive(Default, Debug, Serialize, Deserialize)]
586#[serde(tag = "type", rename_all = "snake_case")]
587pub enum ToolChoice {
588 #[default]
589 Auto,
590 Any,
591 Tool {
592 name: String,
593 },
594}
595
596impl completion::CompletionModel for CompletionModel {
597 type Response = CompletionResponse;
598 type StreamingResponse = StreamingCompletionResponse;
599
600 #[cfg_attr(feature = "worker", worker::send)]
601 async fn completion(
602 &self,
603 completion_request: completion::CompletionRequest,
604 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
605 let max_tokens = if let Some(tokens) = completion_request.max_tokens {
611 tokens
612 } else if let Some(tokens) = self.default_max_tokens {
613 tokens
614 } else {
615 return Err(CompletionError::RequestError(
616 "`max_tokens` must be set for Anthropic".into(),
617 ));
618 };
619
620 let mut full_history = vec![];
621 if let Some(docs) = completion_request.normalized_documents() {
622 full_history.push(docs);
623 }
624 full_history.extend(completion_request.chat_history);
625
626 let full_history = full_history
627 .into_iter()
628 .map(Message::try_from)
629 .collect::<Result<Vec<Message>, _>>()?;
630
631 let mut request = json!({
632 "model": self.model,
633 "messages": full_history,
634 "max_tokens": max_tokens,
635 "system": completion_request.preamble.unwrap_or("".to_string()),
636 });
637
638 if let Some(temperature) = completion_request.temperature {
639 json_utils::merge_inplace(&mut request, json!({ "temperature": temperature }));
640 }
641
642 if !completion_request.tools.is_empty() {
643 json_utils::merge_inplace(
644 &mut request,
645 json!({
646 "tools": completion_request
647 .tools
648 .into_iter()
649 .map(|tool| ToolDefinition {
650 name: tool.name,
651 description: Some(tool.description),
652 input_schema: tool.parameters,
653 })
654 .collect::<Vec<_>>(),
655 "tool_choice": ToolChoice::Auto,
656 }),
657 );
658 }
659
660 if let Some(ref params) = completion_request.additional_params {
661 json_utils::merge_inplace(&mut request, params.clone())
662 }
663
664 tracing::debug!("Anthropic completion request: {request}");
665
666 let response = self
667 .client
668 .post("/v1/messages")
669 .json(&request)
670 .send()
671 .await?;
672
673 if response.status().is_success() {
674 match response.json::<ApiResponse<CompletionResponse>>().await? {
675 ApiResponse::Message(completion) => {
676 tracing::info!(target: "rig",
677 "Anthropic completion token usage: {}",
678 completion.usage
679 );
680 completion.try_into()
681 }
682 ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message)),
683 }
684 } else {
685 Err(CompletionError::ProviderError(response.text().await?))
686 }
687 }
688
689 #[cfg_attr(feature = "worker", worker::send)]
690 async fn stream(
691 &self,
692 request: CompletionRequest,
693 ) -> Result<
694 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
695 CompletionError,
696 > {
697 CompletionModel::stream(self, request).await
698 }
699}
700
701#[derive(Debug, Deserialize)]
702struct ApiErrorResponse {
703 message: String,
704}
705
706#[derive(Debug, Deserialize)]
707#[serde(tag = "type", rename_all = "snake_case")]
708enum ApiResponse<T> {
709 Message(T),
710 Error(ApiErrorResponse),
711}
712
713#[cfg(test)]
714mod tests {
715 use super::*;
716 use serde_path_to_error::deserialize;
717
718 #[test]
719 fn test_deserialize_message() {
720 let assistant_message_json = r#"
721 {
722 "role": "assistant",
723 "content": "\n\nHello there, how may I assist you today?"
724 }
725 "#;
726
727 let assistant_message_json2 = r#"
728 {
729 "role": "assistant",
730 "content": [
731 {
732 "type": "text",
733 "text": "\n\nHello there, how may I assist you today?"
734 },
735 {
736 "type": "tool_use",
737 "id": "toolu_01A09q90qw90lq917835lq9",
738 "name": "get_weather",
739 "input": {"location": "San Francisco, CA"}
740 }
741 ]
742 }
743 "#;
744
745 let user_message_json = r#"
746 {
747 "role": "user",
748 "content": [
749 {
750 "type": "image",
751 "source": {
752 "type": "base64",
753 "media_type": "image/jpeg",
754 "data": "/9j/4AAQSkZJRg..."
755 }
756 },
757 {
758 "type": "text",
759 "text": "What is in this image?"
760 },
761 {
762 "type": "tool_result",
763 "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
764 "content": "15 degrees"
765 }
766 ]
767 }
768 "#;
769
770 let assistant_message: Message = {
771 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
772 deserialize(jd).unwrap_or_else(|err| {
773 panic!("Deserialization error at {}: {}", err.path(), err);
774 })
775 };
776
777 let assistant_message2: Message = {
778 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
779 deserialize(jd).unwrap_or_else(|err| {
780 panic!("Deserialization error at {}: {}", err.path(), err);
781 })
782 };
783
784 let user_message: Message = {
785 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
786 deserialize(jd).unwrap_or_else(|err| {
787 panic!("Deserialization error at {}: {}", err.path(), err);
788 })
789 };
790
791 let Message { role, content } = assistant_message;
792 assert_eq!(role, Role::Assistant);
793 assert_eq!(
794 content.first(),
795 Content::Text {
796 text: "\n\nHello there, how may I assist you today?".to_owned()
797 }
798 );
799
800 let Message { role, content } = assistant_message2;
801 {
802 assert_eq!(role, Role::Assistant);
803 assert_eq!(content.len(), 2);
804
805 let mut iter = content.into_iter();
806
807 match iter.next().unwrap() {
808 Content::Text { text } => {
809 assert_eq!(text, "\n\nHello there, how may I assist you today?");
810 }
811 _ => panic!("Expected text content"),
812 }
813
814 match iter.next().unwrap() {
815 Content::ToolUse { id, name, input } => {
816 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
817 assert_eq!(name, "get_weather");
818 assert_eq!(input, json!({"location": "San Francisco, CA"}));
819 }
820 _ => panic!("Expected tool use content"),
821 }
822
823 assert_eq!(iter.next(), None);
824 }
825
826 let Message { role, content } = user_message;
827 {
828 assert_eq!(role, Role::User);
829 assert_eq!(content.len(), 3);
830
831 let mut iter = content.into_iter();
832
833 match iter.next().unwrap() {
834 Content::Image { source } => {
835 assert_eq!(
836 source,
837 ImageSource {
838 data: "/9j/4AAQSkZJRg...".to_owned(),
839 media_type: ImageFormat::JPEG,
840 r#type: SourceType::BASE64,
841 }
842 );
843 }
844 _ => panic!("Expected image content"),
845 }
846
847 match iter.next().unwrap() {
848 Content::Text { text } => {
849 assert_eq!(text, "What is in this image?");
850 }
851 _ => panic!("Expected text content"),
852 }
853
854 match iter.next().unwrap() {
855 Content::ToolResult {
856 tool_use_id,
857 content,
858 is_error,
859 } => {
860 assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
861 assert_eq!(
862 content.first(),
863 ToolResultContent::Text {
864 text: "15 degrees".to_owned()
865 }
866 );
867 assert_eq!(is_error, None);
868 }
869 _ => panic!("Expected tool result content"),
870 }
871
872 assert_eq!(iter.next(), None);
873 }
874 }
875
876 #[test]
877 fn test_message_to_message_conversion() {
878 let user_message: Message = serde_json::from_str(
879 r#"
880 {
881 "role": "user",
882 "content": [
883 {
884 "type": "image",
885 "source": {
886 "type": "base64",
887 "media_type": "image/jpeg",
888 "data": "/9j/4AAQSkZJRg..."
889 }
890 },
891 {
892 "type": "text",
893 "text": "What is in this image?"
894 },
895 {
896 "type": "document",
897 "source": {
898 "type": "base64",
899 "data": "base64_encoded_pdf_data",
900 "media_type": "application/pdf"
901 }
902 }
903 ]
904 }
905 "#,
906 )
907 .unwrap();
908
909 let assistant_message = Message {
910 role: Role::Assistant,
911 content: OneOrMany::one(Content::ToolUse {
912 id: "toolu_01A09q90qw90lq917835lq9".to_string(),
913 name: "get_weather".to_string(),
914 input: json!({"location": "San Francisco, CA"}),
915 }),
916 };
917
918 let tool_message = Message {
919 role: Role::User,
920 content: OneOrMany::one(Content::ToolResult {
921 tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
922 content: OneOrMany::one(ToolResultContent::Text {
923 text: "15 degrees".to_string(),
924 }),
925 is_error: None,
926 }),
927 };
928
929 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
930 let converted_assistant_message: message::Message =
931 assistant_message.clone().try_into().unwrap();
932 let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
933
934 match converted_user_message.clone() {
935 message::Message::User { content } => {
936 assert_eq!(content.len(), 3);
937
938 let mut iter = content.into_iter();
939
940 match iter.next().unwrap() {
941 message::UserContent::Image(message::Image {
942 data, media_type, ..
943 }) => {
944 assert_eq!(data, DocumentSourceKind::base64("/9j/4AAQSkZJRg..."));
945 assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
946 }
947 _ => panic!("Expected image content"),
948 }
949
950 match iter.next().unwrap() {
951 message::UserContent::Text(message::Text { text }) => {
952 assert_eq!(text, "What is in this image?");
953 }
954 _ => panic!("Expected text content"),
955 }
956
957 match iter.next().unwrap() {
958 message::UserContent::Document(message::Document {
959 data, media_type, ..
960 }) => {
961 assert_eq!(data, "base64_encoded_pdf_data");
962 assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
963 }
964 _ => panic!("Expected document content"),
965 }
966
967 assert_eq!(iter.next(), None);
968 }
969 _ => panic!("Expected user message"),
970 }
971
972 match converted_tool_message.clone() {
973 message::Message::User { content } => {
974 let message::ToolResult { id, content, .. } = match content.first() {
975 message::UserContent::ToolResult(tool_result) => tool_result,
976 _ => panic!("Expected tool result content"),
977 };
978 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
979 match content.first() {
980 message::ToolResultContent::Text(message::Text { text }) => {
981 assert_eq!(text, "15 degrees");
982 }
983 _ => panic!("Expected text content"),
984 }
985 }
986 _ => panic!("Expected tool result content"),
987 }
988
989 match converted_assistant_message.clone() {
990 message::Message::Assistant { content, .. } => {
991 assert_eq!(content.len(), 1);
992
993 match content.first() {
994 message::AssistantContent::ToolCall(message::ToolCall {
995 id, function, ..
996 }) => {
997 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
998 assert_eq!(function.name, "get_weather");
999 assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
1000 }
1001 _ => panic!("Expected tool call content"),
1002 }
1003 }
1004 _ => panic!("Expected assistant message"),
1005 }
1006
1007 let original_user_message: Message = converted_user_message.try_into().unwrap();
1008 let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
1009 let original_tool_message: Message = converted_tool_message.try_into().unwrap();
1010
1011 assert_eq!(user_message, original_user_message);
1012 assert_eq!(assistant_message, original_assistant_message);
1013 assert_eq!(tool_message, original_tool_message);
1014 }
1015}