1use crate::{
4 completion::{self, CompletionError},
5 json_utils,
6 message::{self, DocumentMediaType, MessageError},
7 one_or_many::string_or_one_or_many,
8 OneOrMany,
9};
10use std::{convert::Infallible, str::FromStr};
11
12use super::client::Client;
13use crate::completion::CompletionRequest;
14use crate::providers::anthropic::streaming::StreamingCompletionResponse;
15use serde::{Deserialize, Serialize};
16use serde_json::json;
17
18pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
23
24pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
26
27pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
29
30pub const CLAUDE_3_OPUS: &str = "claude-3-opus-latest";
32
33pub const CLAUDE_3_SONNET: &str = "claude-3-sonnet-20240229";
35
36pub const CLAUDE_3_HAIKU: &str = "claude-3-haiku-20240307";
38
39pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
40pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
41pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
42
43#[derive(Debug, Deserialize)]
44pub struct CompletionResponse {
45 pub content: Vec<Content>,
46 pub id: String,
47 pub model: String,
48 pub role: String,
49 pub stop_reason: Option<String>,
50 pub stop_sequence: Option<String>,
51 pub usage: Usage,
52}
53
54#[derive(Debug, Deserialize, Serialize)]
55pub struct Usage {
56 pub input_tokens: u64,
57 pub cache_read_input_tokens: Option<u64>,
58 pub cache_creation_input_tokens: Option<u64>,
59 pub output_tokens: u64,
60}
61
62impl std::fmt::Display for Usage {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(
65 f,
66 "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
67 self.input_tokens,
68 match self.cache_read_input_tokens {
69 Some(token) => token.to_string(),
70 None => "n/a".to_string(),
71 },
72 match self.cache_creation_input_tokens {
73 Some(token) => token.to_string(),
74 None => "n/a".to_string(),
75 },
76 self.output_tokens
77 )
78 }
79}
80
81#[derive(Debug, Deserialize, Serialize)]
82pub struct ToolDefinition {
83 pub name: String,
84 pub description: Option<String>,
85 pub input_schema: serde_json::Value,
86}
87
88#[derive(Debug, Deserialize, Serialize)]
89#[serde(tag = "type", rename_all = "snake_case")]
90pub enum CacheControl {
91 Ephemeral,
92}
93
94impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
95 type Error = CompletionError;
96
97 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
98 let content = response
99 .content
100 .iter()
101 .map(|content| {
102 Ok(match content {
103 Content::Text { text } => completion::AssistantContent::text(text),
104 Content::ToolUse { id, name, input } => {
105 completion::AssistantContent::tool_call(id, name, input.clone())
106 }
107 _ => {
108 return Err(CompletionError::ResponseError(
109 "Response did not contain a message or tool call".into(),
110 ))
111 }
112 })
113 })
114 .collect::<Result<Vec<_>, _>>()?;
115
116 let choice = OneOrMany::many(content).map_err(|_| {
117 CompletionError::ResponseError(
118 "Response contained no message or tool call (empty)".to_owned(),
119 )
120 })?;
121
122 Ok(completion::CompletionResponse {
123 choice,
124 raw_response: response,
125 })
126 }
127}
128
129#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
130pub struct Message {
131 pub role: Role,
132 #[serde(deserialize_with = "string_or_one_or_many")]
133 pub content: OneOrMany<Content>,
134}
135
136#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
137#[serde(rename_all = "lowercase")]
138pub enum Role {
139 User,
140 Assistant,
141}
142
143#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
144#[serde(tag = "type", rename_all = "snake_case")]
145pub enum Content {
146 Text {
147 text: String,
148 },
149 Image {
150 source: ImageSource,
151 },
152 ToolUse {
153 id: String,
154 name: String,
155 input: serde_json::Value,
156 },
157 ToolResult {
158 tool_use_id: String,
159 #[serde(deserialize_with = "string_or_one_or_many")]
160 content: OneOrMany<ToolResultContent>,
161 #[serde(skip_serializing_if = "Option::is_none")]
162 is_error: Option<bool>,
163 },
164 Document {
165 source: DocumentSource,
166 },
167}
168
169impl FromStr for Content {
170 type Err = Infallible;
171
172 fn from_str(s: &str) -> Result<Self, Self::Err> {
173 Ok(Content::Text { text: s.to_owned() })
174 }
175}
176
177#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
178#[serde(tag = "type", rename_all = "snake_case")]
179pub enum ToolResultContent {
180 Text { text: String },
181 Image(ImageSource),
182}
183
184impl FromStr for ToolResultContent {
185 type Err = Infallible;
186
187 fn from_str(s: &str) -> Result<Self, Self::Err> {
188 Ok(ToolResultContent::Text { text: s.to_owned() })
189 }
190}
191
192#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
193pub struct ImageSource {
194 pub data: String,
195 pub media_type: ImageFormat,
196 pub r#type: SourceType,
197}
198
199#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
200pub struct DocumentSource {
201 pub data: String,
202 pub media_type: DocumentFormat,
203 pub r#type: SourceType,
204}
205
206#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
207#[serde(rename_all = "lowercase")]
208pub enum ImageFormat {
209 #[serde(rename = "image/jpeg")]
210 JPEG,
211 #[serde(rename = "image/png")]
212 PNG,
213 #[serde(rename = "image/gif")]
214 GIF,
215 #[serde(rename = "image/webp")]
216 WEBP,
217}
218
219#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
223#[serde(rename_all = "lowercase")]
224pub enum DocumentFormat {
225 #[serde(rename = "application/pdf")]
226 PDF,
227}
228
229#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
230#[serde(rename_all = "lowercase")]
231pub enum SourceType {
232 BASE64,
233}
234
235impl From<String> for Content {
236 fn from(text: String) -> Self {
237 Content::Text { text }
238 }
239}
240
241impl From<String> for ToolResultContent {
242 fn from(text: String) -> Self {
243 ToolResultContent::Text { text }
244 }
245}
246
247impl TryFrom<message::ContentFormat> for SourceType {
248 type Error = MessageError;
249
250 fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
251 match format {
252 message::ContentFormat::Base64 => Ok(SourceType::BASE64),
253 message::ContentFormat::String => Err(MessageError::ConversionError(
254 "Image urls are not supported in Anthropic".to_owned(),
255 )),
256 }
257 }
258}
259
260impl From<SourceType> for message::ContentFormat {
261 fn from(source_type: SourceType) -> Self {
262 match source_type {
263 SourceType::BASE64 => message::ContentFormat::Base64,
264 }
265 }
266}
267
268impl TryFrom<message::ImageMediaType> for ImageFormat {
269 type Error = MessageError;
270
271 fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
272 Ok(match media_type {
273 message::ImageMediaType::JPEG => ImageFormat::JPEG,
274 message::ImageMediaType::PNG => ImageFormat::PNG,
275 message::ImageMediaType::GIF => ImageFormat::GIF,
276 message::ImageMediaType::WEBP => ImageFormat::WEBP,
277 _ => {
278 return Err(MessageError::ConversionError(
279 format!("Unsupported image media type: {media_type:?}").to_owned(),
280 ))
281 }
282 })
283 }
284}
285
286impl From<ImageFormat> for message::ImageMediaType {
287 fn from(format: ImageFormat) -> Self {
288 match format {
289 ImageFormat::JPEG => message::ImageMediaType::JPEG,
290 ImageFormat::PNG => message::ImageMediaType::PNG,
291 ImageFormat::GIF => message::ImageMediaType::GIF,
292 ImageFormat::WEBP => message::ImageMediaType::WEBP,
293 }
294 }
295}
296
297impl TryFrom<DocumentMediaType> for DocumentFormat {
298 type Error = MessageError;
299 fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
300 if !matches!(value, DocumentMediaType::PDF) {
301 return Err(MessageError::ConversionError(
302 "Anthropic only supports PDF documents".to_string(),
303 ));
304 };
305
306 Ok(DocumentFormat::PDF)
307 }
308}
309
310impl From<message::AssistantContent> for Content {
311 fn from(text: message::AssistantContent) -> Self {
312 match text {
313 message::AssistantContent::Text(message::Text { text }) => Content::Text { text },
314 message::AssistantContent::ToolCall(message::ToolCall { id, function }) => {
315 Content::ToolUse {
316 id,
317 name: function.name,
318 input: function.arguments,
319 }
320 }
321 }
322 }
323}
324
325impl TryFrom<message::Message> for Message {
326 type Error = MessageError;
327
328 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
329 Ok(match message {
330 message::Message::User { content } => Message {
331 role: Role::User,
332 content: content.try_map(|content| match content {
333 message::UserContent::Text(message::Text { text }) => {
334 Ok(Content::Text { text })
335 }
336 message::UserContent::ToolResult(message::ToolResult { id, content }) => {
337 Ok(Content::ToolResult {
338 tool_use_id: id,
339 content: content.try_map(|content| match content {
340 message::ToolResultContent::Text(message::Text { text }) => {
341 Ok(ToolResultContent::Text { text })
342 }
343 message::ToolResultContent::Image(image) => {
344 let media_type =
345 image.media_type.ok_or(MessageError::ConversionError(
346 "Image media type is required".to_owned(),
347 ))?;
348 let format =
349 image.format.ok_or(MessageError::ConversionError(
350 "Image format is required".to_owned(),
351 ))?;
352 Ok(ToolResultContent::Image(ImageSource {
353 data: image.data,
354 media_type: media_type.try_into()?,
355 r#type: format.try_into()?,
356 }))
357 }
358 })?,
359 is_error: None,
360 })
361 }
362 message::UserContent::Image(message::Image {
363 data,
364 format,
365 media_type,
366 ..
367 }) => {
368 let source = ImageSource {
369 data,
370 media_type: match media_type {
371 Some(media_type) => media_type.try_into()?,
372 None => {
373 return Err(MessageError::ConversionError(
374 "Image media type is required".to_owned(),
375 ))
376 }
377 },
378 r#type: match format {
379 Some(format) => format.try_into()?,
380 None => SourceType::BASE64,
381 },
382 };
383 Ok(Content::Image { source })
384 }
385 message::UserContent::Document(message::Document {
386 data,
387 format,
388 media_type,
389 }) => {
390 let Some(media_type) = media_type else {
391 return Err(MessageError::ConversionError(
392 "Document media type is required".to_string(),
393 ));
394 };
395
396 let source = DocumentSource {
397 data,
398 media_type: media_type.try_into()?,
399 r#type: match format {
400 Some(format) => format.try_into()?,
401 None => SourceType::BASE64,
402 },
403 };
404 Ok(Content::Document { source })
405 }
406 message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
407 "Audio is not supported in Anthropic".to_owned(),
408 )),
409 })?,
410 },
411
412 message::Message::Assistant { content } => Message {
413 content: content.map(|content| content.into()),
414 role: Role::Assistant,
415 },
416 })
417 }
418}
419
420impl TryFrom<Content> for message::AssistantContent {
421 type Error = MessageError;
422
423 fn try_from(content: Content) -> Result<Self, Self::Error> {
424 Ok(match content {
425 Content::Text { text } => message::AssistantContent::text(text),
426 Content::ToolUse { id, name, input } => {
427 message::AssistantContent::tool_call(id, name, input)
428 }
429 _ => {
430 return Err(MessageError::ConversionError(
431 format!("Unsupported content type for Assistant role: {content:?}").to_owned(),
432 ))
433 }
434 })
435 }
436}
437
438impl From<ToolResultContent> for message::ToolResultContent {
439 fn from(content: ToolResultContent) -> Self {
440 match content {
441 ToolResultContent::Text { text } => message::ToolResultContent::text(text),
442 ToolResultContent::Image(ImageSource {
443 data,
444 media_type: format,
445 r#type,
446 }) => message::ToolResultContent::image(
447 data,
448 Some(r#type.into()),
449 Some(format.into()),
450 None,
451 ),
452 }
453 }
454}
455
456impl TryFrom<Message> for message::Message {
457 type Error = MessageError;
458
459 fn try_from(message: Message) -> Result<Self, Self::Error> {
460 Ok(match message.role {
461 Role::User => message::Message::User {
462 content: message.content.try_map(|content| {
463 Ok(match content {
464 Content::Text { text } => message::UserContent::text(text),
465 Content::ToolResult {
466 tool_use_id,
467 content,
468 ..
469 } => message::UserContent::tool_result(
470 tool_use_id,
471 content.map(|content| content.into()),
472 ),
473 Content::Image { source } => message::UserContent::Image(message::Image {
474 data: source.data,
475 format: Some(message::ContentFormat::Base64),
476 media_type: Some(source.media_type.into()),
477 detail: None,
478 }),
479 Content::Document { source } => message::UserContent::document(
480 source.data,
481 Some(message::ContentFormat::Base64),
482 Some(message::DocumentMediaType::PDF),
483 ),
484 _ => {
485 return Err(MessageError::ConversionError(
486 "Unsupported content type for User role".to_owned(),
487 ))
488 }
489 })
490 })?,
491 },
492 Role::Assistant => match message.content.first() {
493 Content::Text { .. } | Content::ToolUse { .. } => message::Message::Assistant {
494 content: message.content.try_map(|content| content.try_into())?,
495 },
496
497 _ => {
498 return Err(MessageError::ConversionError(
499 format!("Unsupported message for Assistant role: {message:?}").to_owned(),
500 ))
501 }
502 },
503 })
504 }
505}
506
507#[derive(Clone)]
508pub struct CompletionModel {
509 pub(crate) client: Client,
510 pub model: String,
511 pub default_max_tokens: Option<u64>,
512}
513
514impl CompletionModel {
515 pub fn new(client: Client, model: &str) -> Self {
516 Self {
517 client,
518 model: model.to_string(),
519 default_max_tokens: calculate_max_tokens(model),
520 }
521 }
522}
523
524fn calculate_max_tokens(model: &str) -> Option<u64> {
530 if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
531 Some(8192)
532 } else if model.starts_with("claude-3-opus")
533 || model.starts_with("claude-3-sonnet")
534 || model.starts_with("claude-3-haiku")
535 {
536 Some(4096)
537 } else {
538 None
539 }
540}
541
542#[derive(Debug, Deserialize, Serialize)]
543struct Metadata {
544 user_id: Option<String>,
545}
546
547#[derive(Default, Debug, Serialize, Deserialize)]
548#[serde(tag = "type", rename_all = "snake_case")]
549pub enum ToolChoice {
550 #[default]
551 Auto,
552 Any,
553 Tool {
554 name: String,
555 },
556}
557
558impl completion::CompletionModel for CompletionModel {
559 type Response = CompletionResponse;
560 type StreamingResponse = StreamingCompletionResponse;
561
562 #[cfg_attr(feature = "worker", worker::send)]
563 async fn completion(
564 &self,
565 completion_request: completion::CompletionRequest,
566 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
567 let max_tokens = if let Some(tokens) = completion_request.max_tokens {
573 tokens
574 } else if let Some(tokens) = self.default_max_tokens {
575 tokens
576 } else {
577 return Err(CompletionError::RequestError(
578 "`max_tokens` must be set for Anthropic".into(),
579 ));
580 };
581
582 let mut full_history = vec![];
583 if let Some(docs) = completion_request.normalized_documents() {
584 full_history.push(docs);
585 }
586 full_history.extend(completion_request.chat_history);
587
588 let full_history = full_history
589 .into_iter()
590 .map(Message::try_from)
591 .collect::<Result<Vec<Message>, _>>()?;
592
593 let mut request = json!({
594 "model": self.model,
595 "messages": full_history,
596 "max_tokens": max_tokens,
597 "system": completion_request.preamble.unwrap_or("".to_string()),
598 });
599
600 if let Some(temperature) = completion_request.temperature {
601 json_utils::merge_inplace(&mut request, json!({ "temperature": temperature }));
602 }
603
604 if !completion_request.tools.is_empty() {
605 json_utils::merge_inplace(
606 &mut request,
607 json!({
608 "tools": completion_request
609 .tools
610 .into_iter()
611 .map(|tool| ToolDefinition {
612 name: tool.name,
613 description: Some(tool.description),
614 input_schema: tool.parameters,
615 })
616 .collect::<Vec<_>>(),
617 "tool_choice": ToolChoice::Auto,
618 }),
619 );
620 }
621
622 if let Some(ref params) = completion_request.additional_params {
623 json_utils::merge_inplace(&mut request, params.clone())
624 }
625
626 tracing::debug!("Anthropic completion request: {request}");
627
628 let response = self
629 .client
630 .post("/v1/messages")
631 .json(&request)
632 .send()
633 .await?;
634
635 if response.status().is_success() {
636 match response.json::<ApiResponse<CompletionResponse>>().await? {
637 ApiResponse::Message(completion) => {
638 tracing::info!(target: "rig",
639 "Anthropic completion token usage: {}",
640 completion.usage
641 );
642 completion.try_into()
643 }
644 ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message)),
645 }
646 } else {
647 Err(CompletionError::ProviderError(response.text().await?))
648 }
649 }
650
651 #[cfg_attr(feature = "worker", worker::send)]
652 async fn stream(
653 &self,
654 request: CompletionRequest,
655 ) -> Result<
656 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
657 CompletionError,
658 > {
659 CompletionModel::stream(self, request).await
660 }
661}
662
663#[derive(Debug, Deserialize)]
664struct ApiErrorResponse {
665 message: String,
666}
667
668#[derive(Debug, Deserialize)]
669#[serde(tag = "type", rename_all = "snake_case")]
670enum ApiResponse<T> {
671 Message(T),
672 Error(ApiErrorResponse),
673}
674
675#[cfg(test)]
676mod tests {
677 use super::*;
678 use serde_path_to_error::deserialize;
679
680 #[test]
681 fn test_deserialize_message() {
682 let assistant_message_json = r#"
683 {
684 "role": "assistant",
685 "content": "\n\nHello there, how may I assist you today?"
686 }
687 "#;
688
689 let assistant_message_json2 = r#"
690 {
691 "role": "assistant",
692 "content": [
693 {
694 "type": "text",
695 "text": "\n\nHello there, how may I assist you today?"
696 },
697 {
698 "type": "tool_use",
699 "id": "toolu_01A09q90qw90lq917835lq9",
700 "name": "get_weather",
701 "input": {"location": "San Francisco, CA"}
702 }
703 ]
704 }
705 "#;
706
707 let user_message_json = r#"
708 {
709 "role": "user",
710 "content": [
711 {
712 "type": "image",
713 "source": {
714 "type": "base64",
715 "media_type": "image/jpeg",
716 "data": "/9j/4AAQSkZJRg..."
717 }
718 },
719 {
720 "type": "text",
721 "text": "What is in this image?"
722 },
723 {
724 "type": "tool_result",
725 "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
726 "content": "15 degrees"
727 }
728 ]
729 }
730 "#;
731
732 let assistant_message: Message = {
733 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
734 deserialize(jd).unwrap_or_else(|err| {
735 panic!("Deserialization error at {}: {}", err.path(), err);
736 })
737 };
738
739 let assistant_message2: Message = {
740 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
741 deserialize(jd).unwrap_or_else(|err| {
742 panic!("Deserialization error at {}: {}", err.path(), err);
743 })
744 };
745
746 let user_message: Message = {
747 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
748 deserialize(jd).unwrap_or_else(|err| {
749 panic!("Deserialization error at {}: {}", err.path(), err);
750 })
751 };
752
753 match assistant_message {
754 Message { role, content } => {
755 assert_eq!(role, Role::Assistant);
756 assert_eq!(
757 content.first(),
758 Content::Text {
759 text: "\n\nHello there, how may I assist you today?".to_owned()
760 }
761 );
762 }
763 }
764
765 match assistant_message2 {
766 Message { role, content } => {
767 assert_eq!(role, Role::Assistant);
768 assert_eq!(content.len(), 2);
769
770 let mut iter = content.into_iter();
771
772 match iter.next().unwrap() {
773 Content::Text { text } => {
774 assert_eq!(text, "\n\nHello there, how may I assist you today?");
775 }
776 _ => panic!("Expected text content"),
777 }
778
779 match iter.next().unwrap() {
780 Content::ToolUse { id, name, input } => {
781 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
782 assert_eq!(name, "get_weather");
783 assert_eq!(input, json!({"location": "San Francisco, CA"}));
784 }
785 _ => panic!("Expected tool use content"),
786 }
787
788 assert_eq!(iter.next(), None);
789 }
790 }
791
792 match user_message {
793 Message { role, content } => {
794 assert_eq!(role, Role::User);
795 assert_eq!(content.len(), 3);
796
797 let mut iter = content.into_iter();
798
799 match iter.next().unwrap() {
800 Content::Image { source } => {
801 assert_eq!(
802 source,
803 ImageSource {
804 data: "/9j/4AAQSkZJRg...".to_owned(),
805 media_type: ImageFormat::JPEG,
806 r#type: SourceType::BASE64,
807 }
808 );
809 }
810 _ => panic!("Expected image content"),
811 }
812
813 match iter.next().unwrap() {
814 Content::Text { text } => {
815 assert_eq!(text, "What is in this image?");
816 }
817 _ => panic!("Expected text content"),
818 }
819
820 match iter.next().unwrap() {
821 Content::ToolResult {
822 tool_use_id,
823 content,
824 is_error,
825 } => {
826 assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
827 assert_eq!(
828 content.first(),
829 ToolResultContent::Text {
830 text: "15 degrees".to_owned()
831 }
832 );
833 assert_eq!(is_error, None);
834 }
835 _ => panic!("Expected tool result content"),
836 }
837
838 assert_eq!(iter.next(), None);
839 }
840 }
841 }
842
843 #[test]
844 fn test_message_to_message_conversion() {
845 let user_message: Message = serde_json::from_str(
846 r#"
847 {
848 "role": "user",
849 "content": [
850 {
851 "type": "image",
852 "source": {
853 "type": "base64",
854 "media_type": "image/jpeg",
855 "data": "/9j/4AAQSkZJRg..."
856 }
857 },
858 {
859 "type": "text",
860 "text": "What is in this image?"
861 },
862 {
863 "type": "document",
864 "source": {
865 "type": "base64",
866 "data": "base64_encoded_pdf_data",
867 "media_type": "application/pdf"
868 }
869 }
870 ]
871 }
872 "#,
873 )
874 .unwrap();
875
876 let assistant_message = Message {
877 role: Role::Assistant,
878 content: OneOrMany::one(Content::ToolUse {
879 id: "toolu_01A09q90qw90lq917835lq9".to_string(),
880 name: "get_weather".to_string(),
881 input: json!({"location": "San Francisco, CA"}),
882 }),
883 };
884
885 let tool_message = Message {
886 role: Role::User,
887 content: OneOrMany::one(Content::ToolResult {
888 tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
889 content: OneOrMany::one(ToolResultContent::Text {
890 text: "15 degrees".to_string(),
891 }),
892 is_error: None,
893 }),
894 };
895
896 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
897 let converted_assistant_message: message::Message =
898 assistant_message.clone().try_into().unwrap();
899 let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
900
901 match converted_user_message.clone() {
902 message::Message::User { content } => {
903 assert_eq!(content.len(), 3);
904
905 let mut iter = content.into_iter();
906
907 match iter.next().unwrap() {
908 message::UserContent::Image(message::Image {
909 data,
910 format,
911 media_type,
912 ..
913 }) => {
914 assert_eq!(data, "/9j/4AAQSkZJRg...");
915 assert_eq!(format.unwrap(), message::ContentFormat::Base64);
916 assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
917 }
918 _ => panic!("Expected image content"),
919 }
920
921 match iter.next().unwrap() {
922 message::UserContent::Text(message::Text { text }) => {
923 assert_eq!(text, "What is in this image?");
924 }
925 _ => panic!("Expected text content"),
926 }
927
928 match iter.next().unwrap() {
929 message::UserContent::Document(message::Document {
930 data,
931 format,
932 media_type,
933 }) => {
934 assert_eq!(data, "base64_encoded_pdf_data");
935 assert_eq!(format.unwrap(), message::ContentFormat::Base64);
936 assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
937 }
938 _ => panic!("Expected document content"),
939 }
940
941 assert_eq!(iter.next(), None);
942 }
943 _ => panic!("Expected user message"),
944 }
945
946 match converted_tool_message.clone() {
947 message::Message::User { content } => {
948 let message::ToolResult { id, content, .. } = match content.first() {
949 message::UserContent::ToolResult(tool_result) => tool_result,
950 _ => panic!("Expected tool result content"),
951 };
952 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
953 match content.first() {
954 message::ToolResultContent::Text(message::Text { text }) => {
955 assert_eq!(text, "15 degrees");
956 }
957 _ => panic!("Expected text content"),
958 }
959 }
960 _ => panic!("Expected tool result content"),
961 }
962
963 match converted_assistant_message.clone() {
964 message::Message::Assistant { content } => {
965 assert_eq!(content.len(), 1);
966
967 match content.first() {
968 message::AssistantContent::ToolCall(message::ToolCall { id, function }) => {
969 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
970 assert_eq!(function.name, "get_weather");
971 assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
972 }
973 _ => panic!("Expected tool call content"),
974 }
975 }
976 _ => panic!("Expected assistant message"),
977 }
978
979 let original_user_message: Message = converted_user_message.try_into().unwrap();
980 let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
981 let original_tool_message: Message = converted_tool_message.try_into().unwrap();
982
983 assert_eq!(user_message, original_user_message);
984 assert_eq!(assistant_message, original_assistant_message);
985 assert_eq!(tool_message, original_tool_message);
986 }
987}