1use crate::{
4 OneOrMany,
5 completion::{self, CompletionError},
6 json_utils,
7 message::{self, DocumentMediaType, MessageError},
8 one_or_many::string_or_one_or_many,
9};
10use std::{convert::Infallible, str::FromStr};
11
12use super::client::Client;
13use crate::completion::CompletionRequest;
14use crate::providers::anthropic::streaming::StreamingCompletionResponse;
15use serde::{Deserialize, Serialize};
16use serde_json::json;
17
18pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
23
24pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
26
27pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
29
30pub const CLAUDE_3_OPUS: &str = "claude-3-opus-latest";
32
33pub const CLAUDE_3_SONNET: &str = "claude-3-sonnet-20240229";
35
36pub const CLAUDE_3_HAIKU: &str = "claude-3-haiku-20240307";
38
39pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
40pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
41pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
42
43#[derive(Debug, Deserialize)]
44pub struct CompletionResponse {
45 pub content: Vec<Content>,
46 pub id: String,
47 pub model: String,
48 pub role: String,
49 pub stop_reason: Option<String>,
50 pub stop_sequence: Option<String>,
51 pub usage: Usage,
52}
53
54#[derive(Debug, Deserialize, Serialize)]
55pub struct Usage {
56 pub input_tokens: u64,
57 pub cache_read_input_tokens: Option<u64>,
58 pub cache_creation_input_tokens: Option<u64>,
59 pub output_tokens: u64,
60}
61
62impl std::fmt::Display for Usage {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(
65 f,
66 "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
67 self.input_tokens,
68 match self.cache_read_input_tokens {
69 Some(token) => token.to_string(),
70 None => "n/a".to_string(),
71 },
72 match self.cache_creation_input_tokens {
73 Some(token) => token.to_string(),
74 None => "n/a".to_string(),
75 },
76 self.output_tokens
77 )
78 }
79}
80
81#[derive(Debug, Deserialize, Serialize)]
82pub struct ToolDefinition {
83 pub name: String,
84 pub description: Option<String>,
85 pub input_schema: serde_json::Value,
86}
87
88#[derive(Debug, Deserialize, Serialize)]
89#[serde(tag = "type", rename_all = "snake_case")]
90pub enum CacheControl {
91 Ephemeral,
92}
93
94impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
95 type Error = CompletionError;
96
97 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
98 let content = response
99 .content
100 .iter()
101 .map(|content| {
102 Ok(match content {
103 Content::Text { text } => completion::AssistantContent::text(text),
104 Content::ToolUse { id, name, input } => {
105 completion::AssistantContent::tool_call(id, name, input.clone())
106 }
107 _ => {
108 return Err(CompletionError::ResponseError(
109 "Response did not contain a message or tool call".into(),
110 ));
111 }
112 })
113 })
114 .collect::<Result<Vec<_>, _>>()?;
115
116 let choice = OneOrMany::many(content).map_err(|_| {
117 CompletionError::ResponseError(
118 "Response contained no message or tool call (empty)".to_owned(),
119 )
120 })?;
121
122 Ok(completion::CompletionResponse {
123 choice,
124 raw_response: response,
125 })
126 }
127}
128
129#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
130pub struct Message {
131 pub role: Role,
132 #[serde(deserialize_with = "string_or_one_or_many")]
133 pub content: OneOrMany<Content>,
134}
135
136#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
137#[serde(rename_all = "lowercase")]
138pub enum Role {
139 User,
140 Assistant,
141}
142
143#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
144#[serde(tag = "type", rename_all = "snake_case")]
145pub enum Content {
146 Text {
147 text: String,
148 },
149 Image {
150 source: ImageSource,
151 },
152 ToolUse {
153 id: String,
154 name: String,
155 input: serde_json::Value,
156 },
157 ToolResult {
158 tool_use_id: String,
159 #[serde(deserialize_with = "string_or_one_or_many")]
160 content: OneOrMany<ToolResultContent>,
161 #[serde(skip_serializing_if = "Option::is_none")]
162 is_error: Option<bool>,
163 },
164 Document {
165 source: DocumentSource,
166 },
167}
168
169impl FromStr for Content {
170 type Err = Infallible;
171
172 fn from_str(s: &str) -> Result<Self, Self::Err> {
173 Ok(Content::Text { text: s.to_owned() })
174 }
175}
176
177#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
178#[serde(tag = "type", rename_all = "snake_case")]
179pub enum ToolResultContent {
180 Text { text: String },
181 Image(ImageSource),
182}
183
184impl FromStr for ToolResultContent {
185 type Err = Infallible;
186
187 fn from_str(s: &str) -> Result<Self, Self::Err> {
188 Ok(ToolResultContent::Text { text: s.to_owned() })
189 }
190}
191
192#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
193pub struct ImageSource {
194 pub data: String,
195 pub media_type: ImageFormat,
196 pub r#type: SourceType,
197}
198
199#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
200pub struct DocumentSource {
201 pub data: String,
202 pub media_type: DocumentFormat,
203 pub r#type: SourceType,
204}
205
206#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
207#[serde(rename_all = "lowercase")]
208pub enum ImageFormat {
209 #[serde(rename = "image/jpeg")]
210 JPEG,
211 #[serde(rename = "image/png")]
212 PNG,
213 #[serde(rename = "image/gif")]
214 GIF,
215 #[serde(rename = "image/webp")]
216 WEBP,
217}
218
219#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
223#[serde(rename_all = "lowercase")]
224pub enum DocumentFormat {
225 #[serde(rename = "application/pdf")]
226 PDF,
227}
228
229#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
230#[serde(rename_all = "lowercase")]
231pub enum SourceType {
232 BASE64,
233}
234
235impl From<String> for Content {
236 fn from(text: String) -> Self {
237 Content::Text { text }
238 }
239}
240
241impl From<String> for ToolResultContent {
242 fn from(text: String) -> Self {
243 ToolResultContent::Text { text }
244 }
245}
246
247impl TryFrom<message::ContentFormat> for SourceType {
248 type Error = MessageError;
249
250 fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
251 match format {
252 message::ContentFormat::Base64 => Ok(SourceType::BASE64),
253 message::ContentFormat::String => Err(MessageError::ConversionError(
254 "Image urls are not supported in Anthropic".to_owned(),
255 )),
256 }
257 }
258}
259
260impl From<SourceType> for message::ContentFormat {
261 fn from(source_type: SourceType) -> Self {
262 match source_type {
263 SourceType::BASE64 => message::ContentFormat::Base64,
264 }
265 }
266}
267
268impl TryFrom<message::ImageMediaType> for ImageFormat {
269 type Error = MessageError;
270
271 fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
272 Ok(match media_type {
273 message::ImageMediaType::JPEG => ImageFormat::JPEG,
274 message::ImageMediaType::PNG => ImageFormat::PNG,
275 message::ImageMediaType::GIF => ImageFormat::GIF,
276 message::ImageMediaType::WEBP => ImageFormat::WEBP,
277 _ => {
278 return Err(MessageError::ConversionError(
279 format!("Unsupported image media type: {media_type:?}").to_owned(),
280 ));
281 }
282 })
283 }
284}
285
286impl From<ImageFormat> for message::ImageMediaType {
287 fn from(format: ImageFormat) -> Self {
288 match format {
289 ImageFormat::JPEG => message::ImageMediaType::JPEG,
290 ImageFormat::PNG => message::ImageMediaType::PNG,
291 ImageFormat::GIF => message::ImageMediaType::GIF,
292 ImageFormat::WEBP => message::ImageMediaType::WEBP,
293 }
294 }
295}
296
297impl TryFrom<DocumentMediaType> for DocumentFormat {
298 type Error = MessageError;
299 fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
300 if !matches!(value, DocumentMediaType::PDF) {
301 return Err(MessageError::ConversionError(
302 "Anthropic only supports PDF documents".to_string(),
303 ));
304 };
305
306 Ok(DocumentFormat::PDF)
307 }
308}
309
310impl From<message::AssistantContent> for Content {
311 fn from(text: message::AssistantContent) -> Self {
312 match text {
313 message::AssistantContent::Text(message::Text { text }) => Content::Text { text },
314 message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
315 Content::ToolUse {
316 id,
317 name: function.name,
318 input: function.arguments,
319 }
320 }
321 }
322 }
323}
324
325impl TryFrom<message::Message> for Message {
326 type Error = MessageError;
327
328 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
329 Ok(match message {
330 message::Message::User { content } => Message {
331 role: Role::User,
332 content: content.try_map(|content| match content {
333 message::UserContent::Text(message::Text { text }) => {
334 Ok(Content::Text { text })
335 }
336 message::UserContent::ToolResult(message::ToolResult {
337 id, content, ..
338 }) => Ok(Content::ToolResult {
339 tool_use_id: id,
340 content: content.try_map(|content| match content {
341 message::ToolResultContent::Text(message::Text { text }) => {
342 Ok(ToolResultContent::Text { text })
343 }
344 message::ToolResultContent::Image(image) => {
345 let media_type =
346 image.media_type.ok_or(MessageError::ConversionError(
347 "Image media type is required".to_owned(),
348 ))?;
349 let format = image.format.ok_or(MessageError::ConversionError(
350 "Image format is required".to_owned(),
351 ))?;
352 Ok(ToolResultContent::Image(ImageSource {
353 data: image.data,
354 media_type: media_type.try_into()?,
355 r#type: format.try_into()?,
356 }))
357 }
358 })?,
359 is_error: None,
360 }),
361 message::UserContent::Image(message::Image {
362 data,
363 format,
364 media_type,
365 ..
366 }) => {
367 let source = ImageSource {
368 data,
369 media_type: match media_type {
370 Some(media_type) => media_type.try_into()?,
371 None => {
372 return Err(MessageError::ConversionError(
373 "Image media type is required".to_owned(),
374 ));
375 }
376 },
377 r#type: match format {
378 Some(format) => format.try_into()?,
379 None => SourceType::BASE64,
380 },
381 };
382 Ok(Content::Image { source })
383 }
384 message::UserContent::Document(message::Document {
385 data,
386 format,
387 media_type,
388 }) => {
389 let Some(media_type) = media_type else {
390 return Err(MessageError::ConversionError(
391 "Document media type is required".to_string(),
392 ));
393 };
394
395 let source = DocumentSource {
396 data,
397 media_type: media_type.try_into()?,
398 r#type: match format {
399 Some(format) => format.try_into()?,
400 None => SourceType::BASE64,
401 },
402 };
403 Ok(Content::Document { source })
404 }
405 message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
406 "Audio is not supported in Anthropic".to_owned(),
407 )),
408 })?,
409 },
410
411 message::Message::Assistant { content, .. } => Message {
412 content: content.map(|content| content.into()),
413 role: Role::Assistant,
414 },
415 })
416 }
417}
418
419impl TryFrom<Content> for message::AssistantContent {
420 type Error = MessageError;
421
422 fn try_from(content: Content) -> Result<Self, Self::Error> {
423 Ok(match content {
424 Content::Text { text } => message::AssistantContent::text(text),
425 Content::ToolUse { id, name, input } => {
426 message::AssistantContent::tool_call(id, name, input)
427 }
428 _ => {
429 return Err(MessageError::ConversionError(
430 format!("Unsupported content type for Assistant role: {content:?}").to_owned(),
431 ));
432 }
433 })
434 }
435}
436
437impl From<ToolResultContent> for message::ToolResultContent {
438 fn from(content: ToolResultContent) -> Self {
439 match content {
440 ToolResultContent::Text { text } => message::ToolResultContent::text(text),
441 ToolResultContent::Image(ImageSource {
442 data,
443 media_type: format,
444 r#type,
445 }) => message::ToolResultContent::image(
446 data,
447 Some(r#type.into()),
448 Some(format.into()),
449 None,
450 ),
451 }
452 }
453}
454
455impl TryFrom<Message> for message::Message {
456 type Error = MessageError;
457
458 fn try_from(message: Message) -> Result<Self, Self::Error> {
459 Ok(match message.role {
460 Role::User => message::Message::User {
461 content: message.content.try_map(|content| {
462 Ok(match content {
463 Content::Text { text } => message::UserContent::text(text),
464 Content::ToolResult {
465 tool_use_id,
466 content,
467 ..
468 } => message::UserContent::tool_result(
469 tool_use_id,
470 content.map(|content| content.into()),
471 ),
472 Content::Image { source } => message::UserContent::Image(message::Image {
473 data: source.data,
474 format: Some(message::ContentFormat::Base64),
475 media_type: Some(source.media_type.into()),
476 detail: None,
477 }),
478 Content::Document { source } => message::UserContent::document(
479 source.data,
480 Some(message::ContentFormat::Base64),
481 Some(message::DocumentMediaType::PDF),
482 ),
483 _ => {
484 return Err(MessageError::ConversionError(
485 "Unsupported content type for User role".to_owned(),
486 ));
487 }
488 })
489 })?,
490 },
491 Role::Assistant => match message.content.first() {
492 Content::Text { .. } | Content::ToolUse { .. } => message::Message::Assistant {
493 id: None,
494 content: message.content.try_map(|content| content.try_into())?,
495 },
496
497 _ => {
498 return Err(MessageError::ConversionError(
499 format!("Unsupported message for Assistant role: {message:?}").to_owned(),
500 ));
501 }
502 },
503 })
504 }
505}
506
507#[derive(Clone)]
508pub struct CompletionModel {
509 pub(crate) client: Client,
510 pub model: String,
511 pub default_max_tokens: Option<u64>,
512}
513
514impl CompletionModel {
515 pub fn new(client: Client, model: &str) -> Self {
516 Self {
517 client,
518 model: model.to_string(),
519 default_max_tokens: calculate_max_tokens(model),
520 }
521 }
522}
523
524fn calculate_max_tokens(model: &str) -> Option<u64> {
530 if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
531 Some(8192)
532 } else if model.starts_with("claude-3-opus")
533 || model.starts_with("claude-3-sonnet")
534 || model.starts_with("claude-3-haiku")
535 {
536 Some(4096)
537 } else {
538 None
539 }
540}
541
542#[derive(Debug, Deserialize, Serialize)]
543struct Metadata {
544 user_id: Option<String>,
545}
546
547#[derive(Default, Debug, Serialize, Deserialize)]
548#[serde(tag = "type", rename_all = "snake_case")]
549pub enum ToolChoice {
550 #[default]
551 Auto,
552 Any,
553 Tool {
554 name: String,
555 },
556}
557
558impl completion::CompletionModel for CompletionModel {
559 type Response = CompletionResponse;
560 type StreamingResponse = StreamingCompletionResponse;
561
562 #[cfg_attr(feature = "worker", worker::send)]
563 async fn completion(
564 &self,
565 completion_request: completion::CompletionRequest,
566 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
567 let max_tokens = if let Some(tokens) = completion_request.max_tokens {
573 tokens
574 } else if let Some(tokens) = self.default_max_tokens {
575 tokens
576 } else {
577 return Err(CompletionError::RequestError(
578 "`max_tokens` must be set for Anthropic".into(),
579 ));
580 };
581
582 let mut full_history = vec![];
583 if let Some(docs) = completion_request.normalized_documents() {
584 full_history.push(docs);
585 }
586 full_history.extend(completion_request.chat_history);
587
588 let full_history = full_history
589 .into_iter()
590 .map(Message::try_from)
591 .collect::<Result<Vec<Message>, _>>()?;
592
593 let mut request = json!({
594 "model": self.model,
595 "messages": full_history,
596 "max_tokens": max_tokens,
597 "system": completion_request.preamble.unwrap_or("".to_string()),
598 });
599
600 if let Some(temperature) = completion_request.temperature {
601 json_utils::merge_inplace(&mut request, json!({ "temperature": temperature }));
602 }
603
604 if !completion_request.tools.is_empty() {
605 json_utils::merge_inplace(
606 &mut request,
607 json!({
608 "tools": completion_request
609 .tools
610 .into_iter()
611 .map(|tool| ToolDefinition {
612 name: tool.name,
613 description: Some(tool.description),
614 input_schema: tool.parameters,
615 })
616 .collect::<Vec<_>>(),
617 "tool_choice": ToolChoice::Auto,
618 }),
619 );
620 }
621
622 if let Some(ref params) = completion_request.additional_params {
623 json_utils::merge_inplace(&mut request, params.clone())
624 }
625
626 tracing::debug!("Anthropic completion request: {request}");
627
628 let response = self
629 .client
630 .post("/v1/messages")
631 .json(&request)
632 .send()
633 .await?;
634
635 if response.status().is_success() {
636 match response.json::<ApiResponse<CompletionResponse>>().await? {
637 ApiResponse::Message(completion) => {
638 tracing::info!(target: "rig",
639 "Anthropic completion token usage: {}",
640 completion.usage
641 );
642 completion.try_into()
643 }
644 ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message)),
645 }
646 } else {
647 Err(CompletionError::ProviderError(response.text().await?))
648 }
649 }
650
651 #[cfg_attr(feature = "worker", worker::send)]
652 async fn stream(
653 &self,
654 request: CompletionRequest,
655 ) -> Result<
656 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
657 CompletionError,
658 > {
659 CompletionModel::stream(self, request).await
660 }
661}
662
663#[derive(Debug, Deserialize)]
664struct ApiErrorResponse {
665 message: String,
666}
667
668#[derive(Debug, Deserialize)]
669#[serde(tag = "type", rename_all = "snake_case")]
670enum ApiResponse<T> {
671 Message(T),
672 Error(ApiErrorResponse),
673}
674
675#[cfg(test)]
676mod tests {
677 use super::*;
678 use serde_path_to_error::deserialize;
679
680 #[test]
681 fn test_deserialize_message() {
682 let assistant_message_json = r#"
683 {
684 "role": "assistant",
685 "content": "\n\nHello there, how may I assist you today?"
686 }
687 "#;
688
689 let assistant_message_json2 = r#"
690 {
691 "role": "assistant",
692 "content": [
693 {
694 "type": "text",
695 "text": "\n\nHello there, how may I assist you today?"
696 },
697 {
698 "type": "tool_use",
699 "id": "toolu_01A09q90qw90lq917835lq9",
700 "name": "get_weather",
701 "input": {"location": "San Francisco, CA"}
702 }
703 ]
704 }
705 "#;
706
707 let user_message_json = r#"
708 {
709 "role": "user",
710 "content": [
711 {
712 "type": "image",
713 "source": {
714 "type": "base64",
715 "media_type": "image/jpeg",
716 "data": "/9j/4AAQSkZJRg..."
717 }
718 },
719 {
720 "type": "text",
721 "text": "What is in this image?"
722 },
723 {
724 "type": "tool_result",
725 "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
726 "content": "15 degrees"
727 }
728 ]
729 }
730 "#;
731
732 let assistant_message: Message = {
733 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
734 deserialize(jd).unwrap_or_else(|err| {
735 panic!("Deserialization error at {}: {}", err.path(), err);
736 })
737 };
738
739 let assistant_message2: Message = {
740 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
741 deserialize(jd).unwrap_or_else(|err| {
742 panic!("Deserialization error at {}: {}", err.path(), err);
743 })
744 };
745
746 let user_message: Message = {
747 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
748 deserialize(jd).unwrap_or_else(|err| {
749 panic!("Deserialization error at {}: {}", err.path(), err);
750 })
751 };
752
753 let Message { role, content } = assistant_message;
754 assert_eq!(role, Role::Assistant);
755 assert_eq!(
756 content.first(),
757 Content::Text {
758 text: "\n\nHello there, how may I assist you today?".to_owned()
759 }
760 );
761
762 let Message { role, content } = assistant_message2;
763 {
764 assert_eq!(role, Role::Assistant);
765 assert_eq!(content.len(), 2);
766
767 let mut iter = content.into_iter();
768
769 match iter.next().unwrap() {
770 Content::Text { text } => {
771 assert_eq!(text, "\n\nHello there, how may I assist you today?");
772 }
773 _ => panic!("Expected text content"),
774 }
775
776 match iter.next().unwrap() {
777 Content::ToolUse { id, name, input } => {
778 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
779 assert_eq!(name, "get_weather");
780 assert_eq!(input, json!({"location": "San Francisco, CA"}));
781 }
782 _ => panic!("Expected tool use content"),
783 }
784
785 assert_eq!(iter.next(), None);
786 }
787
788 let Message { role, content } = user_message;
789 {
790 assert_eq!(role, Role::User);
791 assert_eq!(content.len(), 3);
792
793 let mut iter = content.into_iter();
794
795 match iter.next().unwrap() {
796 Content::Image { source } => {
797 assert_eq!(
798 source,
799 ImageSource {
800 data: "/9j/4AAQSkZJRg...".to_owned(),
801 media_type: ImageFormat::JPEG,
802 r#type: SourceType::BASE64,
803 }
804 );
805 }
806 _ => panic!("Expected image content"),
807 }
808
809 match iter.next().unwrap() {
810 Content::Text { text } => {
811 assert_eq!(text, "What is in this image?");
812 }
813 _ => panic!("Expected text content"),
814 }
815
816 match iter.next().unwrap() {
817 Content::ToolResult {
818 tool_use_id,
819 content,
820 is_error,
821 } => {
822 assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
823 assert_eq!(
824 content.first(),
825 ToolResultContent::Text {
826 text: "15 degrees".to_owned()
827 }
828 );
829 assert_eq!(is_error, None);
830 }
831 _ => panic!("Expected tool result content"),
832 }
833
834 assert_eq!(iter.next(), None);
835 }
836 }
837
838 #[test]
839 fn test_message_to_message_conversion() {
840 let user_message: Message = serde_json::from_str(
841 r#"
842 {
843 "role": "user",
844 "content": [
845 {
846 "type": "image",
847 "source": {
848 "type": "base64",
849 "media_type": "image/jpeg",
850 "data": "/9j/4AAQSkZJRg..."
851 }
852 },
853 {
854 "type": "text",
855 "text": "What is in this image?"
856 },
857 {
858 "type": "document",
859 "source": {
860 "type": "base64",
861 "data": "base64_encoded_pdf_data",
862 "media_type": "application/pdf"
863 }
864 }
865 ]
866 }
867 "#,
868 )
869 .unwrap();
870
871 let assistant_message = Message {
872 role: Role::Assistant,
873 content: OneOrMany::one(Content::ToolUse {
874 id: "toolu_01A09q90qw90lq917835lq9".to_string(),
875 name: "get_weather".to_string(),
876 input: json!({"location": "San Francisco, CA"}),
877 }),
878 };
879
880 let tool_message = Message {
881 role: Role::User,
882 content: OneOrMany::one(Content::ToolResult {
883 tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
884 content: OneOrMany::one(ToolResultContent::Text {
885 text: "15 degrees".to_string(),
886 }),
887 is_error: None,
888 }),
889 };
890
891 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
892 let converted_assistant_message: message::Message =
893 assistant_message.clone().try_into().unwrap();
894 let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
895
896 match converted_user_message.clone() {
897 message::Message::User { content } => {
898 assert_eq!(content.len(), 3);
899
900 let mut iter = content.into_iter();
901
902 match iter.next().unwrap() {
903 message::UserContent::Image(message::Image {
904 data,
905 format,
906 media_type,
907 ..
908 }) => {
909 assert_eq!(data, "/9j/4AAQSkZJRg...");
910 assert_eq!(format.unwrap(), message::ContentFormat::Base64);
911 assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
912 }
913 _ => panic!("Expected image content"),
914 }
915
916 match iter.next().unwrap() {
917 message::UserContent::Text(message::Text { text }) => {
918 assert_eq!(text, "What is in this image?");
919 }
920 _ => panic!("Expected text content"),
921 }
922
923 match iter.next().unwrap() {
924 message::UserContent::Document(message::Document {
925 data,
926 format,
927 media_type,
928 }) => {
929 assert_eq!(data, "base64_encoded_pdf_data");
930 assert_eq!(format.unwrap(), message::ContentFormat::Base64);
931 assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
932 }
933 _ => panic!("Expected document content"),
934 }
935
936 assert_eq!(iter.next(), None);
937 }
938 _ => panic!("Expected user message"),
939 }
940
941 match converted_tool_message.clone() {
942 message::Message::User { content } => {
943 let message::ToolResult { id, content, .. } = match content.first() {
944 message::UserContent::ToolResult(tool_result) => tool_result,
945 _ => panic!("Expected tool result content"),
946 };
947 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
948 match content.first() {
949 message::ToolResultContent::Text(message::Text { text }) => {
950 assert_eq!(text, "15 degrees");
951 }
952 _ => panic!("Expected text content"),
953 }
954 }
955 _ => panic!("Expected tool result content"),
956 }
957
958 match converted_assistant_message.clone() {
959 message::Message::Assistant { content, .. } => {
960 assert_eq!(content.len(), 1);
961
962 match content.first() {
963 message::AssistantContent::ToolCall(message::ToolCall {
964 id, function, ..
965 }) => {
966 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
967 assert_eq!(function.name, "get_weather");
968 assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
969 }
970 _ => panic!("Expected tool call content"),
971 }
972 }
973 _ => panic!("Expected assistant message"),
974 }
975
976 let original_user_message: Message = converted_user_message.try_into().unwrap();
977 let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
978 let original_tool_message: Message = converted_tool_message.try_into().unwrap();
979
980 assert_eq!(user_message, original_user_message);
981 assert_eq!(assistant_message, original_assistant_message);
982 assert_eq!(tool_message, original_tool_message);
983 }
984}