1use crate::{
4 OneOrMany,
5 completion::{self, CompletionError, GetTokenUsage},
6 http_client::HttpClientExt,
7 json_utils,
8 message::{self, DocumentMediaType, DocumentSourceKind, MessageError, Reasoning},
9 one_or_many::string_or_one_or_many,
10 telemetry::{ProviderResponseExt, SpanCombinator},
11 wasm_compat::*,
12};
13use std::{convert::Infallible, str::FromStr};
14
15use super::client::Client;
16use crate::completion::CompletionRequest;
17use crate::providers::anthropic::streaming::StreamingCompletionResponse;
18use bytes::Bytes;
19use serde::{Deserialize, Serialize};
20use serde_json::json;
21use tracing::{Instrument, info_span};
22
23pub const CLAUDE_4_OPUS: &str = "claude-opus-4-0";
29
30pub const CLAUDE_4_SONNET: &str = "claude-sonnet-4-0";
32
33pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
35
36pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
38
39pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
41
42pub const CLAUDE_3_OPUS: &str = "claude-3-opus-latest";
44
45pub const CLAUDE_3_SONNET: &str = "claude-3-sonnet-20240229";
47
48pub const CLAUDE_3_HAIKU: &str = "claude-3-haiku-20240307";
50
51pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
52pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
53pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
54
55#[derive(Debug, Deserialize, Serialize)]
56pub struct CompletionResponse {
57 pub content: Vec<Content>,
58 pub id: String,
59 pub model: String,
60 pub role: String,
61 pub stop_reason: Option<String>,
62 pub stop_sequence: Option<String>,
63 pub usage: Usage,
64}
65
66impl ProviderResponseExt for CompletionResponse {
67 type OutputMessage = Content;
68 type Usage = Usage;
69
70 fn get_response_id(&self) -> Option<String> {
71 Some(self.id.to_owned())
72 }
73
74 fn get_response_model_name(&self) -> Option<String> {
75 Some(self.model.to_owned())
76 }
77
78 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
79 self.content.clone()
80 }
81
82 fn get_text_response(&self) -> Option<String> {
83 let res = self
84 .content
85 .iter()
86 .filter_map(|x| {
87 if let Content::Text { text } = x {
88 Some(text.to_owned())
89 } else {
90 None
91 }
92 })
93 .collect::<Vec<String>>()
94 .join("\n");
95
96 if res.is_empty() { None } else { Some(res) }
97 }
98
99 fn get_usage(&self) -> Option<Self::Usage> {
100 Some(self.usage.clone())
101 }
102}
103
104#[derive(Clone, Debug, Deserialize, Serialize)]
105pub struct Usage {
106 pub input_tokens: u64,
107 pub cache_read_input_tokens: Option<u64>,
108 pub cache_creation_input_tokens: Option<u64>,
109 pub output_tokens: u64,
110}
111
112impl std::fmt::Display for Usage {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 write!(
115 f,
116 "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
117 self.input_tokens,
118 match self.cache_read_input_tokens {
119 Some(token) => token.to_string(),
120 None => "n/a".to_string(),
121 },
122 match self.cache_creation_input_tokens {
123 Some(token) => token.to_string(),
124 None => "n/a".to_string(),
125 },
126 self.output_tokens
127 )
128 }
129}
130
131impl GetTokenUsage for Usage {
132 fn token_usage(&self) -> Option<crate::completion::Usage> {
133 let mut usage = crate::completion::Usage::new();
134
135 usage.input_tokens = self.input_tokens
136 + self.cache_creation_input_tokens.unwrap_or_default()
137 + self.cache_read_input_tokens.unwrap_or_default();
138 usage.output_tokens = self.output_tokens;
139 usage.total_tokens = usage.input_tokens + usage.output_tokens;
140
141 Some(usage)
142 }
143}
144
145#[derive(Debug, Deserialize, Serialize)]
146pub struct ToolDefinition {
147 pub name: String,
148 pub description: Option<String>,
149 pub input_schema: serde_json::Value,
150}
151
152#[derive(Debug, Deserialize, Serialize)]
153#[serde(tag = "type", rename_all = "snake_case")]
154pub enum CacheControl {
155 Ephemeral,
156}
157
158impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
159 type Error = CompletionError;
160
161 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
162 let content = response
163 .content
164 .iter()
165 .map(|content| {
166 Ok(match content {
167 Content::Text { text } => completion::AssistantContent::text(text),
168 Content::ToolUse { id, name, input } => {
169 completion::AssistantContent::tool_call(id, name, input.clone())
170 }
171 _ => {
172 return Err(CompletionError::ResponseError(
173 "Response did not contain a message or tool call".into(),
174 ));
175 }
176 })
177 })
178 .collect::<Result<Vec<_>, _>>()?;
179
180 let choice = OneOrMany::many(content).map_err(|_| {
181 CompletionError::ResponseError(
182 "Response contained no message or tool call (empty)".to_owned(),
183 )
184 })?;
185
186 let usage = completion::Usage {
187 input_tokens: response.usage.input_tokens,
188 output_tokens: response.usage.output_tokens,
189 total_tokens: response.usage.input_tokens + response.usage.output_tokens,
190 };
191
192 Ok(completion::CompletionResponse {
193 choice,
194 usage,
195 raw_response: response,
196 })
197 }
198}
199
200#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
201pub struct Message {
202 pub role: Role,
203 #[serde(deserialize_with = "string_or_one_or_many")]
204 pub content: OneOrMany<Content>,
205}
206
207#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
208#[serde(rename_all = "lowercase")]
209pub enum Role {
210 User,
211 Assistant,
212}
213
214#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
215#[serde(tag = "type", rename_all = "snake_case")]
216pub enum Content {
217 Text {
218 text: String,
219 },
220 Image {
221 source: ImageSource,
222 },
223 ToolUse {
224 id: String,
225 name: String,
226 input: serde_json::Value,
227 },
228 ToolResult {
229 tool_use_id: String,
230 #[serde(deserialize_with = "string_or_one_or_many")]
231 content: OneOrMany<ToolResultContent>,
232 #[serde(skip_serializing_if = "Option::is_none")]
233 is_error: Option<bool>,
234 },
235 Document {
236 source: DocumentSource,
237 },
238 Thinking {
239 thinking: String,
240 signature: Option<String>,
241 },
242}
243
244impl FromStr for Content {
245 type Err = Infallible;
246
247 fn from_str(s: &str) -> Result<Self, Self::Err> {
248 Ok(Content::Text { text: s.to_owned() })
249 }
250}
251
252#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
253#[serde(tag = "type", rename_all = "snake_case")]
254pub enum ToolResultContent {
255 Text { text: String },
256 Image(ImageSource),
257}
258
259impl FromStr for ToolResultContent {
260 type Err = Infallible;
261
262 fn from_str(s: &str) -> Result<Self, Self::Err> {
263 Ok(ToolResultContent::Text { text: s.to_owned() })
264 }
265}
266
267#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
268#[serde(untagged)]
269pub enum ImageSourceData {
270 Base64(String),
271 Url(String),
272}
273
274impl From<ImageSourceData> for DocumentSourceKind {
275 fn from(value: ImageSourceData) -> Self {
276 match value {
277 ImageSourceData::Base64(data) => DocumentSourceKind::Base64(data),
278 ImageSourceData::Url(url) => DocumentSourceKind::Url(url),
279 }
280 }
281}
282
283impl TryFrom<DocumentSourceKind> for ImageSourceData {
284 type Error = MessageError;
285
286 fn try_from(value: DocumentSourceKind) -> Result<Self, Self::Error> {
287 match value {
288 DocumentSourceKind::Base64(data) => Ok(ImageSourceData::Base64(data)),
289 DocumentSourceKind::Url(url) => Ok(ImageSourceData::Url(url)),
290 _ => Err(MessageError::ConversionError("Content has no body".into())),
291 }
292 }
293}
294
295impl From<ImageSourceData> for String {
296 fn from(value: ImageSourceData) -> Self {
297 match value {
298 ImageSourceData::Base64(s) | ImageSourceData::Url(s) => s,
299 }
300 }
301}
302
303#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
304pub struct ImageSource {
305 pub data: ImageSourceData,
306 pub media_type: ImageFormat,
307 pub r#type: SourceType,
308}
309
310#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
311pub struct DocumentSource {
312 pub data: String,
313 pub media_type: DocumentFormat,
314 pub r#type: SourceType,
315}
316
317#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
318#[serde(rename_all = "lowercase")]
319pub enum ImageFormat {
320 #[serde(rename = "image/jpeg")]
321 JPEG,
322 #[serde(rename = "image/png")]
323 PNG,
324 #[serde(rename = "image/gif")]
325 GIF,
326 #[serde(rename = "image/webp")]
327 WEBP,
328}
329
330#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
334#[serde(rename_all = "lowercase")]
335pub enum DocumentFormat {
336 #[serde(rename = "application/pdf")]
337 PDF,
338}
339
340#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
341#[serde(rename_all = "lowercase")]
342pub enum SourceType {
343 BASE64,
344 URL,
345}
346
347impl From<String> for Content {
348 fn from(text: String) -> Self {
349 Content::Text { text }
350 }
351}
352
353impl From<String> for ToolResultContent {
354 fn from(text: String) -> Self {
355 ToolResultContent::Text { text }
356 }
357}
358
359impl TryFrom<message::ContentFormat> for SourceType {
360 type Error = MessageError;
361
362 fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
363 match format {
364 message::ContentFormat::Base64 => Ok(SourceType::BASE64),
365 message::ContentFormat::Url => Ok(SourceType::URL),
366 message::ContentFormat::String => Err(MessageError::ConversionError(
367 "ContentFormat::String is deprecated, use ContentFormat::Url for URLs".into(),
368 )),
369 }
370 }
371}
372
373impl From<SourceType> for message::ContentFormat {
374 fn from(source_type: SourceType) -> Self {
375 match source_type {
376 SourceType::BASE64 => message::ContentFormat::Base64,
377 SourceType::URL => message::ContentFormat::Url,
378 }
379 }
380}
381
382impl TryFrom<message::ImageMediaType> for ImageFormat {
383 type Error = MessageError;
384
385 fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
386 Ok(match media_type {
387 message::ImageMediaType::JPEG => ImageFormat::JPEG,
388 message::ImageMediaType::PNG => ImageFormat::PNG,
389 message::ImageMediaType::GIF => ImageFormat::GIF,
390 message::ImageMediaType::WEBP => ImageFormat::WEBP,
391 _ => {
392 return Err(MessageError::ConversionError(
393 format!("Unsupported image media type: {media_type:?}").to_owned(),
394 ));
395 }
396 })
397 }
398}
399
400impl From<ImageFormat> for message::ImageMediaType {
401 fn from(format: ImageFormat) -> Self {
402 match format {
403 ImageFormat::JPEG => message::ImageMediaType::JPEG,
404 ImageFormat::PNG => message::ImageMediaType::PNG,
405 ImageFormat::GIF => message::ImageMediaType::GIF,
406 ImageFormat::WEBP => message::ImageMediaType::WEBP,
407 }
408 }
409}
410
411impl TryFrom<DocumentMediaType> for DocumentFormat {
412 type Error = MessageError;
413 fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
414 if !matches!(value, DocumentMediaType::PDF) {
415 return Err(MessageError::ConversionError(
416 "Anthropic only supports PDF documents".to_string(),
417 ));
418 };
419
420 Ok(DocumentFormat::PDF)
421 }
422}
423
424impl From<message::AssistantContent> for Content {
425 fn from(text: message::AssistantContent) -> Self {
426 match text {
427 message::AssistantContent::Text(message::Text { text }) => Content::Text { text },
428 message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
429 Content::ToolUse {
430 id,
431 name: function.name,
432 input: function.arguments,
433 }
434 }
435 message::AssistantContent::Reasoning(Reasoning { reasoning, id }) => {
436 Content::Thinking {
437 thinking: reasoning.first().cloned().unwrap_or(String::new()),
438 signature: id,
439 }
440 }
441 }
442 }
443}
444
445impl TryFrom<message::Message> for Message {
446 type Error = MessageError;
447
448 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
449 Ok(match message {
450 message::Message::User { content } => Message {
451 role: Role::User,
452 content: content.try_map(|content| match content {
453 message::UserContent::Text(message::Text { text }) => {
454 Ok(Content::Text { text })
455 }
456 message::UserContent::ToolResult(message::ToolResult {
457 id, content, ..
458 }) => Ok(Content::ToolResult {
459 tool_use_id: id,
460 content: content.try_map(|content| match content {
461 message::ToolResultContent::Text(message::Text { text }) => {
462 Ok(ToolResultContent::Text { text })
463 }
464 message::ToolResultContent::Image(image) => {
465 let DocumentSourceKind::Base64(data) = image.data else {
466 return Err(MessageError::ConversionError(
467 "Only base64 strings can be used with the Anthropic API"
468 .to_string(),
469 ));
470 };
471 let media_type =
472 image.media_type.ok_or(MessageError::ConversionError(
473 "Image media type is required".to_owned(),
474 ))?;
475 Ok(ToolResultContent::Image(ImageSource {
476 data: ImageSourceData::Base64(data),
477 media_type: media_type.try_into()?,
478 r#type: SourceType::BASE64,
479 }))
480 }
481 })?,
482 is_error: None,
483 }),
484 message::UserContent::Image(message::Image {
485 data, media_type, ..
486 }) => {
487 let media_type = media_type.ok_or(MessageError::ConversionError(
488 "Image media type is required for Claude API".into(),
489 ))?;
490
491 let source = match data {
492 DocumentSourceKind::Base64(data) => ImageSource {
493 data: ImageSourceData::Base64(data),
494 r#type: SourceType::BASE64,
495 media_type: ImageFormat::try_from(media_type)?,
496 },
497 DocumentSourceKind::Url(url) => ImageSource {
498 data: ImageSourceData::Url(url),
499 r#type: SourceType::URL,
500 media_type: ImageFormat::try_from(media_type)?,
501 },
502 DocumentSourceKind::Unknown => {
503 return Err(MessageError::ConversionError(
504 "Image content has no body".into(),
505 ));
506 }
507 doc => {
508 return Err(MessageError::ConversionError(format!(
509 "Unsupported document type: {doc:?}"
510 )));
511 }
512 };
513
514 Ok(Content::Image { source })
515 }
516 message::UserContent::Document(message::Document {
517 data, media_type, ..
518 }) => {
519 let media_type = media_type.ok_or(MessageError::ConversionError(
520 "Document media type is required".to_string(),
521 ))?;
522
523 let data = match data {
524 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
525 data
526 }
527 _ => {
528 return Err(MessageError::ConversionError(
529 "Only base64 encoded documents currently supported".into(),
530 ));
531 }
532 };
533
534 let source = DocumentSource {
535 data,
536 media_type: media_type.try_into()?,
537 r#type: SourceType::BASE64,
538 };
539 Ok(Content::Document { source })
540 }
541 message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
542 "Audio is not supported in Anthropic".to_owned(),
543 )),
544 message::UserContent::Video { .. } => Err(MessageError::ConversionError(
545 "Video is not supported in Anthropic".to_owned(),
546 )),
547 })?,
548 },
549
550 message::Message::Assistant { content, .. } => Message {
551 content: content.map(|content| content.into()),
552 role: Role::Assistant,
553 },
554 })
555 }
556}
557
558impl TryFrom<Content> for message::AssistantContent {
559 type Error = MessageError;
560
561 fn try_from(content: Content) -> Result<Self, Self::Error> {
562 Ok(match content {
563 Content::Text { text } => message::AssistantContent::text(text),
564 Content::ToolUse { id, name, input } => {
565 message::AssistantContent::tool_call(id, name, input)
566 }
567 Content::Thinking {
568 thinking,
569 signature,
570 } => message::AssistantContent::Reasoning(
571 Reasoning::new(&thinking).optional_id(signature),
572 ),
573 _ => {
574 return Err(MessageError::ConversionError(
575 format!("Unsupported content type for Assistant role: {content:?}").to_owned(),
576 ));
577 }
578 })
579 }
580}
581
582impl From<ToolResultContent> for message::ToolResultContent {
583 fn from(content: ToolResultContent) -> Self {
584 match content {
585 ToolResultContent::Text { text } => message::ToolResultContent::text(text),
586 ToolResultContent::Image(ImageSource {
587 data,
588 media_type: format,
589 ..
590 }) => message::ToolResultContent::image_base64(data, Some(format.into()), None),
591 }
592 }
593}
594
595impl TryFrom<Message> for message::Message {
596 type Error = MessageError;
597
598 fn try_from(message: Message) -> Result<Self, Self::Error> {
599 Ok(match message.role {
600 Role::User => message::Message::User {
601 content: message.content.try_map(|content| {
602 Ok(match content {
603 Content::Text { text } => message::UserContent::text(text),
604 Content::ToolResult {
605 tool_use_id,
606 content,
607 ..
608 } => message::UserContent::tool_result(
609 tool_use_id,
610 content.map(|content| content.into()),
611 ),
612 Content::Image { source } => message::UserContent::Image(message::Image {
613 data: source.data.into(),
614 media_type: Some(source.media_type.into()),
615 detail: None,
616 additional_params: None,
617 }),
618 Content::Document { source } => message::UserContent::document(
619 source.data,
620 Some(message::DocumentMediaType::PDF),
621 ),
622 _ => {
623 return Err(MessageError::ConversionError(
624 "Unsupported content type for User role".to_owned(),
625 ));
626 }
627 })
628 })?,
629 },
630 Role::Assistant => match message.content.first() {
631 Content::Text { .. } | Content::ToolUse { .. } | Content::Thinking { .. } => {
632 message::Message::Assistant {
633 id: None,
634 content: message.content.try_map(|content| content.try_into())?,
635 }
636 }
637
638 _ => {
639 return Err(MessageError::ConversionError(
640 format!("Unsupported message for Assistant role: {message:?}").to_owned(),
641 ));
642 }
643 },
644 })
645 }
646}
647
648#[derive(Clone)]
649pub struct CompletionModel<T = reqwest::Client>
650where
651 T: WasmCompatSend,
652{
653 pub(crate) client: Client<T>,
654 pub model: String,
655 pub default_max_tokens: Option<u64>,
656}
657
658impl<T> CompletionModel<T>
659where
660 T: HttpClientExt,
661{
662 pub fn new(client: Client<T>, model: &str) -> Self {
663 Self {
664 client,
665 model: model.to_string(),
666 default_max_tokens: calculate_max_tokens(model),
667 }
668 }
669}
670
671fn calculate_max_tokens(model: &str) -> Option<u64> {
677 if model.starts_with("claude-opus-4") {
678 Some(32000)
679 } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
680 Some(64000)
681 } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
682 Some(8192)
683 } else if model.starts_with("claude-3-opus")
684 || model.starts_with("claude-3-sonnet")
685 || model.starts_with("claude-3-haiku")
686 {
687 Some(4096)
688 } else {
689 None
690 }
691}
692
693#[derive(Debug, Deserialize, Serialize)]
694pub struct Metadata {
695 user_id: Option<String>,
696}
697
698#[derive(Default, Debug, Serialize, Deserialize)]
699#[serde(tag = "type", rename_all = "snake_case")]
700pub enum ToolChoice {
701 #[default]
702 Auto,
703 Any,
704 None,
705 Tool {
706 name: String,
707 },
708}
709impl TryFrom<message::ToolChoice> for ToolChoice {
710 type Error = CompletionError;
711
712 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
713 let res = match value {
714 message::ToolChoice::Auto => Self::Auto,
715 message::ToolChoice::None => Self::None,
716 message::ToolChoice::Required => Self::Any,
717 message::ToolChoice::Specific { function_names } => {
718 if function_names.len() != 1 {
719 return Err(CompletionError::ProviderError(
720 "Only one tool may be specified to be used by Claude".into(),
721 ));
722 }
723
724 Self::Tool {
725 name: function_names.first().unwrap().to_string(),
726 }
727 }
728 };
729
730 Ok(res)
731 }
732}
733impl<T> completion::CompletionModel for CompletionModel<T>
734where
735 T: HttpClientExt + Clone + Default + WasmCompatSend + WasmCompatSync + 'static,
736{
737 type Response = CompletionResponse;
738 type StreamingResponse = StreamingCompletionResponse;
739
740 #[cfg_attr(feature = "worker", worker::send)]
741 async fn completion(
742 &self,
743 completion_request: completion::CompletionRequest,
744 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
745 let span = if tracing::Span::current().is_disabled() {
746 info_span!(
747 target: "rig::completions",
748 "chat",
749 gen_ai.operation.name = "chat",
750 gen_ai.provider.name = "anthropic",
751 gen_ai.request.model = self.model,
752 gen_ai.system_instructions = &completion_request.preamble,
753 gen_ai.response.id = tracing::field::Empty,
754 gen_ai.response.model = tracing::field::Empty,
755 gen_ai.usage.output_tokens = tracing::field::Empty,
756 gen_ai.usage.input_tokens = tracing::field::Empty,
757 gen_ai.input.messages = tracing::field::Empty,
758 gen_ai.output.messages = tracing::field::Empty,
759 )
760 } else {
761 tracing::Span::current()
762 };
763 let max_tokens = if let Some(tokens) = completion_request.max_tokens {
769 tokens
770 } else if let Some(tokens) = self.default_max_tokens {
771 tokens
772 } else {
773 return Err(CompletionError::RequestError(
774 "`max_tokens` must be set for Anthropic".into(),
775 ));
776 };
777
778 let mut full_history = vec![];
779 if let Some(docs) = completion_request.normalized_documents() {
780 full_history.push(docs);
781 }
782 full_history.extend(completion_request.chat_history);
783 span.record_model_input(&full_history);
784
785 let full_history = full_history
786 .into_iter()
787 .map(Message::try_from)
788 .collect::<Result<Vec<Message>, _>>()?;
789
790 let mut request = json!({
791 "model": self.model,
792 "messages": full_history,
793 "max_tokens": max_tokens,
794 "system": completion_request.preamble.unwrap_or("".to_string()),
795 });
796
797 if let Some(temperature) = completion_request.temperature {
798 json_utils::merge_inplace(&mut request, json!({ "temperature": temperature }));
799 }
800
801 let tool_choice = if let Some(tool_choice) = completion_request.tool_choice {
802 Some(ToolChoice::try_from(tool_choice)?)
803 } else {
804 None
805 };
806
807 if !completion_request.tools.is_empty() {
808 json_utils::merge_inplace(
809 &mut request,
810 json!({
811 "tools": completion_request
812 .tools
813 .into_iter()
814 .map(|tool| ToolDefinition {
815 name: tool.name,
816 description: Some(tool.description),
817 input_schema: tool.parameters,
818 })
819 .collect::<Vec<_>>(),
820 "tool_choice": tool_choice,
821 }),
822 );
823 }
824
825 if let Some(ref params) = completion_request.additional_params {
826 json_utils::merge_inplace(&mut request, params.clone())
827 }
828
829 async move {
830 let request: Vec<u8> = serde_json::to_vec(&request)?;
831
832 let req = self
833 .client
834 .post("/v1/messages")
835 .header("Content-Type", "application/json")
836 .body(request)
837 .map_err(|e| CompletionError::HttpError(e.into()))?;
838
839 let response = self
840 .client
841 .send::<_, Bytes>(req)
842 .await
843 .map_err(CompletionError::HttpError)?;
844
845 if response.status().is_success() {
846 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(
847 response
848 .into_body()
849 .await
850 .map_err(CompletionError::HttpError)?
851 .to_vec()
852 .as_slice(),
853 )? {
854 ApiResponse::Message(completion) => {
855 let span = tracing::Span::current();
856 span.record_model_output(&completion.content);
857 span.record_response_metadata(&completion);
858 span.record_token_usage(&completion.usage);
859 completion.try_into()
860 }
861 ApiResponse::Error(ApiErrorResponse { message }) => {
862 Err(CompletionError::ResponseError(message))
863 }
864 }
865 } else {
866 let text: String = String::from_utf8_lossy(
867 &response
868 .into_body()
869 .await
870 .map_err(CompletionError::HttpError)?,
871 )
872 .into();
873 Err(CompletionError::ProviderError(text))
874 }
875 }
876 .instrument(span)
877 .await
878 }
879
880 #[cfg_attr(feature = "worker", worker::send)]
881 async fn stream(
882 &self,
883 request: CompletionRequest,
884 ) -> Result<
885 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
886 CompletionError,
887 > {
888 CompletionModel::stream(self, request).await
889 }
890}
891
892#[derive(Debug, Deserialize)]
893struct ApiErrorResponse {
894 message: String,
895}
896
897#[derive(Debug, Deserialize)]
898#[serde(tag = "type", rename_all = "snake_case")]
899enum ApiResponse<T> {
900 Message(T),
901 Error(ApiErrorResponse),
902}
903
904#[cfg(test)]
905mod tests {
906 use super::*;
907 use serde_path_to_error::deserialize;
908
909 #[test]
910 fn test_deserialize_message() {
911 let assistant_message_json = r#"
912 {
913 "role": "assistant",
914 "content": "\n\nHello there, how may I assist you today?"
915 }
916 "#;
917
918 let assistant_message_json2 = r#"
919 {
920 "role": "assistant",
921 "content": [
922 {
923 "type": "text",
924 "text": "\n\nHello there, how may I assist you today?"
925 },
926 {
927 "type": "tool_use",
928 "id": "toolu_01A09q90qw90lq917835lq9",
929 "name": "get_weather",
930 "input": {"location": "San Francisco, CA"}
931 }
932 ]
933 }
934 "#;
935
936 let user_message_json = r#"
937 {
938 "role": "user",
939 "content": [
940 {
941 "type": "image",
942 "source": {
943 "type": "base64",
944 "media_type": "image/jpeg",
945 "data": "/9j/4AAQSkZJRg..."
946 }
947 },
948 {
949 "type": "text",
950 "text": "What is in this image?"
951 },
952 {
953 "type": "tool_result",
954 "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
955 "content": "15 degrees"
956 }
957 ]
958 }
959 "#;
960
961 let assistant_message: Message = {
962 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
963 deserialize(jd).unwrap_or_else(|err| {
964 panic!("Deserialization error at {}: {}", err.path(), err);
965 })
966 };
967
968 let assistant_message2: Message = {
969 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
970 deserialize(jd).unwrap_or_else(|err| {
971 panic!("Deserialization error at {}: {}", err.path(), err);
972 })
973 };
974
975 let user_message: Message = {
976 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
977 deserialize(jd).unwrap_or_else(|err| {
978 panic!("Deserialization error at {}: {}", err.path(), err);
979 })
980 };
981
982 let Message { role, content } = assistant_message;
983 assert_eq!(role, Role::Assistant);
984 assert_eq!(
985 content.first(),
986 Content::Text {
987 text: "\n\nHello there, how may I assist you today?".to_owned()
988 }
989 );
990
991 let Message { role, content } = assistant_message2;
992 {
993 assert_eq!(role, Role::Assistant);
994 assert_eq!(content.len(), 2);
995
996 let mut iter = content.into_iter();
997
998 match iter.next().unwrap() {
999 Content::Text { text } => {
1000 assert_eq!(text, "\n\nHello there, how may I assist you today?");
1001 }
1002 _ => panic!("Expected text content"),
1003 }
1004
1005 match iter.next().unwrap() {
1006 Content::ToolUse { id, name, input } => {
1007 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1008 assert_eq!(name, "get_weather");
1009 assert_eq!(input, json!({"location": "San Francisco, CA"}));
1010 }
1011 _ => panic!("Expected tool use content"),
1012 }
1013
1014 assert_eq!(iter.next(), None);
1015 }
1016
1017 let Message { role, content } = user_message;
1018 {
1019 assert_eq!(role, Role::User);
1020 assert_eq!(content.len(), 3);
1021
1022 let mut iter = content.into_iter();
1023
1024 match iter.next().unwrap() {
1025 Content::Image { source } => {
1026 assert_eq!(
1027 source,
1028 ImageSource {
1029 data: ImageSourceData::Base64("/9j/4AAQSkZJRg...".to_owned()),
1030 media_type: ImageFormat::JPEG,
1031 r#type: SourceType::BASE64,
1032 }
1033 );
1034 }
1035 _ => panic!("Expected image content"),
1036 }
1037
1038 match iter.next().unwrap() {
1039 Content::Text { text } => {
1040 assert_eq!(text, "What is in this image?");
1041 }
1042 _ => panic!("Expected text content"),
1043 }
1044
1045 match iter.next().unwrap() {
1046 Content::ToolResult {
1047 tool_use_id,
1048 content,
1049 is_error,
1050 } => {
1051 assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
1052 assert_eq!(
1053 content.first(),
1054 ToolResultContent::Text {
1055 text: "15 degrees".to_owned()
1056 }
1057 );
1058 assert_eq!(is_error, None);
1059 }
1060 _ => panic!("Expected tool result content"),
1061 }
1062
1063 assert_eq!(iter.next(), None);
1064 }
1065 }
1066
1067 #[test]
1068 fn test_message_to_message_conversion() {
1069 let user_message: Message = serde_json::from_str(
1070 r#"
1071 {
1072 "role": "user",
1073 "content": [
1074 {
1075 "type": "image",
1076 "source": {
1077 "type": "base64",
1078 "media_type": "image/jpeg",
1079 "data": "/9j/4AAQSkZJRg..."
1080 }
1081 },
1082 {
1083 "type": "text",
1084 "text": "What is in this image?"
1085 },
1086 {
1087 "type": "document",
1088 "source": {
1089 "type": "base64",
1090 "data": "base64_encoded_pdf_data",
1091 "media_type": "application/pdf"
1092 }
1093 }
1094 ]
1095 }
1096 "#,
1097 )
1098 .unwrap();
1099
1100 let assistant_message = Message {
1101 role: Role::Assistant,
1102 content: OneOrMany::one(Content::ToolUse {
1103 id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1104 name: "get_weather".to_string(),
1105 input: json!({"location": "San Francisco, CA"}),
1106 }),
1107 };
1108
1109 let tool_message = Message {
1110 role: Role::User,
1111 content: OneOrMany::one(Content::ToolResult {
1112 tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1113 content: OneOrMany::one(ToolResultContent::Text {
1114 text: "15 degrees".to_string(),
1115 }),
1116 is_error: None,
1117 }),
1118 };
1119
1120 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1121 let converted_assistant_message: message::Message =
1122 assistant_message.clone().try_into().unwrap();
1123 let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
1124
1125 match converted_user_message.clone() {
1126 message::Message::User { content } => {
1127 assert_eq!(content.len(), 3);
1128
1129 let mut iter = content.into_iter();
1130
1131 match iter.next().unwrap() {
1132 message::UserContent::Image(message::Image {
1133 data, media_type, ..
1134 }) => {
1135 assert_eq!(data, DocumentSourceKind::base64("/9j/4AAQSkZJRg..."));
1136 assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
1137 }
1138 _ => panic!("Expected image content"),
1139 }
1140
1141 match iter.next().unwrap() {
1142 message::UserContent::Text(message::Text { text }) => {
1143 assert_eq!(text, "What is in this image?");
1144 }
1145 _ => panic!("Expected text content"),
1146 }
1147
1148 match iter.next().unwrap() {
1149 message::UserContent::Document(message::Document {
1150 data, media_type, ..
1151 }) => {
1152 assert_eq!(
1153 data,
1154 DocumentSourceKind::String("base64_encoded_pdf_data".into())
1155 );
1156 assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
1157 }
1158 _ => panic!("Expected document content"),
1159 }
1160
1161 assert_eq!(iter.next(), None);
1162 }
1163 _ => panic!("Expected user message"),
1164 }
1165
1166 match converted_tool_message.clone() {
1167 message::Message::User { content } => {
1168 let message::ToolResult { id, content, .. } = match content.first() {
1169 message::UserContent::ToolResult(tool_result) => tool_result,
1170 _ => panic!("Expected tool result content"),
1171 };
1172 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1173 match content.first() {
1174 message::ToolResultContent::Text(message::Text { text }) => {
1175 assert_eq!(text, "15 degrees");
1176 }
1177 _ => panic!("Expected text content"),
1178 }
1179 }
1180 _ => panic!("Expected tool result content"),
1181 }
1182
1183 match converted_assistant_message.clone() {
1184 message::Message::Assistant { content, .. } => {
1185 assert_eq!(content.len(), 1);
1186
1187 match content.first() {
1188 message::AssistantContent::ToolCall(message::ToolCall {
1189 id, function, ..
1190 }) => {
1191 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1192 assert_eq!(function.name, "get_weather");
1193 assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
1194 }
1195 _ => panic!("Expected tool call content"),
1196 }
1197 }
1198 _ => panic!("Expected assistant message"),
1199 }
1200
1201 let original_user_message: Message = converted_user_message.try_into().unwrap();
1202 let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
1203 let original_tool_message: Message = converted_tool_message.try_into().unwrap();
1204
1205 assert_eq!(user_message, original_user_message);
1206 assert_eq!(assistant_message, original_assistant_message);
1207 assert_eq!(tool_message, original_tool_message);
1208 }
1209
1210 #[test]
1211 fn test_content_format_conversion() {
1212 use crate::completion::message::ContentFormat;
1213
1214 let source_type: SourceType = ContentFormat::Url.try_into().unwrap();
1215 assert_eq!(source_type, SourceType::URL);
1216
1217 let content_format: ContentFormat = SourceType::URL.into();
1218 assert_eq!(content_format, ContentFormat::Url);
1219
1220 let source_type: SourceType = ContentFormat::Base64.try_into().unwrap();
1221 assert_eq!(source_type, SourceType::BASE64);
1222
1223 let content_format: ContentFormat = SourceType::BASE64.into();
1224 assert_eq!(content_format, ContentFormat::Base64);
1225
1226 let result: Result<SourceType, _> = ContentFormat::String.try_into();
1227 assert!(result.is_err());
1228 assert!(
1229 result
1230 .unwrap_err()
1231 .to_string()
1232 .contains("ContentFormat::String is deprecated")
1233 );
1234 }
1235}