1use crate::{
4 OneOrMany,
5 completion::{self, CompletionError, GetTokenUsage},
6 http_client::HttpClientExt,
7 json_utils,
8 message::{self, DocumentMediaType, DocumentSourceKind, MessageError, Reasoning},
9 one_or_many::string_or_one_or_many,
10 telemetry::{ProviderResponseExt, SpanCombinator},
11 wasm_compat::*,
12};
13use std::{convert::Infallible, str::FromStr};
14
15use super::client::Client;
16use crate::completion::CompletionRequest;
17use crate::providers::anthropic::streaming::StreamingCompletionResponse;
18use bytes::Bytes;
19use serde::{Deserialize, Serialize};
20use serde_json::json;
21use tracing::{Instrument, info_span};
22
23pub const CLAUDE_4_OPUS: &str = "claude-opus-4-0";
29
30pub const CLAUDE_4_SONNET: &str = "claude-sonnet-4-0";
32
33pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
35
36pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
38
39pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
41
42pub const CLAUDE_3_OPUS: &str = "claude-3-opus-latest";
44
45pub const CLAUDE_3_SONNET: &str = "claude-3-sonnet-20240229";
47
48pub const CLAUDE_3_HAIKU: &str = "claude-3-haiku-20240307";
50
51pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
52pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
53pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
54
55#[derive(Debug, Deserialize, Serialize)]
56pub struct CompletionResponse {
57 pub content: Vec<Content>,
58 pub id: String,
59 pub model: String,
60 pub role: String,
61 pub stop_reason: Option<String>,
62 pub stop_sequence: Option<String>,
63 pub usage: Usage,
64}
65
66impl ProviderResponseExt for CompletionResponse {
67 type OutputMessage = Content;
68 type Usage = Usage;
69
70 fn get_response_id(&self) -> Option<String> {
71 Some(self.id.to_owned())
72 }
73
74 fn get_response_model_name(&self) -> Option<String> {
75 Some(self.model.to_owned())
76 }
77
78 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
79 self.content.clone()
80 }
81
82 fn get_text_response(&self) -> Option<String> {
83 let res = self
84 .content
85 .iter()
86 .filter_map(|x| {
87 if let Content::Text { text } = x {
88 Some(text.to_owned())
89 } else {
90 None
91 }
92 })
93 .collect::<Vec<String>>()
94 .join("\n");
95
96 if res.is_empty() { None } else { Some(res) }
97 }
98
99 fn get_usage(&self) -> Option<Self::Usage> {
100 Some(self.usage.clone())
101 }
102}
103
104#[derive(Clone, Debug, Deserialize, Serialize)]
105pub struct Usage {
106 pub input_tokens: u64,
107 pub cache_read_input_tokens: Option<u64>,
108 pub cache_creation_input_tokens: Option<u64>,
109 pub output_tokens: u64,
110}
111
112impl std::fmt::Display for Usage {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 write!(
115 f,
116 "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
117 self.input_tokens,
118 match self.cache_read_input_tokens {
119 Some(token) => token.to_string(),
120 None => "n/a".to_string(),
121 },
122 match self.cache_creation_input_tokens {
123 Some(token) => token.to_string(),
124 None => "n/a".to_string(),
125 },
126 self.output_tokens
127 )
128 }
129}
130
131impl GetTokenUsage for Usage {
132 fn token_usage(&self) -> Option<crate::completion::Usage> {
133 let mut usage = crate::completion::Usage::new();
134
135 usage.input_tokens = self.input_tokens
136 + self.cache_creation_input_tokens.unwrap_or_default()
137 + self.cache_read_input_tokens.unwrap_or_default();
138 usage.output_tokens = self.output_tokens;
139 usage.total_tokens = usage.input_tokens + usage.output_tokens;
140
141 Some(usage)
142 }
143}
144
145#[derive(Debug, Deserialize, Serialize)]
146pub struct ToolDefinition {
147 pub name: String,
148 pub description: Option<String>,
149 pub input_schema: serde_json::Value,
150}
151
152#[derive(Debug, Deserialize, Serialize)]
153#[serde(tag = "type", rename_all = "snake_case")]
154pub enum CacheControl {
155 Ephemeral,
156}
157
158impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
159 type Error = CompletionError;
160
161 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
162 let content = response
163 .content
164 .iter()
165 .map(|content| {
166 Ok(match content {
167 Content::Text { text } => completion::AssistantContent::text(text),
168 Content::ToolUse { id, name, input } => {
169 completion::AssistantContent::tool_call(id, name, input.clone())
170 }
171 _ => {
172 return Err(CompletionError::ResponseError(
173 "Response did not contain a message or tool call".into(),
174 ));
175 }
176 })
177 })
178 .collect::<Result<Vec<_>, _>>()?;
179
180 let choice = OneOrMany::many(content).map_err(|_| {
181 CompletionError::ResponseError(
182 "Response contained no message or tool call (empty)".to_owned(),
183 )
184 })?;
185
186 let usage = completion::Usage {
187 input_tokens: response.usage.input_tokens,
188 output_tokens: response.usage.output_tokens,
189 total_tokens: response.usage.input_tokens + response.usage.output_tokens,
190 };
191
192 Ok(completion::CompletionResponse {
193 choice,
194 usage,
195 raw_response: response,
196 })
197 }
198}
199
200#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
201pub struct Message {
202 pub role: Role,
203 #[serde(deserialize_with = "string_or_one_or_many")]
204 pub content: OneOrMany<Content>,
205}
206
207#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
208#[serde(rename_all = "lowercase")]
209pub enum Role {
210 User,
211 Assistant,
212}
213
214#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
215#[serde(tag = "type", rename_all = "snake_case")]
216pub enum Content {
217 Text {
218 text: String,
219 },
220 Image {
221 source: ImageSource,
222 },
223 ToolUse {
224 id: String,
225 name: String,
226 input: serde_json::Value,
227 },
228 ToolResult {
229 tool_use_id: String,
230 #[serde(deserialize_with = "string_or_one_or_many")]
231 content: OneOrMany<ToolResultContent>,
232 #[serde(skip_serializing_if = "Option::is_none")]
233 is_error: Option<bool>,
234 },
235 Document {
236 source: DocumentSource,
237 },
238 Thinking {
239 thinking: String,
240 signature: Option<String>,
241 },
242}
243
244impl FromStr for Content {
245 type Err = Infallible;
246
247 fn from_str(s: &str) -> Result<Self, Self::Err> {
248 Ok(Content::Text { text: s.to_owned() })
249 }
250}
251
252#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
253#[serde(tag = "type", rename_all = "snake_case")]
254pub enum ToolResultContent {
255 Text { text: String },
256 Image(ImageSource),
257}
258
259impl FromStr for ToolResultContent {
260 type Err = Infallible;
261
262 fn from_str(s: &str) -> Result<Self, Self::Err> {
263 Ok(ToolResultContent::Text { text: s.to_owned() })
264 }
265}
266
267#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
268#[serde(untagged)]
269pub enum ImageSourceData {
270 Base64(String),
271 Url(String),
272}
273
274impl From<ImageSourceData> for DocumentSourceKind {
275 fn from(value: ImageSourceData) -> Self {
276 match value {
277 ImageSourceData::Base64(data) => DocumentSourceKind::Base64(data),
278 ImageSourceData::Url(url) => DocumentSourceKind::Url(url),
279 }
280 }
281}
282
283impl TryFrom<DocumentSourceKind> for ImageSourceData {
284 type Error = MessageError;
285
286 fn try_from(value: DocumentSourceKind) -> Result<Self, Self::Error> {
287 match value {
288 DocumentSourceKind::Base64(data) => Ok(ImageSourceData::Base64(data)),
289 DocumentSourceKind::Url(url) => Ok(ImageSourceData::Url(url)),
290 _ => Err(MessageError::ConversionError("Content has no body".into())),
291 }
292 }
293}
294
295impl From<ImageSourceData> for String {
296 fn from(value: ImageSourceData) -> Self {
297 match value {
298 ImageSourceData::Base64(s) | ImageSourceData::Url(s) => s,
299 }
300 }
301}
302
303#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
304pub struct ImageSource {
305 pub data: ImageSourceData,
306 pub media_type: ImageFormat,
307 pub r#type: SourceType,
308}
309
310#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
311pub struct DocumentSource {
312 pub data: String,
313 pub media_type: DocumentFormat,
314 pub r#type: SourceType,
315}
316
317#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
318#[serde(rename_all = "lowercase")]
319pub enum ImageFormat {
320 #[serde(rename = "image/jpeg")]
321 JPEG,
322 #[serde(rename = "image/png")]
323 PNG,
324 #[serde(rename = "image/gif")]
325 GIF,
326 #[serde(rename = "image/webp")]
327 WEBP,
328}
329
330#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
334#[serde(rename_all = "lowercase")]
335pub enum DocumentFormat {
336 #[serde(rename = "application/pdf")]
337 PDF,
338}
339
340#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
341#[serde(rename_all = "lowercase")]
342pub enum SourceType {
343 BASE64,
344 URL,
345}
346
347impl From<String> for Content {
348 fn from(text: String) -> Self {
349 Content::Text { text }
350 }
351}
352
353impl From<String> for ToolResultContent {
354 fn from(text: String) -> Self {
355 ToolResultContent::Text { text }
356 }
357}
358
359impl TryFrom<message::ContentFormat> for SourceType {
360 type Error = MessageError;
361
362 fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
363 match format {
364 message::ContentFormat::Base64 => Ok(SourceType::BASE64),
365 message::ContentFormat::Url => Ok(SourceType::URL),
366 message::ContentFormat::String => Err(MessageError::ConversionError(
367 "ContentFormat::String is deprecated, use ContentFormat::Url for URLs".into(),
368 )),
369 }
370 }
371}
372
373impl From<SourceType> for message::ContentFormat {
374 fn from(source_type: SourceType) -> Self {
375 match source_type {
376 SourceType::BASE64 => message::ContentFormat::Base64,
377 SourceType::URL => message::ContentFormat::Url,
378 }
379 }
380}
381
382impl TryFrom<message::ImageMediaType> for ImageFormat {
383 type Error = MessageError;
384
385 fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
386 Ok(match media_type {
387 message::ImageMediaType::JPEG => ImageFormat::JPEG,
388 message::ImageMediaType::PNG => ImageFormat::PNG,
389 message::ImageMediaType::GIF => ImageFormat::GIF,
390 message::ImageMediaType::WEBP => ImageFormat::WEBP,
391 _ => {
392 return Err(MessageError::ConversionError(
393 format!("Unsupported image media type: {media_type:?}").to_owned(),
394 ));
395 }
396 })
397 }
398}
399
400impl From<ImageFormat> for message::ImageMediaType {
401 fn from(format: ImageFormat) -> Self {
402 match format {
403 ImageFormat::JPEG => message::ImageMediaType::JPEG,
404 ImageFormat::PNG => message::ImageMediaType::PNG,
405 ImageFormat::GIF => message::ImageMediaType::GIF,
406 ImageFormat::WEBP => message::ImageMediaType::WEBP,
407 }
408 }
409}
410
411impl TryFrom<DocumentMediaType> for DocumentFormat {
412 type Error = MessageError;
413 fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
414 if !matches!(value, DocumentMediaType::PDF) {
415 return Err(MessageError::ConversionError(
416 "Anthropic only supports PDF documents".to_string(),
417 ));
418 };
419
420 Ok(DocumentFormat::PDF)
421 }
422}
423
424impl From<message::AssistantContent> for Content {
425 fn from(text: message::AssistantContent) -> Self {
426 match text {
427 message::AssistantContent::Text(message::Text { text }) => Content::Text { text },
428 message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
429 Content::ToolUse {
430 id,
431 name: function.name,
432 input: function.arguments,
433 }
434 }
435 message::AssistantContent::Reasoning(Reasoning {
436 reasoning,
437 signature,
438 ..
439 }) => Content::Thinking {
440 thinking: reasoning.first().cloned().unwrap_or(String::new()),
441 signature,
442 },
443 }
444 }
445}
446
447impl TryFrom<message::Message> for Message {
448 type Error = MessageError;
449
450 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
451 Ok(match message {
452 message::Message::User { content } => Message {
453 role: Role::User,
454 content: content.try_map(|content| match content {
455 message::UserContent::Text(message::Text { text }) => {
456 Ok(Content::Text { text })
457 }
458 message::UserContent::ToolResult(message::ToolResult {
459 id, content, ..
460 }) => Ok(Content::ToolResult {
461 tool_use_id: id,
462 content: content.try_map(|content| match content {
463 message::ToolResultContent::Text(message::Text { text }) => {
464 Ok(ToolResultContent::Text { text })
465 }
466 message::ToolResultContent::Image(image) => {
467 let DocumentSourceKind::Base64(data) = image.data else {
468 return Err(MessageError::ConversionError(
469 "Only base64 strings can be used with the Anthropic API"
470 .to_string(),
471 ));
472 };
473 let media_type =
474 image.media_type.ok_or(MessageError::ConversionError(
475 "Image media type is required".to_owned(),
476 ))?;
477 Ok(ToolResultContent::Image(ImageSource {
478 data: ImageSourceData::Base64(data),
479 media_type: media_type.try_into()?,
480 r#type: SourceType::BASE64,
481 }))
482 }
483 })?,
484 is_error: None,
485 }),
486 message::UserContent::Image(message::Image {
487 data, media_type, ..
488 }) => {
489 let media_type = media_type.ok_or(MessageError::ConversionError(
490 "Image media type is required for Claude API".into(),
491 ))?;
492
493 let source = match data {
494 DocumentSourceKind::Base64(data) => ImageSource {
495 data: ImageSourceData::Base64(data),
496 r#type: SourceType::BASE64,
497 media_type: ImageFormat::try_from(media_type)?,
498 },
499 DocumentSourceKind::Url(url) => ImageSource {
500 data: ImageSourceData::Url(url),
501 r#type: SourceType::URL,
502 media_type: ImageFormat::try_from(media_type)?,
503 },
504 DocumentSourceKind::Unknown => {
505 return Err(MessageError::ConversionError(
506 "Image content has no body".into(),
507 ));
508 }
509 doc => {
510 return Err(MessageError::ConversionError(format!(
511 "Unsupported document type: {doc:?}"
512 )));
513 }
514 };
515
516 Ok(Content::Image { source })
517 }
518 message::UserContent::Document(message::Document {
519 data, media_type, ..
520 }) => {
521 let media_type = media_type.ok_or(MessageError::ConversionError(
522 "Document media type is required".to_string(),
523 ))?;
524
525 let data = match data {
526 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
527 data
528 }
529 _ => {
530 return Err(MessageError::ConversionError(
531 "Only base64 encoded documents currently supported".into(),
532 ));
533 }
534 };
535
536 let source = DocumentSource {
537 data,
538 media_type: media_type.try_into()?,
539 r#type: SourceType::BASE64,
540 };
541 Ok(Content::Document { source })
542 }
543 message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
544 "Audio is not supported in Anthropic".to_owned(),
545 )),
546 message::UserContent::Video { .. } => Err(MessageError::ConversionError(
547 "Video is not supported in Anthropic".to_owned(),
548 )),
549 })?,
550 },
551
552 message::Message::Assistant { content, .. } => Message {
553 content: content.map(|content| content.into()),
554 role: Role::Assistant,
555 },
556 })
557 }
558}
559
560impl TryFrom<Content> for message::AssistantContent {
561 type Error = MessageError;
562
563 fn try_from(content: Content) -> Result<Self, Self::Error> {
564 Ok(match content {
565 Content::Text { text } => message::AssistantContent::text(text),
566 Content::ToolUse { id, name, input } => {
567 message::AssistantContent::tool_call(id, name, input)
568 }
569 Content::Thinking {
570 thinking,
571 signature,
572 } => message::AssistantContent::Reasoning(
573 Reasoning::new(&thinking).with_signature(signature),
574 ),
575 _ => {
576 return Err(MessageError::ConversionError(
577 format!("Unsupported content type for Assistant role: {content:?}").to_owned(),
578 ));
579 }
580 })
581 }
582}
583
584impl From<ToolResultContent> for message::ToolResultContent {
585 fn from(content: ToolResultContent) -> Self {
586 match content {
587 ToolResultContent::Text { text } => message::ToolResultContent::text(text),
588 ToolResultContent::Image(ImageSource {
589 data,
590 media_type: format,
591 ..
592 }) => message::ToolResultContent::image_base64(data, Some(format.into()), None),
593 }
594 }
595}
596
597impl TryFrom<Message> for message::Message {
598 type Error = MessageError;
599
600 fn try_from(message: Message) -> Result<Self, Self::Error> {
601 Ok(match message.role {
602 Role::User => message::Message::User {
603 content: message.content.try_map(|content| {
604 Ok(match content {
605 Content::Text { text } => message::UserContent::text(text),
606 Content::ToolResult {
607 tool_use_id,
608 content,
609 ..
610 } => message::UserContent::tool_result(
611 tool_use_id,
612 content.map(|content| content.into()),
613 ),
614 Content::Image { source } => message::UserContent::Image(message::Image {
615 data: source.data.into(),
616 media_type: Some(source.media_type.into()),
617 detail: None,
618 additional_params: None,
619 }),
620 Content::Document { source } => message::UserContent::document(
621 source.data,
622 Some(message::DocumentMediaType::PDF),
623 ),
624 _ => {
625 return Err(MessageError::ConversionError(
626 "Unsupported content type for User role".to_owned(),
627 ));
628 }
629 })
630 })?,
631 },
632 Role::Assistant => match message.content.first() {
633 Content::Text { .. } | Content::ToolUse { .. } | Content::Thinking { .. } => {
634 message::Message::Assistant {
635 id: None,
636 content: message.content.try_map(|content| content.try_into())?,
637 }
638 }
639
640 _ => {
641 return Err(MessageError::ConversionError(
642 format!("Unsupported message for Assistant role: {message:?}").to_owned(),
643 ));
644 }
645 },
646 })
647 }
648}
649
650#[derive(Clone)]
651pub struct CompletionModel<T = reqwest::Client>
652where
653 T: WasmCompatSend,
654{
655 pub(crate) client: Client<T>,
656 pub model: String,
657 pub default_max_tokens: Option<u64>,
658}
659
660impl<T> CompletionModel<T>
661where
662 T: HttpClientExt,
663{
664 pub fn new(client: Client<T>, model: &str) -> Self {
665 Self {
666 client,
667 model: model.to_string(),
668 default_max_tokens: calculate_max_tokens(model),
669 }
670 }
671}
672
673fn calculate_max_tokens(model: &str) -> Option<u64> {
679 if model.starts_with("claude-opus-4") {
680 Some(32000)
681 } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
682 Some(64000)
683 } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
684 Some(8192)
685 } else if model.starts_with("claude-3-opus")
686 || model.starts_with("claude-3-sonnet")
687 || model.starts_with("claude-3-haiku")
688 {
689 Some(4096)
690 } else {
691 None
692 }
693}
694
695#[derive(Debug, Deserialize, Serialize)]
696pub struct Metadata {
697 user_id: Option<String>,
698}
699
700#[derive(Default, Debug, Serialize, Deserialize)]
701#[serde(tag = "type", rename_all = "snake_case")]
702pub enum ToolChoice {
703 #[default]
704 Auto,
705 Any,
706 None,
707 Tool {
708 name: String,
709 },
710}
711impl TryFrom<message::ToolChoice> for ToolChoice {
712 type Error = CompletionError;
713
714 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
715 let res = match value {
716 message::ToolChoice::Auto => Self::Auto,
717 message::ToolChoice::None => Self::None,
718 message::ToolChoice::Required => Self::Any,
719 message::ToolChoice::Specific { function_names } => {
720 if function_names.len() != 1 {
721 return Err(CompletionError::ProviderError(
722 "Only one tool may be specified to be used by Claude".into(),
723 ));
724 }
725
726 Self::Tool {
727 name: function_names.first().unwrap().to_string(),
728 }
729 }
730 };
731
732 Ok(res)
733 }
734}
735impl<T> completion::CompletionModel for CompletionModel<T>
736where
737 T: HttpClientExt + Clone + Default + WasmCompatSend + WasmCompatSync + 'static,
738{
739 type Response = CompletionResponse;
740 type StreamingResponse = StreamingCompletionResponse;
741
742 #[cfg_attr(feature = "worker", worker::send)]
743 async fn completion(
744 &self,
745 completion_request: completion::CompletionRequest,
746 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
747 let span = if tracing::Span::current().is_disabled() {
748 info_span!(
749 target: "rig::completions",
750 "chat",
751 gen_ai.operation.name = "chat",
752 gen_ai.provider.name = "anthropic",
753 gen_ai.request.model = self.model,
754 gen_ai.system_instructions = &completion_request.preamble,
755 gen_ai.response.id = tracing::field::Empty,
756 gen_ai.response.model = tracing::field::Empty,
757 gen_ai.usage.output_tokens = tracing::field::Empty,
758 gen_ai.usage.input_tokens = tracing::field::Empty,
759 gen_ai.input.messages = tracing::field::Empty,
760 gen_ai.output.messages = tracing::field::Empty,
761 )
762 } else {
763 tracing::Span::current()
764 };
765 let max_tokens = if let Some(tokens) = completion_request.max_tokens {
771 tokens
772 } else if let Some(tokens) = self.default_max_tokens {
773 tokens
774 } else {
775 return Err(CompletionError::RequestError(
776 "`max_tokens` must be set for Anthropic".into(),
777 ));
778 };
779
780 let mut full_history = vec![];
781 if let Some(docs) = completion_request.normalized_documents() {
782 full_history.push(docs);
783 }
784 full_history.extend(completion_request.chat_history);
785 span.record_model_input(&full_history);
786
787 let full_history = full_history
788 .into_iter()
789 .map(Message::try_from)
790 .collect::<Result<Vec<Message>, _>>()?;
791
792 let mut request = json!({
793 "model": self.model,
794 "messages": full_history,
795 "max_tokens": max_tokens,
796 "system": completion_request.preamble.unwrap_or("".to_string()),
797 });
798
799 if let Some(temperature) = completion_request.temperature {
800 json_utils::merge_inplace(&mut request, json!({ "temperature": temperature }));
801 }
802
803 let tool_choice = if let Some(tool_choice) = completion_request.tool_choice {
804 Some(ToolChoice::try_from(tool_choice)?)
805 } else {
806 None
807 };
808
809 if !completion_request.tools.is_empty() {
810 let mut tools_json = json!({
811 "tools": completion_request
812 .tools
813 .into_iter()
814 .map(|tool| ToolDefinition {
815 name: tool.name,
816 description: Some(tool.description),
817 input_schema: tool.parameters,
818 })
819 .collect::<Vec<_>>(),
820 });
821
822 if let Some(tc) = tool_choice {
825 tools_json["tool_choice"] = serde_json::to_value(tc)?;
826 }
827
828 json_utils::merge_inplace(&mut request, tools_json);
829 }
830
831 if let Some(ref params) = completion_request.additional_params {
832 json_utils::merge_inplace(&mut request, params.clone())
833 }
834
835 async move {
836 let request: Vec<u8> = serde_json::to_vec(&request)?;
837
838 if let Ok(json_str) = String::from_utf8(request.clone()) {
839 tracing::debug!("Request body:\n{}", json_str);
840 }
841
842 let req = self
843 .client
844 .post("/v1/messages")
845 .header("Content-Type", "application/json")
846 .body(request)
847 .map_err(|e| CompletionError::HttpError(e.into()))?;
848
849 let response = self
850 .client
851 .send::<_, Bytes>(req)
852 .await
853 .map_err(CompletionError::HttpError)?;
854
855 if response.status().is_success() {
856 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(
857 response
858 .into_body()
859 .await
860 .map_err(CompletionError::HttpError)?
861 .to_vec()
862 .as_slice(),
863 )? {
864 ApiResponse::Message(completion) => {
865 let span = tracing::Span::current();
866 span.record_model_output(&completion.content);
867 span.record_response_metadata(&completion);
868 span.record_token_usage(&completion.usage);
869 completion.try_into()
870 }
871 ApiResponse::Error(ApiErrorResponse { message }) => {
872 Err(CompletionError::ResponseError(message))
873 }
874 }
875 } else {
876 let text: String = String::from_utf8_lossy(
877 &response
878 .into_body()
879 .await
880 .map_err(CompletionError::HttpError)?,
881 )
882 .into();
883 Err(CompletionError::ProviderError(text))
884 }
885 }
886 .instrument(span)
887 .await
888 }
889
890 #[cfg_attr(feature = "worker", worker::send)]
891 async fn stream(
892 &self,
893 request: CompletionRequest,
894 ) -> Result<
895 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
896 CompletionError,
897 > {
898 CompletionModel::stream(self, request).await
899 }
900}
901
902#[derive(Debug, Deserialize)]
903struct ApiErrorResponse {
904 message: String,
905}
906
907#[derive(Debug, Deserialize)]
908#[serde(tag = "type", rename_all = "snake_case")]
909enum ApiResponse<T> {
910 Message(T),
911 Error(ApiErrorResponse),
912}
913
914#[cfg(test)]
915mod tests {
916 use super::*;
917 use serde_path_to_error::deserialize;
918
919 #[test]
920 fn test_deserialize_message() {
921 let assistant_message_json = r#"
922 {
923 "role": "assistant",
924 "content": "\n\nHello there, how may I assist you today?"
925 }
926 "#;
927
928 let assistant_message_json2 = r#"
929 {
930 "role": "assistant",
931 "content": [
932 {
933 "type": "text",
934 "text": "\n\nHello there, how may I assist you today?"
935 },
936 {
937 "type": "tool_use",
938 "id": "toolu_01A09q90qw90lq917835lq9",
939 "name": "get_weather",
940 "input": {"location": "San Francisco, CA"}
941 }
942 ]
943 }
944 "#;
945
946 let user_message_json = r#"
947 {
948 "role": "user",
949 "content": [
950 {
951 "type": "image",
952 "source": {
953 "type": "base64",
954 "media_type": "image/jpeg",
955 "data": "/9j/4AAQSkZJRg..."
956 }
957 },
958 {
959 "type": "text",
960 "text": "What is in this image?"
961 },
962 {
963 "type": "tool_result",
964 "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
965 "content": "15 degrees"
966 }
967 ]
968 }
969 "#;
970
971 let assistant_message: Message = {
972 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
973 deserialize(jd).unwrap_or_else(|err| {
974 panic!("Deserialization error at {}: {}", err.path(), err);
975 })
976 };
977
978 let assistant_message2: Message = {
979 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
980 deserialize(jd).unwrap_or_else(|err| {
981 panic!("Deserialization error at {}: {}", err.path(), err);
982 })
983 };
984
985 let user_message: Message = {
986 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
987 deserialize(jd).unwrap_or_else(|err| {
988 panic!("Deserialization error at {}: {}", err.path(), err);
989 })
990 };
991
992 let Message { role, content } = assistant_message;
993 assert_eq!(role, Role::Assistant);
994 assert_eq!(
995 content.first(),
996 Content::Text {
997 text: "\n\nHello there, how may I assist you today?".to_owned()
998 }
999 );
1000
1001 let Message { role, content } = assistant_message2;
1002 {
1003 assert_eq!(role, Role::Assistant);
1004 assert_eq!(content.len(), 2);
1005
1006 let mut iter = content.into_iter();
1007
1008 match iter.next().unwrap() {
1009 Content::Text { text } => {
1010 assert_eq!(text, "\n\nHello there, how may I assist you today?");
1011 }
1012 _ => panic!("Expected text content"),
1013 }
1014
1015 match iter.next().unwrap() {
1016 Content::ToolUse { id, name, input } => {
1017 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1018 assert_eq!(name, "get_weather");
1019 assert_eq!(input, json!({"location": "San Francisco, CA"}));
1020 }
1021 _ => panic!("Expected tool use content"),
1022 }
1023
1024 assert_eq!(iter.next(), None);
1025 }
1026
1027 let Message { role, content } = user_message;
1028 {
1029 assert_eq!(role, Role::User);
1030 assert_eq!(content.len(), 3);
1031
1032 let mut iter = content.into_iter();
1033
1034 match iter.next().unwrap() {
1035 Content::Image { source } => {
1036 assert_eq!(
1037 source,
1038 ImageSource {
1039 data: ImageSourceData::Base64("/9j/4AAQSkZJRg...".to_owned()),
1040 media_type: ImageFormat::JPEG,
1041 r#type: SourceType::BASE64,
1042 }
1043 );
1044 }
1045 _ => panic!("Expected image content"),
1046 }
1047
1048 match iter.next().unwrap() {
1049 Content::Text { text } => {
1050 assert_eq!(text, "What is in this image?");
1051 }
1052 _ => panic!("Expected text content"),
1053 }
1054
1055 match iter.next().unwrap() {
1056 Content::ToolResult {
1057 tool_use_id,
1058 content,
1059 is_error,
1060 } => {
1061 assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
1062 assert_eq!(
1063 content.first(),
1064 ToolResultContent::Text {
1065 text: "15 degrees".to_owned()
1066 }
1067 );
1068 assert_eq!(is_error, None);
1069 }
1070 _ => panic!("Expected tool result content"),
1071 }
1072
1073 assert_eq!(iter.next(), None);
1074 }
1075 }
1076
1077 #[test]
1078 fn test_message_to_message_conversion() {
1079 let user_message: Message = serde_json::from_str(
1080 r#"
1081 {
1082 "role": "user",
1083 "content": [
1084 {
1085 "type": "image",
1086 "source": {
1087 "type": "base64",
1088 "media_type": "image/jpeg",
1089 "data": "/9j/4AAQSkZJRg..."
1090 }
1091 },
1092 {
1093 "type": "text",
1094 "text": "What is in this image?"
1095 },
1096 {
1097 "type": "document",
1098 "source": {
1099 "type": "base64",
1100 "data": "base64_encoded_pdf_data",
1101 "media_type": "application/pdf"
1102 }
1103 }
1104 ]
1105 }
1106 "#,
1107 )
1108 .unwrap();
1109
1110 let assistant_message = Message {
1111 role: Role::Assistant,
1112 content: OneOrMany::one(Content::ToolUse {
1113 id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1114 name: "get_weather".to_string(),
1115 input: json!({"location": "San Francisco, CA"}),
1116 }),
1117 };
1118
1119 let tool_message = Message {
1120 role: Role::User,
1121 content: OneOrMany::one(Content::ToolResult {
1122 tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1123 content: OneOrMany::one(ToolResultContent::Text {
1124 text: "15 degrees".to_string(),
1125 }),
1126 is_error: None,
1127 }),
1128 };
1129
1130 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1131 let converted_assistant_message: message::Message =
1132 assistant_message.clone().try_into().unwrap();
1133 let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
1134
1135 match converted_user_message.clone() {
1136 message::Message::User { content } => {
1137 assert_eq!(content.len(), 3);
1138
1139 let mut iter = content.into_iter();
1140
1141 match iter.next().unwrap() {
1142 message::UserContent::Image(message::Image {
1143 data, media_type, ..
1144 }) => {
1145 assert_eq!(data, DocumentSourceKind::base64("/9j/4AAQSkZJRg..."));
1146 assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
1147 }
1148 _ => panic!("Expected image content"),
1149 }
1150
1151 match iter.next().unwrap() {
1152 message::UserContent::Text(message::Text { text }) => {
1153 assert_eq!(text, "What is in this image?");
1154 }
1155 _ => panic!("Expected text content"),
1156 }
1157
1158 match iter.next().unwrap() {
1159 message::UserContent::Document(message::Document {
1160 data, media_type, ..
1161 }) => {
1162 assert_eq!(
1163 data,
1164 DocumentSourceKind::String("base64_encoded_pdf_data".into())
1165 );
1166 assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
1167 }
1168 _ => panic!("Expected document content"),
1169 }
1170
1171 assert_eq!(iter.next(), None);
1172 }
1173 _ => panic!("Expected user message"),
1174 }
1175
1176 match converted_tool_message.clone() {
1177 message::Message::User { content } => {
1178 let message::ToolResult { id, content, .. } = match content.first() {
1179 message::UserContent::ToolResult(tool_result) => tool_result,
1180 _ => panic!("Expected tool result content"),
1181 };
1182 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1183 match content.first() {
1184 message::ToolResultContent::Text(message::Text { text }) => {
1185 assert_eq!(text, "15 degrees");
1186 }
1187 _ => panic!("Expected text content"),
1188 }
1189 }
1190 _ => panic!("Expected tool result content"),
1191 }
1192
1193 match converted_assistant_message.clone() {
1194 message::Message::Assistant { content, .. } => {
1195 assert_eq!(content.len(), 1);
1196
1197 match content.first() {
1198 message::AssistantContent::ToolCall(message::ToolCall {
1199 id, function, ..
1200 }) => {
1201 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1202 assert_eq!(function.name, "get_weather");
1203 assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
1204 }
1205 _ => panic!("Expected tool call content"),
1206 }
1207 }
1208 _ => panic!("Expected assistant message"),
1209 }
1210
1211 let original_user_message: Message = converted_user_message.try_into().unwrap();
1212 let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
1213 let original_tool_message: Message = converted_tool_message.try_into().unwrap();
1214
1215 assert_eq!(user_message, original_user_message);
1216 assert_eq!(assistant_message, original_assistant_message);
1217 assert_eq!(tool_message, original_tool_message);
1218 }
1219
1220 #[test]
1221 fn test_content_format_conversion() {
1222 use crate::completion::message::ContentFormat;
1223
1224 let source_type: SourceType = ContentFormat::Url.try_into().unwrap();
1225 assert_eq!(source_type, SourceType::URL);
1226
1227 let content_format: ContentFormat = SourceType::URL.into();
1228 assert_eq!(content_format, ContentFormat::Url);
1229
1230 let source_type: SourceType = ContentFormat::Base64.try_into().unwrap();
1231 assert_eq!(source_type, SourceType::BASE64);
1232
1233 let content_format: ContentFormat = SourceType::BASE64.into();
1234 assert_eq!(content_format, ContentFormat::Base64);
1235
1236 let result: Result<SourceType, _> = ContentFormat::String.try_into();
1237 assert!(result.is_err());
1238 assert!(
1239 result
1240 .unwrap_err()
1241 .to_string()
1242 .contains("ContentFormat::String is deprecated")
1243 );
1244 }
1245}