rig/providers/anthropic/
completion.rs

1//! Anthropic completion api implementation
2
3use crate::{
4    OneOrMany,
5    completion::{self, CompletionError, GetTokenUsage},
6    http_client::HttpClientExt,
7    message::{self, DocumentMediaType, DocumentSourceKind, MessageError, Reasoning},
8    one_or_many::string_or_one_or_many,
9    telemetry::{ProviderResponseExt, SpanCombinator},
10    wasm_compat::*,
11};
12use std::{convert::Infallible, str::FromStr};
13
14use super::client::Client;
15use crate::completion::CompletionRequest;
16use crate::providers::anthropic::streaming::StreamingCompletionResponse;
17use bytes::Bytes;
18use serde::{Deserialize, Serialize};
19use tracing::{Instrument, info_span};
20
21// ================================================================
22// Anthropic Completion API
23// ================================================================
24
25/// `claude-opus-4-0` completion model
26pub const CLAUDE_4_OPUS: &str = "claude-opus-4-0";
27/// `claude-sonnet-4-0` completion model
28pub const CLAUDE_4_SONNET: &str = "claude-sonnet-4-0";
29/// `claude-3-7-sonnet-latest` completion model
30pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
31/// `claude-3-5-sonnet-latest` completion model
32pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
33/// `claude-3-5-haiku-latest` completion model
34pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
35
36pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
37pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
38pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
39
40#[derive(Debug, Deserialize, Serialize)]
41pub struct CompletionResponse {
42    pub content: Vec<Content>,
43    pub id: String,
44    pub model: String,
45    pub role: String,
46    pub stop_reason: Option<String>,
47    pub stop_sequence: Option<String>,
48    pub usage: Usage,
49}
50
51impl ProviderResponseExt for CompletionResponse {
52    type OutputMessage = Content;
53    type Usage = Usage;
54
55    fn get_response_id(&self) -> Option<String> {
56        Some(self.id.to_owned())
57    }
58
59    fn get_response_model_name(&self) -> Option<String> {
60        Some(self.model.to_owned())
61    }
62
63    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
64        self.content.clone()
65    }
66
67    fn get_text_response(&self) -> Option<String> {
68        let res = self
69            .content
70            .iter()
71            .filter_map(|x| {
72                if let Content::Text { text } = x {
73                    Some(text.to_owned())
74                } else {
75                    None
76                }
77            })
78            .collect::<Vec<String>>()
79            .join("\n");
80
81        if res.is_empty() { None } else { Some(res) }
82    }
83
84    fn get_usage(&self) -> Option<Self::Usage> {
85        Some(self.usage.clone())
86    }
87}
88
89#[derive(Clone, Debug, Deserialize, Serialize)]
90pub struct Usage {
91    pub input_tokens: u64,
92    pub cache_read_input_tokens: Option<u64>,
93    pub cache_creation_input_tokens: Option<u64>,
94    pub output_tokens: u64,
95}
96
97impl std::fmt::Display for Usage {
98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99        write!(
100            f,
101            "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
102            self.input_tokens,
103            match self.cache_read_input_tokens {
104                Some(token) => token.to_string(),
105                None => "n/a".to_string(),
106            },
107            match self.cache_creation_input_tokens {
108                Some(token) => token.to_string(),
109                None => "n/a".to_string(),
110            },
111            self.output_tokens
112        )
113    }
114}
115
116impl GetTokenUsage for Usage {
117    fn token_usage(&self) -> Option<crate::completion::Usage> {
118        let mut usage = crate::completion::Usage::new();
119
120        usage.input_tokens = self.input_tokens
121            + self.cache_creation_input_tokens.unwrap_or_default()
122            + self.cache_read_input_tokens.unwrap_or_default();
123        usage.output_tokens = self.output_tokens;
124        usage.total_tokens = usage.input_tokens + usage.output_tokens;
125
126        Some(usage)
127    }
128}
129
130#[derive(Debug, Deserialize, Serialize)]
131pub struct ToolDefinition {
132    pub name: String,
133    pub description: Option<String>,
134    pub input_schema: serde_json::Value,
135}
136
137#[derive(Debug, Deserialize, Serialize)]
138#[serde(tag = "type", rename_all = "snake_case")]
139pub enum CacheControl {
140    Ephemeral,
141}
142
143impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
144    type Error = CompletionError;
145
146    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
147        let content = response
148            .content
149            .iter()
150            .map(|content| content.clone().try_into())
151            .collect::<Result<Vec<_>, _>>()?;
152
153        let choice = OneOrMany::many(content).map_err(|_| {
154            CompletionError::ResponseError(
155                "Response contained no message or tool call (empty)".to_owned(),
156            )
157        })?;
158
159        let usage = completion::Usage {
160            input_tokens: response.usage.input_tokens,
161            output_tokens: response.usage.output_tokens,
162            total_tokens: response.usage.input_tokens + response.usage.output_tokens,
163        };
164
165        Ok(completion::CompletionResponse {
166            choice,
167            usage,
168            raw_response: response,
169        })
170    }
171}
172
173#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
174pub struct Message {
175    pub role: Role,
176    #[serde(deserialize_with = "string_or_one_or_many")]
177    pub content: OneOrMany<Content>,
178}
179
180#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
181#[serde(rename_all = "lowercase")]
182pub enum Role {
183    User,
184    Assistant,
185}
186
187#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
188#[serde(tag = "type", rename_all = "snake_case")]
189pub enum Content {
190    Text {
191        text: String,
192    },
193    Image {
194        source: ImageSource,
195    },
196    ToolUse {
197        id: String,
198        name: String,
199        input: serde_json::Value,
200    },
201    ToolResult {
202        tool_use_id: String,
203        #[serde(deserialize_with = "string_or_one_or_many")]
204        content: OneOrMany<ToolResultContent>,
205        #[serde(skip_serializing_if = "Option::is_none")]
206        is_error: Option<bool>,
207    },
208    Document {
209        source: DocumentSource,
210    },
211    Thinking {
212        thinking: String,
213        signature: Option<String>,
214    },
215}
216
217impl FromStr for Content {
218    type Err = Infallible;
219
220    fn from_str(s: &str) -> Result<Self, Self::Err> {
221        Ok(Content::Text { text: s.to_owned() })
222    }
223}
224
225#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
226#[serde(tag = "type", rename_all = "snake_case")]
227pub enum ToolResultContent {
228    Text { text: String },
229    Image(ImageSource),
230}
231
232impl FromStr for ToolResultContent {
233    type Err = Infallible;
234
235    fn from_str(s: &str) -> Result<Self, Self::Err> {
236        Ok(ToolResultContent::Text { text: s.to_owned() })
237    }
238}
239
240#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
241#[serde(untagged)]
242pub enum ImageSourceData {
243    Base64(String),
244    Url(String),
245}
246
247impl From<ImageSourceData> for DocumentSourceKind {
248    fn from(value: ImageSourceData) -> Self {
249        match value {
250            ImageSourceData::Base64(data) => DocumentSourceKind::Base64(data),
251            ImageSourceData::Url(url) => DocumentSourceKind::Url(url),
252        }
253    }
254}
255
256impl TryFrom<DocumentSourceKind> for ImageSourceData {
257    type Error = MessageError;
258
259    fn try_from(value: DocumentSourceKind) -> Result<Self, Self::Error> {
260        match value {
261            DocumentSourceKind::Base64(data) => Ok(ImageSourceData::Base64(data)),
262            DocumentSourceKind::Url(url) => Ok(ImageSourceData::Url(url)),
263            _ => Err(MessageError::ConversionError("Content has no body".into())),
264        }
265    }
266}
267
268impl From<ImageSourceData> for String {
269    fn from(value: ImageSourceData) -> Self {
270        match value {
271            ImageSourceData::Base64(s) | ImageSourceData::Url(s) => s,
272        }
273    }
274}
275
276#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
277pub struct ImageSource {
278    pub data: ImageSourceData,
279    pub media_type: ImageFormat,
280    pub r#type: SourceType,
281}
282
283#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
284pub struct DocumentSource {
285    pub data: String,
286    pub media_type: DocumentFormat,
287    pub r#type: SourceType,
288}
289
290#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
291#[serde(rename_all = "lowercase")]
292pub enum ImageFormat {
293    #[serde(rename = "image/jpeg")]
294    JPEG,
295    #[serde(rename = "image/png")]
296    PNG,
297    #[serde(rename = "image/gif")]
298    GIF,
299    #[serde(rename = "image/webp")]
300    WEBP,
301}
302
303/// The document format to be used.
304///
305/// Currently, Anthropic only supports PDF for text documents over the API (within a message). You can find more information about this here: <https://docs.anthropic.com/en/docs/build-with-claude/pdf-support>
306#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
307#[serde(rename_all = "lowercase")]
308pub enum DocumentFormat {
309    #[serde(rename = "application/pdf")]
310    PDF,
311}
312
313#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
314#[serde(rename_all = "lowercase")]
315pub enum SourceType {
316    BASE64,
317    URL,
318}
319
320impl From<String> for Content {
321    fn from(text: String) -> Self {
322        Content::Text { text }
323    }
324}
325
326impl From<String> for ToolResultContent {
327    fn from(text: String) -> Self {
328        ToolResultContent::Text { text }
329    }
330}
331
332impl TryFrom<message::ContentFormat> for SourceType {
333    type Error = MessageError;
334
335    fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
336        match format {
337            message::ContentFormat::Base64 => Ok(SourceType::BASE64),
338            message::ContentFormat::Url => Ok(SourceType::URL),
339            message::ContentFormat::String => Err(MessageError::ConversionError(
340                "ContentFormat::String is deprecated, use ContentFormat::Url for URLs".into(),
341            )),
342        }
343    }
344}
345
346impl From<SourceType> for message::ContentFormat {
347    fn from(source_type: SourceType) -> Self {
348        match source_type {
349            SourceType::BASE64 => message::ContentFormat::Base64,
350            SourceType::URL => message::ContentFormat::Url,
351        }
352    }
353}
354
355impl TryFrom<message::ImageMediaType> for ImageFormat {
356    type Error = MessageError;
357
358    fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
359        Ok(match media_type {
360            message::ImageMediaType::JPEG => ImageFormat::JPEG,
361            message::ImageMediaType::PNG => ImageFormat::PNG,
362            message::ImageMediaType::GIF => ImageFormat::GIF,
363            message::ImageMediaType::WEBP => ImageFormat::WEBP,
364            _ => {
365                return Err(MessageError::ConversionError(
366                    format!("Unsupported image media type: {media_type:?}").to_owned(),
367                ));
368            }
369        })
370    }
371}
372
373impl From<ImageFormat> for message::ImageMediaType {
374    fn from(format: ImageFormat) -> Self {
375        match format {
376            ImageFormat::JPEG => message::ImageMediaType::JPEG,
377            ImageFormat::PNG => message::ImageMediaType::PNG,
378            ImageFormat::GIF => message::ImageMediaType::GIF,
379            ImageFormat::WEBP => message::ImageMediaType::WEBP,
380        }
381    }
382}
383
384impl TryFrom<DocumentMediaType> for DocumentFormat {
385    type Error = MessageError;
386    fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
387        if !matches!(value, DocumentMediaType::PDF) {
388            return Err(MessageError::ConversionError(
389                "Anthropic only supports PDF documents".to_string(),
390            ));
391        };
392
393        Ok(DocumentFormat::PDF)
394    }
395}
396
397impl TryFrom<message::AssistantContent> for Content {
398    type Error = MessageError;
399    fn try_from(text: message::AssistantContent) -> Result<Self, Self::Error> {
400        match text {
401            message::AssistantContent::Text(message::Text { text }) => Ok(Content::Text { text }),
402            message::AssistantContent::Image(_) => Err(MessageError::ConversionError(
403                "Anthropic currently doesn't support images.".to_string(),
404            )),
405            message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
406                Ok(Content::ToolUse {
407                    id,
408                    name: function.name,
409                    input: function.arguments,
410                })
411            }
412            message::AssistantContent::Reasoning(Reasoning {
413                reasoning,
414                signature,
415                ..
416            }) => Ok(Content::Thinking {
417                thinking: reasoning.first().cloned().unwrap_or(String::new()),
418                signature,
419            }),
420        }
421    }
422}
423
424impl TryFrom<message::Message> for Message {
425    type Error = MessageError;
426
427    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
428        Ok(match message {
429            message::Message::User { content } => Message {
430                role: Role::User,
431                content: content.try_map(|content| match content {
432                    message::UserContent::Text(message::Text { text }) => {
433                        Ok(Content::Text { text })
434                    }
435                    message::UserContent::ToolResult(message::ToolResult {
436                        id, content, ..
437                    }) => Ok(Content::ToolResult {
438                        tool_use_id: id,
439                        content: content.try_map(|content| match content {
440                            message::ToolResultContent::Text(message::Text { text }) => {
441                                Ok(ToolResultContent::Text { text })
442                            }
443                            message::ToolResultContent::Image(image) => {
444                                let DocumentSourceKind::Base64(data) = image.data else {
445                                    return Err(MessageError::ConversionError(
446                                        "Only base64 strings can be used with the Anthropic API"
447                                            .to_string(),
448                                    ));
449                                };
450                                let media_type =
451                                    image.media_type.ok_or(MessageError::ConversionError(
452                                        "Image media type is required".to_owned(),
453                                    ))?;
454                                Ok(ToolResultContent::Image(ImageSource {
455                                    data: ImageSourceData::Base64(data),
456                                    media_type: media_type.try_into()?,
457                                    r#type: SourceType::BASE64,
458                                }))
459                            }
460                        })?,
461                        is_error: None,
462                    }),
463                    message::UserContent::Image(message::Image {
464                        data, media_type, ..
465                    }) => {
466                        let media_type = media_type.ok_or(MessageError::ConversionError(
467                            "Image media type is required for Claude API".to_string(),
468                        ))?;
469
470                        let source = match data {
471                            DocumentSourceKind::Base64(data) => ImageSource {
472                                data: ImageSourceData::Base64(data),
473                                r#type: SourceType::BASE64,
474                                media_type: ImageFormat::try_from(media_type)?,
475                            },
476                            DocumentSourceKind::Url(url) => ImageSource {
477                                data: ImageSourceData::Url(url),
478                                r#type: SourceType::URL,
479                                media_type: ImageFormat::try_from(media_type)?,
480                            },
481                            DocumentSourceKind::Unknown => {
482                                return Err(MessageError::ConversionError(
483                                    "Image content has no body".into(),
484                                ));
485                            }
486                            doc => {
487                                return Err(MessageError::ConversionError(format!(
488                                    "Unsupported document type: {doc:?}"
489                                )));
490                            }
491                        };
492
493                        Ok(Content::Image { source })
494                    }
495                    message::UserContent::Document(message::Document {
496                        data, media_type, ..
497                    }) => {
498                        let media_type = media_type.ok_or(MessageError::ConversionError(
499                            "Document media type is required".to_string(),
500                        ))?;
501
502                        let data = match data {
503                            DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
504                                data
505                            }
506                            _ => {
507                                return Err(MessageError::ConversionError(
508                                    "Only base64 encoded documents currently supported".into(),
509                                ));
510                            }
511                        };
512
513                        let source = DocumentSource {
514                            data,
515                            media_type: media_type.try_into()?,
516                            r#type: SourceType::BASE64,
517                        };
518                        Ok(Content::Document { source })
519                    }
520                    message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
521                        "Audio is not supported in Anthropic".to_owned(),
522                    )),
523                    message::UserContent::Video { .. } => Err(MessageError::ConversionError(
524                        "Video is not supported in Anthropic".to_owned(),
525                    )),
526                })?,
527            },
528
529            message::Message::Assistant { content, .. } => Message {
530                content: content.try_map(|content| content.try_into())?,
531                role: Role::Assistant,
532            },
533        })
534    }
535}
536
537impl TryFrom<Content> for message::AssistantContent {
538    type Error = MessageError;
539
540    fn try_from(content: Content) -> Result<Self, Self::Error> {
541        Ok(match content {
542            Content::Text { text } => message::AssistantContent::text(text),
543            Content::ToolUse { id, name, input } => {
544                message::AssistantContent::tool_call(id, name, input)
545            }
546            Content::Thinking {
547                thinking,
548                signature,
549            } => message::AssistantContent::Reasoning(
550                Reasoning::new(&thinking).with_signature(signature),
551            ),
552            _ => {
553                return Err(MessageError::ConversionError(
554                    "Content did not contain a message, tool call, or reasoning".to_owned(),
555                ));
556            }
557        })
558    }
559}
560
561impl From<ToolResultContent> for message::ToolResultContent {
562    fn from(content: ToolResultContent) -> Self {
563        match content {
564            ToolResultContent::Text { text } => message::ToolResultContent::text(text),
565            ToolResultContent::Image(ImageSource {
566                data,
567                media_type: format,
568                ..
569            }) => message::ToolResultContent::image_base64(data, Some(format.into()), None),
570        }
571    }
572}
573
574impl TryFrom<Message> for message::Message {
575    type Error = MessageError;
576
577    fn try_from(message: Message) -> Result<Self, Self::Error> {
578        Ok(match message.role {
579            Role::User => message::Message::User {
580                content: message.content.try_map(|content| {
581                    Ok(match content {
582                        Content::Text { text } => message::UserContent::text(text),
583                        Content::ToolResult {
584                            tool_use_id,
585                            content,
586                            ..
587                        } => message::UserContent::tool_result(
588                            tool_use_id,
589                            content.map(|content| content.into()),
590                        ),
591                        Content::Image { source } => message::UserContent::Image(message::Image {
592                            data: source.data.into(),
593                            media_type: Some(source.media_type.into()),
594                            detail: None,
595                            additional_params: None,
596                        }),
597                        Content::Document { source } => message::UserContent::document(
598                            source.data,
599                            Some(message::DocumentMediaType::PDF),
600                        ),
601                        _ => {
602                            return Err(MessageError::ConversionError(
603                                "Unsupported content type for User role".to_owned(),
604                            ));
605                        }
606                    })
607                })?,
608            },
609            Role::Assistant => match message.content.first() {
610                Content::Text { .. } | Content::ToolUse { .. } | Content::Thinking { .. } => {
611                    message::Message::Assistant {
612                        id: None,
613                        content: message.content.try_map(|content| content.try_into())?,
614                    }
615                }
616
617                _ => {
618                    return Err(MessageError::ConversionError(
619                        format!("Unsupported message for Assistant role: {message:?}").to_owned(),
620                    ));
621                }
622            },
623        })
624    }
625}
626
627#[derive(Clone)]
628pub struct CompletionModel<T = reqwest::Client> {
629    pub(crate) client: Client<T>,
630    pub model: String,
631    pub default_max_tokens: Option<u64>,
632}
633
634impl<T> CompletionModel<T>
635where
636    T: HttpClientExt,
637{
638    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
639        let model = model.into();
640        let default_max_tokens = calculate_max_tokens(&model);
641
642        Self {
643            client,
644            model,
645            default_max_tokens,
646        }
647    }
648
649    pub fn with_model(client: Client<T>, model: &str) -> Self {
650        Self {
651            client,
652            model: model.to_string(),
653            default_max_tokens: Some(calculate_max_tokens_custom(model)),
654        }
655    }
656}
657
658/// Anthropic requires a `max_tokens` parameter to be set, which is dependent on the model. If not
659/// set or if set too high, the request will fail. The following values are based on the models
660/// available at the time of writing.
661fn calculate_max_tokens(model: &str) -> Option<u64> {
662    match model {
663        CLAUDE_4_OPUS => Some(32_000),
664        CLAUDE_4_SONNET | CLAUDE_3_7_SONNET => Some(64_000),
665        CLAUDE_3_5_SONNET | CLAUDE_3_5_HAIKU => Some(8_192),
666        _ => None,
667    }
668}
669
670fn calculate_max_tokens_custom(model: &str) -> u64 {
671    match model {
672        "claude-4-opus" => 32_000,
673        "claude-4-sonnet" | "claude-3.7-sonnet" => 64_000,
674        "claude-3.5-sonnet" | "claude-3.5-haiku" => 8_192,
675        _ => 4_096,
676    }
677}
678
679#[derive(Debug, Deserialize, Serialize)]
680pub struct Metadata {
681    user_id: Option<String>,
682}
683
684#[derive(Default, Debug, Serialize, Deserialize)]
685#[serde(tag = "type", rename_all = "snake_case")]
686pub enum ToolChoice {
687    #[default]
688    Auto,
689    Any,
690    None,
691    Tool {
692        name: String,
693    },
694}
695impl TryFrom<message::ToolChoice> for ToolChoice {
696    type Error = CompletionError;
697
698    fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
699        let res = match value {
700            message::ToolChoice::Auto => Self::Auto,
701            message::ToolChoice::None => Self::None,
702            message::ToolChoice::Required => Self::Any,
703            message::ToolChoice::Specific { function_names } => {
704                if function_names.len() != 1 {
705                    return Err(CompletionError::ProviderError(
706                        "Only one tool may be specified to be used by Claude".into(),
707                    ));
708                }
709
710                Self::Tool {
711                    name: function_names.first().unwrap().to_string(),
712                }
713            }
714        };
715
716        Ok(res)
717    }
718}
719
720#[derive(Debug, Deserialize, Serialize)]
721struct AnthropicCompletionRequest {
722    model: String,
723    messages: Vec<Message>,
724    max_tokens: u64,
725    system: String,
726    #[serde(skip_serializing_if = "Option::is_none")]
727    temperature: Option<f64>,
728    #[serde(skip_serializing_if = "Option::is_none")]
729    tool_choice: Option<ToolChoice>,
730    #[serde(skip_serializing_if = "Vec::is_empty")]
731    tools: Vec<ToolDefinition>,
732    #[serde(flatten, skip_serializing_if = "Option::is_none")]
733    additional_params: Option<serde_json::Value>,
734}
735
736impl TryFrom<(&str, CompletionRequest)> for AnthropicCompletionRequest {
737    type Error = CompletionError;
738
739    fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
740        // Check if max_tokens is set, required for Anthropic
741        let Some(max_tokens) = req.max_tokens else {
742            return Err(CompletionError::RequestError(
743                "`max_tokens` must be set for Anthropic".into(),
744            ));
745        };
746
747        let mut full_history = vec![];
748        if let Some(docs) = req.normalized_documents() {
749            full_history.push(docs);
750        }
751        full_history.extend(req.chat_history);
752
753        let messages = full_history
754            .into_iter()
755            .map(Message::try_from)
756            .collect::<Result<Vec<Message>, _>>()?;
757
758        let tools = req
759            .tools
760            .into_iter()
761            .map(|tool| ToolDefinition {
762                name: tool.name,
763                description: Some(tool.description),
764                input_schema: tool.parameters,
765            })
766            .collect::<Vec<_>>();
767
768        Ok(Self {
769            model: model.to_string(),
770            messages,
771            max_tokens,
772            system: req.preamble.unwrap_or_default(),
773            temperature: req.temperature,
774            tool_choice: req.tool_choice.and_then(|x| ToolChoice::try_from(x).ok()),
775            tools,
776            additional_params: req.additional_params,
777        })
778    }
779}
780
781impl<T> completion::CompletionModel for CompletionModel<T>
782where
783    T: HttpClientExt + Clone + Default + WasmCompatSend + WasmCompatSync + 'static,
784{
785    type Response = CompletionResponse;
786    type StreamingResponse = StreamingCompletionResponse;
787    type Client = Client<T>;
788
789    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
790        Self::new(client.clone(), model.into())
791    }
792
793    #[cfg_attr(feature = "worker", worker::send)]
794    async fn completion(
795        &self,
796        mut completion_request: completion::CompletionRequest,
797    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
798        let span = if tracing::Span::current().is_disabled() {
799            info_span!(
800                target: "rig::completions",
801                "chat",
802                gen_ai.operation.name = "chat",
803                gen_ai.provider.name = "anthropic",
804                gen_ai.request.model = &self.model,
805                gen_ai.system_instructions = &completion_request.preamble,
806                gen_ai.response.id = tracing::field::Empty,
807                gen_ai.response.model = tracing::field::Empty,
808                gen_ai.usage.output_tokens = tracing::field::Empty,
809                gen_ai.usage.input_tokens = tracing::field::Empty,
810                gen_ai.input.messages = tracing::field::Empty,
811                gen_ai.output.messages = tracing::field::Empty,
812            )
813        } else {
814            tracing::Span::current()
815        };
816
817        // Check if max_tokens is set, required for Anthropic
818        if completion_request.max_tokens.is_none() {
819            if let Some(tokens) = self.default_max_tokens {
820                completion_request.max_tokens = Some(tokens);
821            } else {
822                return Err(CompletionError::RequestError(
823                    "`max_tokens` must be set for Anthropic".into(),
824                ));
825            }
826        }
827
828        let request =
829            AnthropicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
830        span.record_model_input(&request.messages);
831
832        async move {
833            let request: Vec<u8> = serde_json::to_vec(&request)?;
834
835            let req = self
836                .client
837                .post("/v1/messages")?
838                .body(request)
839                .map_err(|e| CompletionError::HttpError(e.into()))?;
840
841            let response = self
842                .client
843                .send::<_, Bytes>(req)
844                .await
845                .map_err(CompletionError::HttpError)?;
846
847            if response.status().is_success() {
848                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(
849                    response
850                        .into_body()
851                        .await
852                        .map_err(CompletionError::HttpError)?
853                        .to_vec()
854                        .as_slice(),
855                )? {
856                    ApiResponse::Message(completion) => {
857                        let span = tracing::Span::current();
858                        span.record_model_output(&completion.content);
859                        span.record_response_metadata(&completion);
860                        span.record_token_usage(&completion.usage);
861                        tracing::trace!(
862                            target: "rig::completions",
863                            "Anthropic completion response: {}",
864                            serde_json::to_string_pretty(&completion)?
865                        );
866                        completion.try_into()
867                    }
868                    ApiResponse::Error(ApiErrorResponse { message }) => {
869                        Err(CompletionError::ResponseError(message))
870                    }
871                }
872            } else {
873                let text: String = String::from_utf8_lossy(
874                    &response
875                        .into_body()
876                        .await
877                        .map_err(CompletionError::HttpError)?,
878                )
879                .into();
880                Err(CompletionError::ProviderError(text))
881            }
882        }
883        .instrument(span)
884        .await
885    }
886
887    #[cfg_attr(feature = "worker", worker::send)]
888    async fn stream(
889        &self,
890        request: CompletionRequest,
891    ) -> Result<
892        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
893        CompletionError,
894    > {
895        CompletionModel::stream(self, request).await
896    }
897}
898
899#[derive(Debug, Deserialize)]
900struct ApiErrorResponse {
901    message: String,
902}
903
904#[derive(Debug, Deserialize)]
905#[serde(tag = "type", rename_all = "snake_case")]
906enum ApiResponse<T> {
907    Message(T),
908    Error(ApiErrorResponse),
909}
910
911#[cfg(test)]
912mod tests {
913    use super::*;
914    use serde_json::json;
915    use serde_path_to_error::deserialize;
916
917    #[test]
918    fn test_deserialize_message() {
919        let assistant_message_json = r#"
920        {
921            "role": "assistant",
922            "content": "\n\nHello there, how may I assist you today?"
923        }
924        "#;
925
926        let assistant_message_json2 = r#"
927        {
928            "role": "assistant",
929            "content": [
930                {
931                    "type": "text",
932                    "text": "\n\nHello there, how may I assist you today?"
933                },
934                {
935                    "type": "tool_use",
936                    "id": "toolu_01A09q90qw90lq917835lq9",
937                    "name": "get_weather",
938                    "input": {"location": "San Francisco, CA"}
939                }
940            ]
941        }
942        "#;
943
944        let user_message_json = r#"
945        {
946            "role": "user",
947            "content": [
948                {
949                    "type": "image",
950                    "source": {
951                        "type": "base64",
952                        "media_type": "image/jpeg",
953                        "data": "/9j/4AAQSkZJRg..."
954                    }
955                },
956                {
957                    "type": "text",
958                    "text": "What is in this image?"
959                },
960                {
961                    "type": "tool_result",
962                    "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
963                    "content": "15 degrees"
964                }
965            ]
966        }
967        "#;
968
969        let assistant_message: Message = {
970            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
971            deserialize(jd).unwrap_or_else(|err| {
972                panic!("Deserialization error at {}: {}", err.path(), err);
973            })
974        };
975
976        let assistant_message2: Message = {
977            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
978            deserialize(jd).unwrap_or_else(|err| {
979                panic!("Deserialization error at {}: {}", err.path(), err);
980            })
981        };
982
983        let user_message: Message = {
984            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
985            deserialize(jd).unwrap_or_else(|err| {
986                panic!("Deserialization error at {}: {}", err.path(), err);
987            })
988        };
989
990        let Message { role, content } = assistant_message;
991        assert_eq!(role, Role::Assistant);
992        assert_eq!(
993            content.first(),
994            Content::Text {
995                text: "\n\nHello there, how may I assist you today?".to_owned()
996            }
997        );
998
999        let Message { role, content } = assistant_message2;
1000        {
1001            assert_eq!(role, Role::Assistant);
1002            assert_eq!(content.len(), 2);
1003
1004            let mut iter = content.into_iter();
1005
1006            match iter.next().unwrap() {
1007                Content::Text { text } => {
1008                    assert_eq!(text, "\n\nHello there, how may I assist you today?");
1009                }
1010                _ => panic!("Expected text content"),
1011            }
1012
1013            match iter.next().unwrap() {
1014                Content::ToolUse { id, name, input } => {
1015                    assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1016                    assert_eq!(name, "get_weather");
1017                    assert_eq!(input, json!({"location": "San Francisco, CA"}));
1018                }
1019                _ => panic!("Expected tool use content"),
1020            }
1021
1022            assert_eq!(iter.next(), None);
1023        }
1024
1025        let Message { role, content } = user_message;
1026        {
1027            assert_eq!(role, Role::User);
1028            assert_eq!(content.len(), 3);
1029
1030            let mut iter = content.into_iter();
1031
1032            match iter.next().unwrap() {
1033                Content::Image { source } => {
1034                    assert_eq!(
1035                        source,
1036                        ImageSource {
1037                            data: ImageSourceData::Base64("/9j/4AAQSkZJRg...".to_owned()),
1038                            media_type: ImageFormat::JPEG,
1039                            r#type: SourceType::BASE64,
1040                        }
1041                    );
1042                }
1043                _ => panic!("Expected image content"),
1044            }
1045
1046            match iter.next().unwrap() {
1047                Content::Text { text } => {
1048                    assert_eq!(text, "What is in this image?");
1049                }
1050                _ => panic!("Expected text content"),
1051            }
1052
1053            match iter.next().unwrap() {
1054                Content::ToolResult {
1055                    tool_use_id,
1056                    content,
1057                    is_error,
1058                } => {
1059                    assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
1060                    assert_eq!(
1061                        content.first(),
1062                        ToolResultContent::Text {
1063                            text: "15 degrees".to_owned()
1064                        }
1065                    );
1066                    assert_eq!(is_error, None);
1067                }
1068                _ => panic!("Expected tool result content"),
1069            }
1070
1071            assert_eq!(iter.next(), None);
1072        }
1073    }
1074
1075    #[test]
1076    fn test_message_to_message_conversion() {
1077        let user_message: Message = serde_json::from_str(
1078            r#"
1079        {
1080            "role": "user",
1081            "content": [
1082                {
1083                    "type": "image",
1084                    "source": {
1085                        "type": "base64",
1086                        "media_type": "image/jpeg",
1087                        "data": "/9j/4AAQSkZJRg..."
1088                    }
1089                },
1090                {
1091                    "type": "text",
1092                    "text": "What is in this image?"
1093                },
1094                {
1095                    "type": "document",
1096                    "source": {
1097                        "type": "base64",
1098                        "data": "base64_encoded_pdf_data",
1099                        "media_type": "application/pdf"
1100                    }
1101                }
1102            ]
1103        }
1104        "#,
1105        )
1106        .unwrap();
1107
1108        let assistant_message = Message {
1109            role: Role::Assistant,
1110            content: OneOrMany::one(Content::ToolUse {
1111                id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1112                name: "get_weather".to_string(),
1113                input: json!({"location": "San Francisco, CA"}),
1114            }),
1115        };
1116
1117        let tool_message = Message {
1118            role: Role::User,
1119            content: OneOrMany::one(Content::ToolResult {
1120                tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1121                content: OneOrMany::one(ToolResultContent::Text {
1122                    text: "15 degrees".to_string(),
1123                }),
1124                is_error: None,
1125            }),
1126        };
1127
1128        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1129        let converted_assistant_message: message::Message =
1130            assistant_message.clone().try_into().unwrap();
1131        let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
1132
1133        match converted_user_message.clone() {
1134            message::Message::User { content } => {
1135                assert_eq!(content.len(), 3);
1136
1137                let mut iter = content.into_iter();
1138
1139                match iter.next().unwrap() {
1140                    message::UserContent::Image(message::Image {
1141                        data, media_type, ..
1142                    }) => {
1143                        assert_eq!(data, DocumentSourceKind::base64("/9j/4AAQSkZJRg..."));
1144                        assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
1145                    }
1146                    _ => panic!("Expected image content"),
1147                }
1148
1149                match iter.next().unwrap() {
1150                    message::UserContent::Text(message::Text { text }) => {
1151                        assert_eq!(text, "What is in this image?");
1152                    }
1153                    _ => panic!("Expected text content"),
1154                }
1155
1156                match iter.next().unwrap() {
1157                    message::UserContent::Document(message::Document {
1158                        data, media_type, ..
1159                    }) => {
1160                        assert_eq!(
1161                            data,
1162                            DocumentSourceKind::String("base64_encoded_pdf_data".into())
1163                        );
1164                        assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
1165                    }
1166                    _ => panic!("Expected document content"),
1167                }
1168
1169                assert_eq!(iter.next(), None);
1170            }
1171            _ => panic!("Expected user message"),
1172        }
1173
1174        match converted_tool_message.clone() {
1175            message::Message::User { content } => {
1176                let message::ToolResult { id, content, .. } = match content.first() {
1177                    message::UserContent::ToolResult(tool_result) => tool_result,
1178                    _ => panic!("Expected tool result content"),
1179                };
1180                assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1181                match content.first() {
1182                    message::ToolResultContent::Text(message::Text { text }) => {
1183                        assert_eq!(text, "15 degrees");
1184                    }
1185                    _ => panic!("Expected text content"),
1186                }
1187            }
1188            _ => panic!("Expected tool result content"),
1189        }
1190
1191        match converted_assistant_message.clone() {
1192            message::Message::Assistant { content, .. } => {
1193                assert_eq!(content.len(), 1);
1194
1195                match content.first() {
1196                    message::AssistantContent::ToolCall(message::ToolCall {
1197                        id, function, ..
1198                    }) => {
1199                        assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1200                        assert_eq!(function.name, "get_weather");
1201                        assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
1202                    }
1203                    _ => panic!("Expected tool call content"),
1204                }
1205            }
1206            _ => panic!("Expected assistant message"),
1207        }
1208
1209        let original_user_message: Message = converted_user_message.try_into().unwrap();
1210        let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
1211        let original_tool_message: Message = converted_tool_message.try_into().unwrap();
1212
1213        assert_eq!(user_message, original_user_message);
1214        assert_eq!(assistant_message, original_assistant_message);
1215        assert_eq!(tool_message, original_tool_message);
1216    }
1217
1218    #[test]
1219    fn test_content_format_conversion() {
1220        use crate::completion::message::ContentFormat;
1221
1222        let source_type: SourceType = ContentFormat::Url.try_into().unwrap();
1223        assert_eq!(source_type, SourceType::URL);
1224
1225        let content_format: ContentFormat = SourceType::URL.into();
1226        assert_eq!(content_format, ContentFormat::Url);
1227
1228        let source_type: SourceType = ContentFormat::Base64.try_into().unwrap();
1229        assert_eq!(source_type, SourceType::BASE64);
1230
1231        let content_format: ContentFormat = SourceType::BASE64.into();
1232        assert_eq!(content_format, ContentFormat::Base64);
1233
1234        let result: Result<SourceType, _> = ContentFormat::String.try_into();
1235        assert!(result.is_err());
1236        assert!(
1237            result
1238                .unwrap_err()
1239                .to_string()
1240                .contains("ContentFormat::String is deprecated")
1241        );
1242    }
1243}