1use crate::{
4 OneOrMany,
5 completion::{self, CompletionError},
6 json_utils,
7 message::{self, DocumentMediaType, MessageError, Reasoning},
8 one_or_many::string_or_one_or_many,
9};
10use std::{convert::Infallible, str::FromStr};
11
12use super::client::Client;
13use crate::completion::CompletionRequest;
14use crate::providers::anthropic::streaming::StreamingCompletionResponse;
15use serde::{Deserialize, Serialize};
16use serde_json::json;
17
18pub const CLAUDE_4_OPUS: &str = "claude-opus-4-0";
24
25pub const CLAUDE_4_SONNET: &str = "claude-sonnet-4-0";
27
28pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
30
31pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
33
34pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
36
37pub const CLAUDE_3_OPUS: &str = "claude-3-opus-latest";
39
40pub const CLAUDE_3_SONNET: &str = "claude-3-sonnet-20240229";
42
43pub const CLAUDE_3_HAIKU: &str = "claude-3-haiku-20240307";
45
46pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
47pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
48pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
49
50#[derive(Debug, Deserialize, Serialize)]
51pub struct CompletionResponse {
52 pub content: Vec<Content>,
53 pub id: String,
54 pub model: String,
55 pub role: String,
56 pub stop_reason: Option<String>,
57 pub stop_sequence: Option<String>,
58 pub usage: Usage,
59}
60
61#[derive(Debug, Deserialize, Serialize)]
62pub struct Usage {
63 pub input_tokens: u64,
64 pub cache_read_input_tokens: Option<u64>,
65 pub cache_creation_input_tokens: Option<u64>,
66 pub output_tokens: u64,
67}
68
69impl std::fmt::Display for Usage {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 write!(
72 f,
73 "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
74 self.input_tokens,
75 match self.cache_read_input_tokens {
76 Some(token) => token.to_string(),
77 None => "n/a".to_string(),
78 },
79 match self.cache_creation_input_tokens {
80 Some(token) => token.to_string(),
81 None => "n/a".to_string(),
82 },
83 self.output_tokens
84 )
85 }
86}
87
88#[derive(Debug, Deserialize, Serialize)]
89pub struct ToolDefinition {
90 pub name: String,
91 pub description: Option<String>,
92 pub input_schema: serde_json::Value,
93}
94
95#[derive(Debug, Deserialize, Serialize)]
96#[serde(tag = "type", rename_all = "snake_case")]
97pub enum CacheControl {
98 Ephemeral,
99}
100
101impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
102 type Error = CompletionError;
103
104 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
105 let content = response
106 .content
107 .iter()
108 .map(|content| {
109 Ok(match content {
110 Content::Text { text } => completion::AssistantContent::text(text),
111 Content::ToolUse { id, name, input } => {
112 completion::AssistantContent::tool_call(id, name, input.clone())
113 }
114 _ => {
115 return Err(CompletionError::ResponseError(
116 "Response did not contain a message or tool call".into(),
117 ));
118 }
119 })
120 })
121 .collect::<Result<Vec<_>, _>>()?;
122
123 let choice = OneOrMany::many(content).map_err(|_| {
124 CompletionError::ResponseError(
125 "Response contained no message or tool call (empty)".to_owned(),
126 )
127 })?;
128
129 let usage = completion::Usage {
130 input_tokens: response.usage.input_tokens,
131 output_tokens: response.usage.output_tokens,
132 total_tokens: response.usage.input_tokens + response.usage.output_tokens,
133 };
134
135 Ok(completion::CompletionResponse {
136 choice,
137 usage,
138 raw_response: response,
139 })
140 }
141}
142
143#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
144pub struct Message {
145 pub role: Role,
146 #[serde(deserialize_with = "string_or_one_or_many")]
147 pub content: OneOrMany<Content>,
148}
149
150#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
151#[serde(rename_all = "lowercase")]
152pub enum Role {
153 User,
154 Assistant,
155}
156
157#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
158#[serde(tag = "type", rename_all = "snake_case")]
159pub enum Content {
160 Text {
161 text: String,
162 },
163 Image {
164 source: ImageSource,
165 },
166 ToolUse {
167 id: String,
168 name: String,
169 input: serde_json::Value,
170 },
171 ToolResult {
172 tool_use_id: String,
173 #[serde(deserialize_with = "string_or_one_or_many")]
174 content: OneOrMany<ToolResultContent>,
175 #[serde(skip_serializing_if = "Option::is_none")]
176 is_error: Option<bool>,
177 },
178 Document {
179 source: DocumentSource,
180 },
181 Thinking {
182 thinking: String,
183 signature: Option<String>,
184 },
185}
186
187impl FromStr for Content {
188 type Err = Infallible;
189
190 fn from_str(s: &str) -> Result<Self, Self::Err> {
191 Ok(Content::Text { text: s.to_owned() })
192 }
193}
194
195#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
196#[serde(tag = "type", rename_all = "snake_case")]
197pub enum ToolResultContent {
198 Text { text: String },
199 Image(ImageSource),
200}
201
202impl FromStr for ToolResultContent {
203 type Err = Infallible;
204
205 fn from_str(s: &str) -> Result<Self, Self::Err> {
206 Ok(ToolResultContent::Text { text: s.to_owned() })
207 }
208}
209
210#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
211pub struct ImageSource {
212 pub data: String,
213 pub media_type: ImageFormat,
214 pub r#type: SourceType,
215}
216
217#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
218pub struct DocumentSource {
219 pub data: String,
220 pub media_type: DocumentFormat,
221 pub r#type: SourceType,
222}
223
224#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
225#[serde(rename_all = "lowercase")]
226pub enum ImageFormat {
227 #[serde(rename = "image/jpeg")]
228 JPEG,
229 #[serde(rename = "image/png")]
230 PNG,
231 #[serde(rename = "image/gif")]
232 GIF,
233 #[serde(rename = "image/webp")]
234 WEBP,
235}
236
237#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
241#[serde(rename_all = "lowercase")]
242pub enum DocumentFormat {
243 #[serde(rename = "application/pdf")]
244 PDF,
245}
246
247#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
248#[serde(rename_all = "lowercase")]
249pub enum SourceType {
250 BASE64,
251}
252
253impl From<String> for Content {
254 fn from(text: String) -> Self {
255 Content::Text { text }
256 }
257}
258
259impl From<String> for ToolResultContent {
260 fn from(text: String) -> Self {
261 ToolResultContent::Text { text }
262 }
263}
264
265impl TryFrom<message::ContentFormat> for SourceType {
266 type Error = MessageError;
267
268 fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
269 match format {
270 message::ContentFormat::Base64 => Ok(SourceType::BASE64),
271 message::ContentFormat::String => Err(MessageError::ConversionError(
272 "Image urls are not supported in Anthropic".to_owned(),
273 )),
274 }
275 }
276}
277
278impl From<SourceType> for message::ContentFormat {
279 fn from(source_type: SourceType) -> Self {
280 match source_type {
281 SourceType::BASE64 => message::ContentFormat::Base64,
282 }
283 }
284}
285
286impl TryFrom<message::ImageMediaType> for ImageFormat {
287 type Error = MessageError;
288
289 fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
290 Ok(match media_type {
291 message::ImageMediaType::JPEG => ImageFormat::JPEG,
292 message::ImageMediaType::PNG => ImageFormat::PNG,
293 message::ImageMediaType::GIF => ImageFormat::GIF,
294 message::ImageMediaType::WEBP => ImageFormat::WEBP,
295 _ => {
296 return Err(MessageError::ConversionError(
297 format!("Unsupported image media type: {media_type:?}").to_owned(),
298 ));
299 }
300 })
301 }
302}
303
304impl From<ImageFormat> for message::ImageMediaType {
305 fn from(format: ImageFormat) -> Self {
306 match format {
307 ImageFormat::JPEG => message::ImageMediaType::JPEG,
308 ImageFormat::PNG => message::ImageMediaType::PNG,
309 ImageFormat::GIF => message::ImageMediaType::GIF,
310 ImageFormat::WEBP => message::ImageMediaType::WEBP,
311 }
312 }
313}
314
315impl TryFrom<DocumentMediaType> for DocumentFormat {
316 type Error = MessageError;
317 fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
318 if !matches!(value, DocumentMediaType::PDF) {
319 return Err(MessageError::ConversionError(
320 "Anthropic only supports PDF documents".to_string(),
321 ));
322 };
323
324 Ok(DocumentFormat::PDF)
325 }
326}
327
328impl From<message::AssistantContent> for Content {
329 fn from(text: message::AssistantContent) -> Self {
330 match text {
331 message::AssistantContent::Text(message::Text { text }) => Content::Text { text },
332 message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
333 Content::ToolUse {
334 id,
335 name: function.name,
336 input: function.arguments,
337 }
338 }
339 message::AssistantContent::Reasoning(Reasoning { reasoning, id }) => {
340 Content::Thinking {
341 thinking: reasoning.first().cloned().unwrap_or(String::new()),
342 signature: id,
343 }
344 }
345 }
346 }
347}
348
349impl TryFrom<message::Message> for Message {
350 type Error = MessageError;
351
352 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
353 Ok(match message {
354 message::Message::User { content } => Message {
355 role: Role::User,
356 content: content.try_map(|content| match content {
357 message::UserContent::Text(message::Text { text }) => {
358 Ok(Content::Text { text })
359 }
360 message::UserContent::ToolResult(message::ToolResult {
361 id, content, ..
362 }) => Ok(Content::ToolResult {
363 tool_use_id: id,
364 content: content.try_map(|content| match content {
365 message::ToolResultContent::Text(message::Text { text }) => {
366 Ok(ToolResultContent::Text { text })
367 }
368 message::ToolResultContent::Image(image) => {
369 let media_type =
370 image.media_type.ok_or(MessageError::ConversionError(
371 "Image media type is required".to_owned(),
372 ))?;
373 let format = image.format.ok_or(MessageError::ConversionError(
374 "Image format is required".to_owned(),
375 ))?;
376 Ok(ToolResultContent::Image(ImageSource {
377 data: image.data,
378 media_type: media_type.try_into()?,
379 r#type: format.try_into()?,
380 }))
381 }
382 })?,
383 is_error: None,
384 }),
385 message::UserContent::Image(message::Image {
386 data,
387 format,
388 media_type,
389 ..
390 }) => {
391 let source = ImageSource {
392 data,
393 media_type: match media_type {
394 Some(media_type) => media_type.try_into()?,
395 None => {
396 return Err(MessageError::ConversionError(
397 "Image media type is required".to_owned(),
398 ));
399 }
400 },
401 r#type: match format {
402 Some(format) => format.try_into()?,
403 None => SourceType::BASE64,
404 },
405 };
406 Ok(Content::Image { source })
407 }
408 message::UserContent::Document(message::Document {
409 data,
410 format,
411 media_type,
412 ..
413 }) => {
414 let Some(media_type) = media_type else {
415 return Err(MessageError::ConversionError(
416 "Document media type is required".to_string(),
417 ));
418 };
419
420 let source = DocumentSource {
421 data,
422 media_type: media_type.try_into()?,
423 r#type: match format {
424 Some(format) => format.try_into()?,
425 None => SourceType::BASE64,
426 },
427 };
428 Ok(Content::Document { source })
429 }
430 message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
431 "Audio is not supported in Anthropic".to_owned(),
432 )),
433 message::UserContent::Video { .. } => Err(MessageError::ConversionError(
434 "Audio is not supported in Anthropic".to_owned(),
435 )),
436 })?,
437 },
438
439 message::Message::Assistant { content, .. } => Message {
440 content: content.map(|content| content.into()),
441 role: Role::Assistant,
442 },
443 })
444 }
445}
446
447impl TryFrom<Content> for message::AssistantContent {
448 type Error = MessageError;
449
450 fn try_from(content: Content) -> Result<Self, Self::Error> {
451 Ok(match content {
452 Content::Text { text } => message::AssistantContent::text(text),
453 Content::ToolUse { id, name, input } => {
454 message::AssistantContent::tool_call(id, name, input)
455 }
456 Content::Thinking {
457 thinking,
458 signature,
459 } => message::AssistantContent::Reasoning(
460 Reasoning::new(&thinking).optional_id(signature),
461 ),
462 _ => {
463 return Err(MessageError::ConversionError(
464 format!("Unsupported content type for Assistant role: {content:?}").to_owned(),
465 ));
466 }
467 })
468 }
469}
470
471impl From<ToolResultContent> for message::ToolResultContent {
472 fn from(content: ToolResultContent) -> Self {
473 match content {
474 ToolResultContent::Text { text } => message::ToolResultContent::text(text),
475 ToolResultContent::Image(ImageSource {
476 data,
477 media_type: format,
478 r#type,
479 }) => message::ToolResultContent::image(
480 data,
481 Some(r#type.into()),
482 Some(format.into()),
483 None,
484 ),
485 }
486 }
487}
488
489impl TryFrom<Message> for message::Message {
490 type Error = MessageError;
491
492 fn try_from(message: Message) -> Result<Self, Self::Error> {
493 Ok(match message.role {
494 Role::User => message::Message::User {
495 content: message.content.try_map(|content| {
496 Ok(match content {
497 Content::Text { text } => message::UserContent::text(text),
498 Content::ToolResult {
499 tool_use_id,
500 content,
501 ..
502 } => message::UserContent::tool_result(
503 tool_use_id,
504 content.map(|content| content.into()),
505 ),
506 Content::Image { source } => message::UserContent::Image(message::Image {
507 data: source.data,
508 format: Some(message::ContentFormat::Base64),
509 media_type: Some(source.media_type.into()),
510 detail: None,
511 additional_params: None,
512 }),
513 Content::Document { source } => message::UserContent::document(
514 source.data,
515 Some(message::ContentFormat::Base64),
516 Some(message::DocumentMediaType::PDF),
517 ),
518 _ => {
519 return Err(MessageError::ConversionError(
520 "Unsupported content type for User role".to_owned(),
521 ));
522 }
523 })
524 })?,
525 },
526 Role::Assistant => match message.content.first() {
527 Content::Text { .. } | Content::ToolUse { .. } | Content::Thinking { .. } => {
528 message::Message::Assistant {
529 id: None,
530 content: message.content.try_map(|content| content.try_into())?,
531 }
532 }
533
534 _ => {
535 return Err(MessageError::ConversionError(
536 format!("Unsupported message for Assistant role: {message:?}").to_owned(),
537 ));
538 }
539 },
540 })
541 }
542}
543
544#[derive(Clone)]
545pub struct CompletionModel {
546 pub(crate) client: Client,
547 pub model: String,
548 pub default_max_tokens: Option<u64>,
549}
550
551impl CompletionModel {
552 pub fn new(client: Client, model: &str) -> Self {
553 Self {
554 client,
555 model: model.to_string(),
556 default_max_tokens: calculate_max_tokens(model),
557 }
558 }
559}
560
561fn calculate_max_tokens(model: &str) -> Option<u64> {
567 if model.starts_with("claude-opus-4") {
568 Some(32000)
569 } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
570 Some(64000)
571 } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
572 Some(8192)
573 } else if model.starts_with("claude-3-opus")
574 || model.starts_with("claude-3-sonnet")
575 || model.starts_with("claude-3-haiku")
576 {
577 Some(4096)
578 } else {
579 None
580 }
581}
582
583#[derive(Debug, Deserialize, Serialize)]
584struct Metadata {
585 user_id: Option<String>,
586}
587
588#[derive(Default, Debug, Serialize, Deserialize)]
589#[serde(tag = "type", rename_all = "snake_case")]
590pub enum ToolChoice {
591 #[default]
592 Auto,
593 Any,
594 Tool {
595 name: String,
596 },
597}
598
599impl completion::CompletionModel for CompletionModel {
600 type Response = CompletionResponse;
601 type StreamingResponse = StreamingCompletionResponse;
602
603 #[cfg_attr(feature = "worker", worker::send)]
604 async fn completion(
605 &self,
606 completion_request: completion::CompletionRequest,
607 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
608 let max_tokens = if let Some(tokens) = completion_request.max_tokens {
614 tokens
615 } else if let Some(tokens) = self.default_max_tokens {
616 tokens
617 } else {
618 return Err(CompletionError::RequestError(
619 "`max_tokens` must be set for Anthropic".into(),
620 ));
621 };
622
623 let mut full_history = vec![];
624 if let Some(docs) = completion_request.normalized_documents() {
625 full_history.push(docs);
626 }
627 full_history.extend(completion_request.chat_history);
628
629 let full_history = full_history
630 .into_iter()
631 .map(Message::try_from)
632 .collect::<Result<Vec<Message>, _>>()?;
633
634 let mut request = json!({
635 "model": self.model,
636 "messages": full_history,
637 "max_tokens": max_tokens,
638 "system": completion_request.preamble.unwrap_or("".to_string()),
639 });
640
641 if let Some(temperature) = completion_request.temperature {
642 json_utils::merge_inplace(&mut request, json!({ "temperature": temperature }));
643 }
644
645 if !completion_request.tools.is_empty() {
646 json_utils::merge_inplace(
647 &mut request,
648 json!({
649 "tools": completion_request
650 .tools
651 .into_iter()
652 .map(|tool| ToolDefinition {
653 name: tool.name,
654 description: Some(tool.description),
655 input_schema: tool.parameters,
656 })
657 .collect::<Vec<_>>(),
658 "tool_choice": ToolChoice::Auto,
659 }),
660 );
661 }
662
663 if let Some(ref params) = completion_request.additional_params {
664 json_utils::merge_inplace(&mut request, params.clone())
665 }
666
667 tracing::debug!("Anthropic completion request: {request}");
668
669 let response = self
670 .client
671 .post("/v1/messages")
672 .json(&request)
673 .send()
674 .await?;
675
676 if response.status().is_success() {
677 match response.json::<ApiResponse<CompletionResponse>>().await? {
678 ApiResponse::Message(completion) => {
679 tracing::info!(target: "rig",
680 "Anthropic completion token usage: {}",
681 completion.usage
682 );
683 completion.try_into()
684 }
685 ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message)),
686 }
687 } else {
688 Err(CompletionError::ProviderError(response.text().await?))
689 }
690 }
691
692 #[cfg_attr(feature = "worker", worker::send)]
693 async fn stream(
694 &self,
695 request: CompletionRequest,
696 ) -> Result<
697 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
698 CompletionError,
699 > {
700 CompletionModel::stream(self, request).await
701 }
702}
703
704#[derive(Debug, Deserialize)]
705struct ApiErrorResponse {
706 message: String,
707}
708
709#[derive(Debug, Deserialize)]
710#[serde(tag = "type", rename_all = "snake_case")]
711enum ApiResponse<T> {
712 Message(T),
713 Error(ApiErrorResponse),
714}
715
716#[cfg(test)]
717mod tests {
718 use super::*;
719 use serde_path_to_error::deserialize;
720
721 #[test]
722 fn test_deserialize_message() {
723 let assistant_message_json = r#"
724 {
725 "role": "assistant",
726 "content": "\n\nHello there, how may I assist you today?"
727 }
728 "#;
729
730 let assistant_message_json2 = r#"
731 {
732 "role": "assistant",
733 "content": [
734 {
735 "type": "text",
736 "text": "\n\nHello there, how may I assist you today?"
737 },
738 {
739 "type": "tool_use",
740 "id": "toolu_01A09q90qw90lq917835lq9",
741 "name": "get_weather",
742 "input": {"location": "San Francisco, CA"}
743 }
744 ]
745 }
746 "#;
747
748 let user_message_json = r#"
749 {
750 "role": "user",
751 "content": [
752 {
753 "type": "image",
754 "source": {
755 "type": "base64",
756 "media_type": "image/jpeg",
757 "data": "/9j/4AAQSkZJRg..."
758 }
759 },
760 {
761 "type": "text",
762 "text": "What is in this image?"
763 },
764 {
765 "type": "tool_result",
766 "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
767 "content": "15 degrees"
768 }
769 ]
770 }
771 "#;
772
773 let assistant_message: Message = {
774 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
775 deserialize(jd).unwrap_or_else(|err| {
776 panic!("Deserialization error at {}: {}", err.path(), err);
777 })
778 };
779
780 let assistant_message2: Message = {
781 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
782 deserialize(jd).unwrap_or_else(|err| {
783 panic!("Deserialization error at {}: {}", err.path(), err);
784 })
785 };
786
787 let user_message: Message = {
788 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
789 deserialize(jd).unwrap_or_else(|err| {
790 panic!("Deserialization error at {}: {}", err.path(), err);
791 })
792 };
793
794 let Message { role, content } = assistant_message;
795 assert_eq!(role, Role::Assistant);
796 assert_eq!(
797 content.first(),
798 Content::Text {
799 text: "\n\nHello there, how may I assist you today?".to_owned()
800 }
801 );
802
803 let Message { role, content } = assistant_message2;
804 {
805 assert_eq!(role, Role::Assistant);
806 assert_eq!(content.len(), 2);
807
808 let mut iter = content.into_iter();
809
810 match iter.next().unwrap() {
811 Content::Text { text } => {
812 assert_eq!(text, "\n\nHello there, how may I assist you today?");
813 }
814 _ => panic!("Expected text content"),
815 }
816
817 match iter.next().unwrap() {
818 Content::ToolUse { id, name, input } => {
819 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
820 assert_eq!(name, "get_weather");
821 assert_eq!(input, json!({"location": "San Francisco, CA"}));
822 }
823 _ => panic!("Expected tool use content"),
824 }
825
826 assert_eq!(iter.next(), None);
827 }
828
829 let Message { role, content } = user_message;
830 {
831 assert_eq!(role, Role::User);
832 assert_eq!(content.len(), 3);
833
834 let mut iter = content.into_iter();
835
836 match iter.next().unwrap() {
837 Content::Image { source } => {
838 assert_eq!(
839 source,
840 ImageSource {
841 data: "/9j/4AAQSkZJRg...".to_owned(),
842 media_type: ImageFormat::JPEG,
843 r#type: SourceType::BASE64,
844 }
845 );
846 }
847 _ => panic!("Expected image content"),
848 }
849
850 match iter.next().unwrap() {
851 Content::Text { text } => {
852 assert_eq!(text, "What is in this image?");
853 }
854 _ => panic!("Expected text content"),
855 }
856
857 match iter.next().unwrap() {
858 Content::ToolResult {
859 tool_use_id,
860 content,
861 is_error,
862 } => {
863 assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
864 assert_eq!(
865 content.first(),
866 ToolResultContent::Text {
867 text: "15 degrees".to_owned()
868 }
869 );
870 assert_eq!(is_error, None);
871 }
872 _ => panic!("Expected tool result content"),
873 }
874
875 assert_eq!(iter.next(), None);
876 }
877 }
878
879 #[test]
880 fn test_message_to_message_conversion() {
881 let user_message: Message = serde_json::from_str(
882 r#"
883 {
884 "role": "user",
885 "content": [
886 {
887 "type": "image",
888 "source": {
889 "type": "base64",
890 "media_type": "image/jpeg",
891 "data": "/9j/4AAQSkZJRg..."
892 }
893 },
894 {
895 "type": "text",
896 "text": "What is in this image?"
897 },
898 {
899 "type": "document",
900 "source": {
901 "type": "base64",
902 "data": "base64_encoded_pdf_data",
903 "media_type": "application/pdf"
904 }
905 }
906 ]
907 }
908 "#,
909 )
910 .unwrap();
911
912 let assistant_message = Message {
913 role: Role::Assistant,
914 content: OneOrMany::one(Content::ToolUse {
915 id: "toolu_01A09q90qw90lq917835lq9".to_string(),
916 name: "get_weather".to_string(),
917 input: json!({"location": "San Francisco, CA"}),
918 }),
919 };
920
921 let tool_message = Message {
922 role: Role::User,
923 content: OneOrMany::one(Content::ToolResult {
924 tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
925 content: OneOrMany::one(ToolResultContent::Text {
926 text: "15 degrees".to_string(),
927 }),
928 is_error: None,
929 }),
930 };
931
932 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
933 let converted_assistant_message: message::Message =
934 assistant_message.clone().try_into().unwrap();
935 let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
936
937 match converted_user_message.clone() {
938 message::Message::User { content } => {
939 assert_eq!(content.len(), 3);
940
941 let mut iter = content.into_iter();
942
943 match iter.next().unwrap() {
944 message::UserContent::Image(message::Image {
945 data,
946 format,
947 media_type,
948 ..
949 }) => {
950 assert_eq!(data, "/9j/4AAQSkZJRg...");
951 assert_eq!(format.unwrap(), message::ContentFormat::Base64);
952 assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
953 }
954 _ => panic!("Expected image content"),
955 }
956
957 match iter.next().unwrap() {
958 message::UserContent::Text(message::Text { text }) => {
959 assert_eq!(text, "What is in this image?");
960 }
961 _ => panic!("Expected text content"),
962 }
963
964 match iter.next().unwrap() {
965 message::UserContent::Document(message::Document {
966 data,
967 format,
968 media_type,
969 ..
970 }) => {
971 assert_eq!(data, "base64_encoded_pdf_data");
972 assert_eq!(format.unwrap(), message::ContentFormat::Base64);
973 assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
974 }
975 _ => panic!("Expected document content"),
976 }
977
978 assert_eq!(iter.next(), None);
979 }
980 _ => panic!("Expected user message"),
981 }
982
983 match converted_tool_message.clone() {
984 message::Message::User { content } => {
985 let message::ToolResult { id, content, .. } = match content.first() {
986 message::UserContent::ToolResult(tool_result) => tool_result,
987 _ => panic!("Expected tool result content"),
988 };
989 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
990 match content.first() {
991 message::ToolResultContent::Text(message::Text { text }) => {
992 assert_eq!(text, "15 degrees");
993 }
994 _ => panic!("Expected text content"),
995 }
996 }
997 _ => panic!("Expected tool result content"),
998 }
999
1000 match converted_assistant_message.clone() {
1001 message::Message::Assistant { content, .. } => {
1002 assert_eq!(content.len(), 1);
1003
1004 match content.first() {
1005 message::AssistantContent::ToolCall(message::ToolCall {
1006 id, function, ..
1007 }) => {
1008 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1009 assert_eq!(function.name, "get_weather");
1010 assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
1011 }
1012 _ => panic!("Expected tool call content"),
1013 }
1014 }
1015 _ => panic!("Expected assistant message"),
1016 }
1017
1018 let original_user_message: Message = converted_user_message.try_into().unwrap();
1019 let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
1020 let original_tool_message: Message = converted_tool_message.try_into().unwrap();
1021
1022 assert_eq!(user_message, original_user_message);
1023 assert_eq!(assistant_message, original_assistant_message);
1024 assert_eq!(tool_message, original_tool_message);
1025 }
1026}