1use crate::{
4 OneOrMany,
5 completion::{self, CompletionError},
6 json_utils,
7 message::{self, DocumentMediaType, MessageError},
8 one_or_many::string_or_one_or_many,
9};
10use std::{convert::Infallible, str::FromStr};
11
12use super::client::Client;
13use crate::completion::CompletionRequest;
14use crate::providers::anthropic::streaming::StreamingCompletionResponse;
15use serde::{Deserialize, Serialize};
16use serde_json::json;
17
18pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
23
24pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
26
27pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
29
30pub const CLAUDE_3_OPUS: &str = "claude-3-opus-latest";
32
33pub const CLAUDE_3_SONNET: &str = "claude-3-sonnet-20240229";
35
36pub const CLAUDE_3_HAIKU: &str = "claude-3-haiku-20240307";
38
39pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
40pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
41pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
42
43#[derive(Debug, Deserialize)]
44pub struct CompletionResponse {
45 pub content: Vec<Content>,
46 pub id: String,
47 pub model: String,
48 pub role: String,
49 pub stop_reason: Option<String>,
50 pub stop_sequence: Option<String>,
51 pub usage: Usage,
52}
53
54#[derive(Debug, Deserialize, Serialize)]
55pub struct Usage {
56 pub input_tokens: u64,
57 pub cache_read_input_tokens: Option<u64>,
58 pub cache_creation_input_tokens: Option<u64>,
59 pub output_tokens: u64,
60}
61
62impl std::fmt::Display for Usage {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(
65 f,
66 "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
67 self.input_tokens,
68 match self.cache_read_input_tokens {
69 Some(token) => token.to_string(),
70 None => "n/a".to_string(),
71 },
72 match self.cache_creation_input_tokens {
73 Some(token) => token.to_string(),
74 None => "n/a".to_string(),
75 },
76 self.output_tokens
77 )
78 }
79}
80
81#[derive(Debug, Deserialize, Serialize)]
82pub struct ToolDefinition {
83 pub name: String,
84 pub description: Option<String>,
85 pub input_schema: serde_json::Value,
86}
87
88#[derive(Debug, Deserialize, Serialize)]
89#[serde(tag = "type", rename_all = "snake_case")]
90pub enum CacheControl {
91 Ephemeral,
92}
93
94impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
95 type Error = CompletionError;
96
97 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
98 let content = response
99 .content
100 .iter()
101 .map(|content| {
102 Ok(match content {
103 Content::Text { text } => completion::AssistantContent::text(text),
104 Content::ToolUse { id, name, input } => {
105 completion::AssistantContent::tool_call(id, name, input.clone())
106 }
107 _ => {
108 return Err(CompletionError::ResponseError(
109 "Response did not contain a message or tool call".into(),
110 ));
111 }
112 })
113 })
114 .collect::<Result<Vec<_>, _>>()?;
115
116 let choice = OneOrMany::many(content).map_err(|_| {
117 CompletionError::ResponseError(
118 "Response contained no message or tool call (empty)".to_owned(),
119 )
120 })?;
121
122 let usage = completion::Usage {
123 input_tokens: response.usage.input_tokens,
124 output_tokens: response.usage.output_tokens,
125 total_tokens: response.usage.input_tokens + response.usage.output_tokens,
126 };
127
128 Ok(completion::CompletionResponse {
129 choice,
130 usage,
131 raw_response: response,
132 })
133 }
134}
135
136#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
137pub struct Message {
138 pub role: Role,
139 #[serde(deserialize_with = "string_or_one_or_many")]
140 pub content: OneOrMany<Content>,
141}
142
143#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
144#[serde(rename_all = "lowercase")]
145pub enum Role {
146 User,
147 Assistant,
148}
149
150#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
151#[serde(tag = "type", rename_all = "snake_case")]
152pub enum Content {
153 Text {
154 text: String,
155 },
156 Image {
157 source: ImageSource,
158 },
159 ToolUse {
160 id: String,
161 name: String,
162 input: serde_json::Value,
163 },
164 ToolResult {
165 tool_use_id: String,
166 #[serde(deserialize_with = "string_or_one_or_many")]
167 content: OneOrMany<ToolResultContent>,
168 #[serde(skip_serializing_if = "Option::is_none")]
169 is_error: Option<bool>,
170 },
171 Document {
172 source: DocumentSource,
173 },
174}
175
176impl FromStr for Content {
177 type Err = Infallible;
178
179 fn from_str(s: &str) -> Result<Self, Self::Err> {
180 Ok(Content::Text { text: s.to_owned() })
181 }
182}
183
184#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
185#[serde(tag = "type", rename_all = "snake_case")]
186pub enum ToolResultContent {
187 Text { text: String },
188 Image(ImageSource),
189}
190
191impl FromStr for ToolResultContent {
192 type Err = Infallible;
193
194 fn from_str(s: &str) -> Result<Self, Self::Err> {
195 Ok(ToolResultContent::Text { text: s.to_owned() })
196 }
197}
198
199#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
200pub struct ImageSource {
201 pub data: String,
202 pub media_type: ImageFormat,
203 pub r#type: SourceType,
204}
205
206#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
207pub struct DocumentSource {
208 pub data: String,
209 pub media_type: DocumentFormat,
210 pub r#type: SourceType,
211}
212
213#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
214#[serde(rename_all = "lowercase")]
215pub enum ImageFormat {
216 #[serde(rename = "image/jpeg")]
217 JPEG,
218 #[serde(rename = "image/png")]
219 PNG,
220 #[serde(rename = "image/gif")]
221 GIF,
222 #[serde(rename = "image/webp")]
223 WEBP,
224}
225
226#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
230#[serde(rename_all = "lowercase")]
231pub enum DocumentFormat {
232 #[serde(rename = "application/pdf")]
233 PDF,
234}
235
236#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
237#[serde(rename_all = "lowercase")]
238pub enum SourceType {
239 BASE64,
240}
241
242impl From<String> for Content {
243 fn from(text: String) -> Self {
244 Content::Text { text }
245 }
246}
247
248impl From<String> for ToolResultContent {
249 fn from(text: String) -> Self {
250 ToolResultContent::Text { text }
251 }
252}
253
254impl TryFrom<message::ContentFormat> for SourceType {
255 type Error = MessageError;
256
257 fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
258 match format {
259 message::ContentFormat::Base64 => Ok(SourceType::BASE64),
260 message::ContentFormat::String => Err(MessageError::ConversionError(
261 "Image urls are not supported in Anthropic".to_owned(),
262 )),
263 }
264 }
265}
266
267impl From<SourceType> for message::ContentFormat {
268 fn from(source_type: SourceType) -> Self {
269 match source_type {
270 SourceType::BASE64 => message::ContentFormat::Base64,
271 }
272 }
273}
274
275impl TryFrom<message::ImageMediaType> for ImageFormat {
276 type Error = MessageError;
277
278 fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
279 Ok(match media_type {
280 message::ImageMediaType::JPEG => ImageFormat::JPEG,
281 message::ImageMediaType::PNG => ImageFormat::PNG,
282 message::ImageMediaType::GIF => ImageFormat::GIF,
283 message::ImageMediaType::WEBP => ImageFormat::WEBP,
284 _ => {
285 return Err(MessageError::ConversionError(
286 format!("Unsupported image media type: {media_type:?}").to_owned(),
287 ));
288 }
289 })
290 }
291}
292
293impl From<ImageFormat> for message::ImageMediaType {
294 fn from(format: ImageFormat) -> Self {
295 match format {
296 ImageFormat::JPEG => message::ImageMediaType::JPEG,
297 ImageFormat::PNG => message::ImageMediaType::PNG,
298 ImageFormat::GIF => message::ImageMediaType::GIF,
299 ImageFormat::WEBP => message::ImageMediaType::WEBP,
300 }
301 }
302}
303
304impl TryFrom<DocumentMediaType> for DocumentFormat {
305 type Error = MessageError;
306 fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
307 if !matches!(value, DocumentMediaType::PDF) {
308 return Err(MessageError::ConversionError(
309 "Anthropic only supports PDF documents".to_string(),
310 ));
311 };
312
313 Ok(DocumentFormat::PDF)
314 }
315}
316
317impl From<message::AssistantContent> for Content {
318 fn from(text: message::AssistantContent) -> Self {
319 match text {
320 message::AssistantContent::Text(message::Text { text }) => Content::Text { text },
321 message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
322 Content::ToolUse {
323 id,
324 name: function.name,
325 input: function.arguments,
326 }
327 }
328 }
329 }
330}
331
332impl TryFrom<message::Message> for Message {
333 type Error = MessageError;
334
335 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
336 Ok(match message {
337 message::Message::User { content } => Message {
338 role: Role::User,
339 content: content.try_map(|content| match content {
340 message::UserContent::Text(message::Text { text }) => {
341 Ok(Content::Text { text })
342 }
343 message::UserContent::ToolResult(message::ToolResult {
344 id, content, ..
345 }) => Ok(Content::ToolResult {
346 tool_use_id: id,
347 content: content.try_map(|content| match content {
348 message::ToolResultContent::Text(message::Text { text }) => {
349 Ok(ToolResultContent::Text { text })
350 }
351 message::ToolResultContent::Image(image) => {
352 let media_type =
353 image.media_type.ok_or(MessageError::ConversionError(
354 "Image media type is required".to_owned(),
355 ))?;
356 let format = image.format.ok_or(MessageError::ConversionError(
357 "Image format is required".to_owned(),
358 ))?;
359 Ok(ToolResultContent::Image(ImageSource {
360 data: image.data,
361 media_type: media_type.try_into()?,
362 r#type: format.try_into()?,
363 }))
364 }
365 })?,
366 is_error: None,
367 }),
368 message::UserContent::Image(message::Image {
369 data,
370 format,
371 media_type,
372 ..
373 }) => {
374 let source = ImageSource {
375 data,
376 media_type: match media_type {
377 Some(media_type) => media_type.try_into()?,
378 None => {
379 return Err(MessageError::ConversionError(
380 "Image media type is required".to_owned(),
381 ));
382 }
383 },
384 r#type: match format {
385 Some(format) => format.try_into()?,
386 None => SourceType::BASE64,
387 },
388 };
389 Ok(Content::Image { source })
390 }
391 message::UserContent::Document(message::Document {
392 data,
393 format,
394 media_type,
395 }) => {
396 let Some(media_type) = media_type else {
397 return Err(MessageError::ConversionError(
398 "Document media type is required".to_string(),
399 ));
400 };
401
402 let source = DocumentSource {
403 data,
404 media_type: media_type.try_into()?,
405 r#type: match format {
406 Some(format) => format.try_into()?,
407 None => SourceType::BASE64,
408 },
409 };
410 Ok(Content::Document { source })
411 }
412 message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
413 "Audio is not supported in Anthropic".to_owned(),
414 )),
415 })?,
416 },
417
418 message::Message::Assistant { content, .. } => Message {
419 content: content.map(|content| content.into()),
420 role: Role::Assistant,
421 },
422 })
423 }
424}
425
426impl TryFrom<Content> for message::AssistantContent {
427 type Error = MessageError;
428
429 fn try_from(content: Content) -> Result<Self, Self::Error> {
430 Ok(match content {
431 Content::Text { text } => message::AssistantContent::text(text),
432 Content::ToolUse { id, name, input } => {
433 message::AssistantContent::tool_call(id, name, input)
434 }
435 _ => {
436 return Err(MessageError::ConversionError(
437 format!("Unsupported content type for Assistant role: {content:?}").to_owned(),
438 ));
439 }
440 })
441 }
442}
443
444impl From<ToolResultContent> for message::ToolResultContent {
445 fn from(content: ToolResultContent) -> Self {
446 match content {
447 ToolResultContent::Text { text } => message::ToolResultContent::text(text),
448 ToolResultContent::Image(ImageSource {
449 data,
450 media_type: format,
451 r#type,
452 }) => message::ToolResultContent::image(
453 data,
454 Some(r#type.into()),
455 Some(format.into()),
456 None,
457 ),
458 }
459 }
460}
461
462impl TryFrom<Message> for message::Message {
463 type Error = MessageError;
464
465 fn try_from(message: Message) -> Result<Self, Self::Error> {
466 Ok(match message.role {
467 Role::User => message::Message::User {
468 content: message.content.try_map(|content| {
469 Ok(match content {
470 Content::Text { text } => message::UserContent::text(text),
471 Content::ToolResult {
472 tool_use_id,
473 content,
474 ..
475 } => message::UserContent::tool_result(
476 tool_use_id,
477 content.map(|content| content.into()),
478 ),
479 Content::Image { source } => message::UserContent::Image(message::Image {
480 data: source.data,
481 format: Some(message::ContentFormat::Base64),
482 media_type: Some(source.media_type.into()),
483 detail: None,
484 }),
485 Content::Document { source } => message::UserContent::document(
486 source.data,
487 Some(message::ContentFormat::Base64),
488 Some(message::DocumentMediaType::PDF),
489 ),
490 _ => {
491 return Err(MessageError::ConversionError(
492 "Unsupported content type for User role".to_owned(),
493 ));
494 }
495 })
496 })?,
497 },
498 Role::Assistant => match message.content.first() {
499 Content::Text { .. } | Content::ToolUse { .. } => message::Message::Assistant {
500 id: None,
501 content: message.content.try_map(|content| content.try_into())?,
502 },
503
504 _ => {
505 return Err(MessageError::ConversionError(
506 format!("Unsupported message for Assistant role: {message:?}").to_owned(),
507 ));
508 }
509 },
510 })
511 }
512}
513
514#[derive(Clone)]
515pub struct CompletionModel {
516 pub(crate) client: Client,
517 pub model: String,
518 pub default_max_tokens: Option<u64>,
519}
520
521impl CompletionModel {
522 pub fn new(client: Client, model: &str) -> Self {
523 Self {
524 client,
525 model: model.to_string(),
526 default_max_tokens: calculate_max_tokens(model),
527 }
528 }
529}
530
531fn calculate_max_tokens(model: &str) -> Option<u64> {
537 if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
538 Some(8192)
539 } else if model.starts_with("claude-3-opus")
540 || model.starts_with("claude-3-sonnet")
541 || model.starts_with("claude-3-haiku")
542 {
543 Some(4096)
544 } else {
545 None
546 }
547}
548
549#[derive(Debug, Deserialize, Serialize)]
550struct Metadata {
551 user_id: Option<String>,
552}
553
554#[derive(Default, Debug, Serialize, Deserialize)]
555#[serde(tag = "type", rename_all = "snake_case")]
556pub enum ToolChoice {
557 #[default]
558 Auto,
559 Any,
560 Tool {
561 name: String,
562 },
563}
564
565impl completion::CompletionModel for CompletionModel {
566 type Response = CompletionResponse;
567 type StreamingResponse = StreamingCompletionResponse;
568
569 #[cfg_attr(feature = "worker", worker::send)]
570 async fn completion(
571 &self,
572 completion_request: completion::CompletionRequest,
573 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
574 let max_tokens = if let Some(tokens) = completion_request.max_tokens {
580 tokens
581 } else if let Some(tokens) = self.default_max_tokens {
582 tokens
583 } else {
584 return Err(CompletionError::RequestError(
585 "`max_tokens` must be set for Anthropic".into(),
586 ));
587 };
588
589 let mut full_history = vec![];
590 if let Some(docs) = completion_request.normalized_documents() {
591 full_history.push(docs);
592 }
593 full_history.extend(completion_request.chat_history);
594
595 let full_history = full_history
596 .into_iter()
597 .map(Message::try_from)
598 .collect::<Result<Vec<Message>, _>>()?;
599
600 let mut request = json!({
601 "model": self.model,
602 "messages": full_history,
603 "max_tokens": max_tokens,
604 "system": completion_request.preamble.unwrap_or("".to_string()),
605 });
606
607 if let Some(temperature) = completion_request.temperature {
608 json_utils::merge_inplace(&mut request, json!({ "temperature": temperature }));
609 }
610
611 if !completion_request.tools.is_empty() {
612 json_utils::merge_inplace(
613 &mut request,
614 json!({
615 "tools": completion_request
616 .tools
617 .into_iter()
618 .map(|tool| ToolDefinition {
619 name: tool.name,
620 description: Some(tool.description),
621 input_schema: tool.parameters,
622 })
623 .collect::<Vec<_>>(),
624 "tool_choice": ToolChoice::Auto,
625 }),
626 );
627 }
628
629 if let Some(ref params) = completion_request.additional_params {
630 json_utils::merge_inplace(&mut request, params.clone())
631 }
632
633 tracing::debug!("Anthropic completion request: {request}");
634
635 let response = self
636 .client
637 .post("/v1/messages")
638 .json(&request)
639 .send()
640 .await?;
641
642 if response.status().is_success() {
643 match response.json::<ApiResponse<CompletionResponse>>().await? {
644 ApiResponse::Message(completion) => {
645 tracing::info!(target: "rig",
646 "Anthropic completion token usage: {}",
647 completion.usage
648 );
649 completion.try_into()
650 }
651 ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message)),
652 }
653 } else {
654 Err(CompletionError::ProviderError(response.text().await?))
655 }
656 }
657
658 #[cfg_attr(feature = "worker", worker::send)]
659 async fn stream(
660 &self,
661 request: CompletionRequest,
662 ) -> Result<
663 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
664 CompletionError,
665 > {
666 CompletionModel::stream(self, request).await
667 }
668}
669
670#[derive(Debug, Deserialize)]
671struct ApiErrorResponse {
672 message: String,
673}
674
675#[derive(Debug, Deserialize)]
676#[serde(tag = "type", rename_all = "snake_case")]
677enum ApiResponse<T> {
678 Message(T),
679 Error(ApiErrorResponse),
680}
681
682#[cfg(test)]
683mod tests {
684 use super::*;
685 use serde_path_to_error::deserialize;
686
687 #[test]
688 fn test_deserialize_message() {
689 let assistant_message_json = r#"
690 {
691 "role": "assistant",
692 "content": "\n\nHello there, how may I assist you today?"
693 }
694 "#;
695
696 let assistant_message_json2 = r#"
697 {
698 "role": "assistant",
699 "content": [
700 {
701 "type": "text",
702 "text": "\n\nHello there, how may I assist you today?"
703 },
704 {
705 "type": "tool_use",
706 "id": "toolu_01A09q90qw90lq917835lq9",
707 "name": "get_weather",
708 "input": {"location": "San Francisco, CA"}
709 }
710 ]
711 }
712 "#;
713
714 let user_message_json = r#"
715 {
716 "role": "user",
717 "content": [
718 {
719 "type": "image",
720 "source": {
721 "type": "base64",
722 "media_type": "image/jpeg",
723 "data": "/9j/4AAQSkZJRg..."
724 }
725 },
726 {
727 "type": "text",
728 "text": "What is in this image?"
729 },
730 {
731 "type": "tool_result",
732 "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
733 "content": "15 degrees"
734 }
735 ]
736 }
737 "#;
738
739 let assistant_message: Message = {
740 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
741 deserialize(jd).unwrap_or_else(|err| {
742 panic!("Deserialization error at {}: {}", err.path(), err);
743 })
744 };
745
746 let assistant_message2: Message = {
747 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
748 deserialize(jd).unwrap_or_else(|err| {
749 panic!("Deserialization error at {}: {}", err.path(), err);
750 })
751 };
752
753 let user_message: Message = {
754 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
755 deserialize(jd).unwrap_or_else(|err| {
756 panic!("Deserialization error at {}: {}", err.path(), err);
757 })
758 };
759
760 let Message { role, content } = assistant_message;
761 assert_eq!(role, Role::Assistant);
762 assert_eq!(
763 content.first(),
764 Content::Text {
765 text: "\n\nHello there, how may I assist you today?".to_owned()
766 }
767 );
768
769 let Message { role, content } = assistant_message2;
770 {
771 assert_eq!(role, Role::Assistant);
772 assert_eq!(content.len(), 2);
773
774 let mut iter = content.into_iter();
775
776 match iter.next().unwrap() {
777 Content::Text { text } => {
778 assert_eq!(text, "\n\nHello there, how may I assist you today?");
779 }
780 _ => panic!("Expected text content"),
781 }
782
783 match iter.next().unwrap() {
784 Content::ToolUse { id, name, input } => {
785 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
786 assert_eq!(name, "get_weather");
787 assert_eq!(input, json!({"location": "San Francisco, CA"}));
788 }
789 _ => panic!("Expected tool use content"),
790 }
791
792 assert_eq!(iter.next(), None);
793 }
794
795 let Message { role, content } = user_message;
796 {
797 assert_eq!(role, Role::User);
798 assert_eq!(content.len(), 3);
799
800 let mut iter = content.into_iter();
801
802 match iter.next().unwrap() {
803 Content::Image { source } => {
804 assert_eq!(
805 source,
806 ImageSource {
807 data: "/9j/4AAQSkZJRg...".to_owned(),
808 media_type: ImageFormat::JPEG,
809 r#type: SourceType::BASE64,
810 }
811 );
812 }
813 _ => panic!("Expected image content"),
814 }
815
816 match iter.next().unwrap() {
817 Content::Text { text } => {
818 assert_eq!(text, "What is in this image?");
819 }
820 _ => panic!("Expected text content"),
821 }
822
823 match iter.next().unwrap() {
824 Content::ToolResult {
825 tool_use_id,
826 content,
827 is_error,
828 } => {
829 assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
830 assert_eq!(
831 content.first(),
832 ToolResultContent::Text {
833 text: "15 degrees".to_owned()
834 }
835 );
836 assert_eq!(is_error, None);
837 }
838 _ => panic!("Expected tool result content"),
839 }
840
841 assert_eq!(iter.next(), None);
842 }
843 }
844
845 #[test]
846 fn test_message_to_message_conversion() {
847 let user_message: Message = serde_json::from_str(
848 r#"
849 {
850 "role": "user",
851 "content": [
852 {
853 "type": "image",
854 "source": {
855 "type": "base64",
856 "media_type": "image/jpeg",
857 "data": "/9j/4AAQSkZJRg..."
858 }
859 },
860 {
861 "type": "text",
862 "text": "What is in this image?"
863 },
864 {
865 "type": "document",
866 "source": {
867 "type": "base64",
868 "data": "base64_encoded_pdf_data",
869 "media_type": "application/pdf"
870 }
871 }
872 ]
873 }
874 "#,
875 )
876 .unwrap();
877
878 let assistant_message = Message {
879 role: Role::Assistant,
880 content: OneOrMany::one(Content::ToolUse {
881 id: "toolu_01A09q90qw90lq917835lq9".to_string(),
882 name: "get_weather".to_string(),
883 input: json!({"location": "San Francisco, CA"}),
884 }),
885 };
886
887 let tool_message = Message {
888 role: Role::User,
889 content: OneOrMany::one(Content::ToolResult {
890 tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
891 content: OneOrMany::one(ToolResultContent::Text {
892 text: "15 degrees".to_string(),
893 }),
894 is_error: None,
895 }),
896 };
897
898 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
899 let converted_assistant_message: message::Message =
900 assistant_message.clone().try_into().unwrap();
901 let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
902
903 match converted_user_message.clone() {
904 message::Message::User { content } => {
905 assert_eq!(content.len(), 3);
906
907 let mut iter = content.into_iter();
908
909 match iter.next().unwrap() {
910 message::UserContent::Image(message::Image {
911 data,
912 format,
913 media_type,
914 ..
915 }) => {
916 assert_eq!(data, "/9j/4AAQSkZJRg...");
917 assert_eq!(format.unwrap(), message::ContentFormat::Base64);
918 assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
919 }
920 _ => panic!("Expected image content"),
921 }
922
923 match iter.next().unwrap() {
924 message::UserContent::Text(message::Text { text }) => {
925 assert_eq!(text, "What is in this image?");
926 }
927 _ => panic!("Expected text content"),
928 }
929
930 match iter.next().unwrap() {
931 message::UserContent::Document(message::Document {
932 data,
933 format,
934 media_type,
935 }) => {
936 assert_eq!(data, "base64_encoded_pdf_data");
937 assert_eq!(format.unwrap(), message::ContentFormat::Base64);
938 assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
939 }
940 _ => panic!("Expected document content"),
941 }
942
943 assert_eq!(iter.next(), None);
944 }
945 _ => panic!("Expected user message"),
946 }
947
948 match converted_tool_message.clone() {
949 message::Message::User { content } => {
950 let message::ToolResult { id, content, .. } = match content.first() {
951 message::UserContent::ToolResult(tool_result) => tool_result,
952 _ => panic!("Expected tool result content"),
953 };
954 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
955 match content.first() {
956 message::ToolResultContent::Text(message::Text { text }) => {
957 assert_eq!(text, "15 degrees");
958 }
959 _ => panic!("Expected text content"),
960 }
961 }
962 _ => panic!("Expected tool result content"),
963 }
964
965 match converted_assistant_message.clone() {
966 message::Message::Assistant { content, .. } => {
967 assert_eq!(content.len(), 1);
968
969 match content.first() {
970 message::AssistantContent::ToolCall(message::ToolCall {
971 id, function, ..
972 }) => {
973 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
974 assert_eq!(function.name, "get_weather");
975 assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
976 }
977 _ => panic!("Expected tool call content"),
978 }
979 }
980 _ => panic!("Expected assistant message"),
981 }
982
983 let original_user_message: Message = converted_user_message.try_into().unwrap();
984 let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
985 let original_tool_message: Message = converted_tool_message.try_into().unwrap();
986
987 assert_eq!(user_message, original_user_message);
988 assert_eq!(assistant_message, original_assistant_message);
989 assert_eq!(tool_message, original_tool_message);
990 }
991}