rig/providers/anthropic/
completion.rs

1//! Anthropic completion api implementation
2
3use crate::{
4    OneOrMany,
5    completion::{self, CompletionError, GetTokenUsage},
6    http_client::HttpClientExt,
7    json_utils,
8    message::{self, DocumentMediaType, DocumentSourceKind, MessageError, Reasoning},
9    one_or_many::string_or_one_or_many,
10    telemetry::{ProviderResponseExt, SpanCombinator},
11    wasm_compat::*,
12};
13use std::{convert::Infallible, str::FromStr};
14
15use super::client::Client;
16use crate::completion::CompletionRequest;
17use crate::providers::anthropic::streaming::StreamingCompletionResponse;
18use bytes::Bytes;
19use serde::{Deserialize, Serialize};
20use serde_json::json;
21use tracing::{Instrument, info_span};
22
23// ================================================================
24// Anthropic Completion API
25// ================================================================
26
27/// `claude-opus-4-0` completion model
28pub const CLAUDE_4_OPUS: &str = "claude-opus-4-0";
29
30/// `claude-sonnet-4-0` completion model
31pub const CLAUDE_4_SONNET: &str = "claude-sonnet-4-0";
32
33/// `claude-3-7-sonnet-latest` completion model
34pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
35
36/// `claude-3-5-sonnet-latest` completion model
37pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
38
39/// `claude-3-5-haiku-latest` completion model
40pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
41
42/// `claude-3-5-haiku-latest` completion model
43pub const CLAUDE_3_OPUS: &str = "claude-3-opus-latest";
44
45/// `claude-3-sonnet-20240229` completion model
46pub const CLAUDE_3_SONNET: &str = "claude-3-sonnet-20240229";
47
48/// `claude-3-haiku-20240307` completion model
49pub const CLAUDE_3_HAIKU: &str = "claude-3-haiku-20240307";
50
51pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
52pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
53pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
54
55#[derive(Debug, Deserialize, Serialize)]
56pub struct CompletionResponse {
57    pub content: Vec<Content>,
58    pub id: String,
59    pub model: String,
60    pub role: String,
61    pub stop_reason: Option<String>,
62    pub stop_sequence: Option<String>,
63    pub usage: Usage,
64}
65
66impl ProviderResponseExt for CompletionResponse {
67    type OutputMessage = Content;
68    type Usage = Usage;
69
70    fn get_response_id(&self) -> Option<String> {
71        Some(self.id.to_owned())
72    }
73
74    fn get_response_model_name(&self) -> Option<String> {
75        Some(self.model.to_owned())
76    }
77
78    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
79        self.content.clone()
80    }
81
82    fn get_text_response(&self) -> Option<String> {
83        let res = self
84            .content
85            .iter()
86            .filter_map(|x| {
87                if let Content::Text { text } = x {
88                    Some(text.to_owned())
89                } else {
90                    None
91                }
92            })
93            .collect::<Vec<String>>()
94            .join("\n");
95
96        if res.is_empty() { None } else { Some(res) }
97    }
98
99    fn get_usage(&self) -> Option<Self::Usage> {
100        Some(self.usage.clone())
101    }
102}
103
104#[derive(Clone, Debug, Deserialize, Serialize)]
105pub struct Usage {
106    pub input_tokens: u64,
107    pub cache_read_input_tokens: Option<u64>,
108    pub cache_creation_input_tokens: Option<u64>,
109    pub output_tokens: u64,
110}
111
112impl std::fmt::Display for Usage {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        write!(
115            f,
116            "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
117            self.input_tokens,
118            match self.cache_read_input_tokens {
119                Some(token) => token.to_string(),
120                None => "n/a".to_string(),
121            },
122            match self.cache_creation_input_tokens {
123                Some(token) => token.to_string(),
124                None => "n/a".to_string(),
125            },
126            self.output_tokens
127        )
128    }
129}
130
131impl GetTokenUsage for Usage {
132    fn token_usage(&self) -> Option<crate::completion::Usage> {
133        let mut usage = crate::completion::Usage::new();
134
135        usage.input_tokens = self.input_tokens
136            + self.cache_creation_input_tokens.unwrap_or_default()
137            + self.cache_read_input_tokens.unwrap_or_default();
138        usage.output_tokens = self.output_tokens;
139        usage.total_tokens = usage.input_tokens + usage.output_tokens;
140
141        Some(usage)
142    }
143}
144
145#[derive(Debug, Deserialize, Serialize)]
146pub struct ToolDefinition {
147    pub name: String,
148    pub description: Option<String>,
149    pub input_schema: serde_json::Value,
150}
151
152#[derive(Debug, Deserialize, Serialize)]
153#[serde(tag = "type", rename_all = "snake_case")]
154pub enum CacheControl {
155    Ephemeral,
156}
157
158impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
159    type Error = CompletionError;
160
161    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
162        let content = response
163            .content
164            .iter()
165            .map(|content| {
166                Ok(match content {
167                    Content::Text { text } => completion::AssistantContent::text(text),
168                    Content::ToolUse { id, name, input } => {
169                        completion::AssistantContent::tool_call(id, name, input.clone())
170                    }
171                    _ => {
172                        return Err(CompletionError::ResponseError(
173                            "Response did not contain a message or tool call".into(),
174                        ));
175                    }
176                })
177            })
178            .collect::<Result<Vec<_>, _>>()?;
179
180        let choice = OneOrMany::many(content).map_err(|_| {
181            CompletionError::ResponseError(
182                "Response contained no message or tool call (empty)".to_owned(),
183            )
184        })?;
185
186        let usage = completion::Usage {
187            input_tokens: response.usage.input_tokens,
188            output_tokens: response.usage.output_tokens,
189            total_tokens: response.usage.input_tokens + response.usage.output_tokens,
190        };
191
192        Ok(completion::CompletionResponse {
193            choice,
194            usage,
195            raw_response: response,
196        })
197    }
198}
199
200#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
201pub struct Message {
202    pub role: Role,
203    #[serde(deserialize_with = "string_or_one_or_many")]
204    pub content: OneOrMany<Content>,
205}
206
207#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
208#[serde(rename_all = "lowercase")]
209pub enum Role {
210    User,
211    Assistant,
212}
213
214#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
215#[serde(tag = "type", rename_all = "snake_case")]
216pub enum Content {
217    Text {
218        text: String,
219    },
220    Image {
221        source: ImageSource,
222    },
223    ToolUse {
224        id: String,
225        name: String,
226        input: serde_json::Value,
227    },
228    ToolResult {
229        tool_use_id: String,
230        #[serde(deserialize_with = "string_or_one_or_many")]
231        content: OneOrMany<ToolResultContent>,
232        #[serde(skip_serializing_if = "Option::is_none")]
233        is_error: Option<bool>,
234    },
235    Document {
236        source: DocumentSource,
237    },
238    Thinking {
239        thinking: String,
240        signature: Option<String>,
241    },
242}
243
244impl FromStr for Content {
245    type Err = Infallible;
246
247    fn from_str(s: &str) -> Result<Self, Self::Err> {
248        Ok(Content::Text { text: s.to_owned() })
249    }
250}
251
252#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
253#[serde(tag = "type", rename_all = "snake_case")]
254pub enum ToolResultContent {
255    Text { text: String },
256    Image(ImageSource),
257}
258
259impl FromStr for ToolResultContent {
260    type Err = Infallible;
261
262    fn from_str(s: &str) -> Result<Self, Self::Err> {
263        Ok(ToolResultContent::Text { text: s.to_owned() })
264    }
265}
266
267#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
268#[serde(untagged)]
269pub enum ImageSourceData {
270    Base64(String),
271    Url(String),
272}
273
274impl From<ImageSourceData> for DocumentSourceKind {
275    fn from(value: ImageSourceData) -> Self {
276        match value {
277            ImageSourceData::Base64(data) => DocumentSourceKind::Base64(data),
278            ImageSourceData::Url(url) => DocumentSourceKind::Url(url),
279        }
280    }
281}
282
283impl TryFrom<DocumentSourceKind> for ImageSourceData {
284    type Error = MessageError;
285
286    fn try_from(value: DocumentSourceKind) -> Result<Self, Self::Error> {
287        match value {
288            DocumentSourceKind::Base64(data) => Ok(ImageSourceData::Base64(data)),
289            DocumentSourceKind::Url(url) => Ok(ImageSourceData::Url(url)),
290            _ => Err(MessageError::ConversionError("Content has no body".into())),
291        }
292    }
293}
294
295impl From<ImageSourceData> for String {
296    fn from(value: ImageSourceData) -> Self {
297        match value {
298            ImageSourceData::Base64(s) | ImageSourceData::Url(s) => s,
299        }
300    }
301}
302
303#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
304pub struct ImageSource {
305    pub data: ImageSourceData,
306    pub media_type: ImageFormat,
307    pub r#type: SourceType,
308}
309
310#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
311pub struct DocumentSource {
312    pub data: String,
313    pub media_type: DocumentFormat,
314    pub r#type: SourceType,
315}
316
317#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
318#[serde(rename_all = "lowercase")]
319pub enum ImageFormat {
320    #[serde(rename = "image/jpeg")]
321    JPEG,
322    #[serde(rename = "image/png")]
323    PNG,
324    #[serde(rename = "image/gif")]
325    GIF,
326    #[serde(rename = "image/webp")]
327    WEBP,
328}
329
330/// The document format to be used.
331///
332/// Currently, Anthropic only supports PDF for text documents over the API (within a message). You can find more information about this here: <https://docs.anthropic.com/en/docs/build-with-claude/pdf-support>
333#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
334#[serde(rename_all = "lowercase")]
335pub enum DocumentFormat {
336    #[serde(rename = "application/pdf")]
337    PDF,
338}
339
340#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
341#[serde(rename_all = "lowercase")]
342pub enum SourceType {
343    BASE64,
344    URL,
345}
346
347impl From<String> for Content {
348    fn from(text: String) -> Self {
349        Content::Text { text }
350    }
351}
352
353impl From<String> for ToolResultContent {
354    fn from(text: String) -> Self {
355        ToolResultContent::Text { text }
356    }
357}
358
359impl TryFrom<message::ContentFormat> for SourceType {
360    type Error = MessageError;
361
362    fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
363        match format {
364            message::ContentFormat::Base64 => Ok(SourceType::BASE64),
365            message::ContentFormat::Url => Ok(SourceType::URL),
366            message::ContentFormat::String => Err(MessageError::ConversionError(
367                "ContentFormat::String is deprecated, use ContentFormat::Url for URLs".into(),
368            )),
369        }
370    }
371}
372
373impl From<SourceType> for message::ContentFormat {
374    fn from(source_type: SourceType) -> Self {
375        match source_type {
376            SourceType::BASE64 => message::ContentFormat::Base64,
377            SourceType::URL => message::ContentFormat::Url,
378        }
379    }
380}
381
382impl TryFrom<message::ImageMediaType> for ImageFormat {
383    type Error = MessageError;
384
385    fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
386        Ok(match media_type {
387            message::ImageMediaType::JPEG => ImageFormat::JPEG,
388            message::ImageMediaType::PNG => ImageFormat::PNG,
389            message::ImageMediaType::GIF => ImageFormat::GIF,
390            message::ImageMediaType::WEBP => ImageFormat::WEBP,
391            _ => {
392                return Err(MessageError::ConversionError(
393                    format!("Unsupported image media type: {media_type:?}").to_owned(),
394                ));
395            }
396        })
397    }
398}
399
400impl From<ImageFormat> for message::ImageMediaType {
401    fn from(format: ImageFormat) -> Self {
402        match format {
403            ImageFormat::JPEG => message::ImageMediaType::JPEG,
404            ImageFormat::PNG => message::ImageMediaType::PNG,
405            ImageFormat::GIF => message::ImageMediaType::GIF,
406            ImageFormat::WEBP => message::ImageMediaType::WEBP,
407        }
408    }
409}
410
411impl TryFrom<DocumentMediaType> for DocumentFormat {
412    type Error = MessageError;
413    fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
414        if !matches!(value, DocumentMediaType::PDF) {
415            return Err(MessageError::ConversionError(
416                "Anthropic only supports PDF documents".to_string(),
417            ));
418        };
419
420        Ok(DocumentFormat::PDF)
421    }
422}
423
424impl From<message::AssistantContent> for Content {
425    fn from(text: message::AssistantContent) -> Self {
426        match text {
427            message::AssistantContent::Text(message::Text { text }) => Content::Text { text },
428            message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
429                Content::ToolUse {
430                    id,
431                    name: function.name,
432                    input: function.arguments,
433                }
434            }
435            message::AssistantContent::Reasoning(Reasoning { reasoning, id }) => {
436                Content::Thinking {
437                    thinking: reasoning.first().cloned().unwrap_or(String::new()),
438                    signature: id,
439                }
440            }
441        }
442    }
443}
444
445impl TryFrom<message::Message> for Message {
446    type Error = MessageError;
447
448    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
449        Ok(match message {
450            message::Message::User { content } => Message {
451                role: Role::User,
452                content: content.try_map(|content| match content {
453                    message::UserContent::Text(message::Text { text }) => {
454                        Ok(Content::Text { text })
455                    }
456                    message::UserContent::ToolResult(message::ToolResult {
457                        id, content, ..
458                    }) => Ok(Content::ToolResult {
459                        tool_use_id: id,
460                        content: content.try_map(|content| match content {
461                            message::ToolResultContent::Text(message::Text { text }) => {
462                                Ok(ToolResultContent::Text { text })
463                            }
464                            message::ToolResultContent::Image(image) => {
465                                let DocumentSourceKind::Base64(data) = image.data else {
466                                    return Err(MessageError::ConversionError(
467                                        "Only base64 strings can be used with the Anthropic API"
468                                            .to_string(),
469                                    ));
470                                };
471                                let media_type =
472                                    image.media_type.ok_or(MessageError::ConversionError(
473                                        "Image media type is required".to_owned(),
474                                    ))?;
475                                Ok(ToolResultContent::Image(ImageSource {
476                                    data: ImageSourceData::Base64(data),
477                                    media_type: media_type.try_into()?,
478                                    r#type: SourceType::BASE64,
479                                }))
480                            }
481                        })?,
482                        is_error: None,
483                    }),
484                    message::UserContent::Image(message::Image {
485                        data, media_type, ..
486                    }) => {
487                        let media_type = media_type.ok_or(MessageError::ConversionError(
488                            "Image media type is required for Claude API".into(),
489                        ))?;
490
491                        let source = match data {
492                            DocumentSourceKind::Base64(data) => ImageSource {
493                                data: ImageSourceData::Base64(data),
494                                r#type: SourceType::BASE64,
495                                media_type: ImageFormat::try_from(media_type)?,
496                            },
497                            DocumentSourceKind::Url(url) => ImageSource {
498                                data: ImageSourceData::Url(url),
499                                r#type: SourceType::URL,
500                                media_type: ImageFormat::try_from(media_type)?,
501                            },
502                            DocumentSourceKind::Unknown => {
503                                return Err(MessageError::ConversionError(
504                                    "Image content has no body".into(),
505                                ));
506                            }
507                            doc => {
508                                return Err(MessageError::ConversionError(format!(
509                                    "Unsupported document type: {doc:?}"
510                                )));
511                            }
512                        };
513
514                        Ok(Content::Image { source })
515                    }
516                    message::UserContent::Document(message::Document {
517                        data, media_type, ..
518                    }) => {
519                        let media_type = media_type.ok_or(MessageError::ConversionError(
520                            "Document media type is required".to_string(),
521                        ))?;
522
523                        let data = match data {
524                            DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
525                                data
526                            }
527                            _ => {
528                                return Err(MessageError::ConversionError(
529                                    "Only base64 encoded documents currently supported".into(),
530                                ));
531                            }
532                        };
533
534                        let source = DocumentSource {
535                            data,
536                            media_type: media_type.try_into()?,
537                            r#type: SourceType::BASE64,
538                        };
539                        Ok(Content::Document { source })
540                    }
541                    message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
542                        "Audio is not supported in Anthropic".to_owned(),
543                    )),
544                    message::UserContent::Video { .. } => Err(MessageError::ConversionError(
545                        "Video is not supported in Anthropic".to_owned(),
546                    )),
547                })?,
548            },
549
550            message::Message::Assistant { content, .. } => Message {
551                content: content.map(|content| content.into()),
552                role: Role::Assistant,
553            },
554        })
555    }
556}
557
558impl TryFrom<Content> for message::AssistantContent {
559    type Error = MessageError;
560
561    fn try_from(content: Content) -> Result<Self, Self::Error> {
562        Ok(match content {
563            Content::Text { text } => message::AssistantContent::text(text),
564            Content::ToolUse { id, name, input } => {
565                message::AssistantContent::tool_call(id, name, input)
566            }
567            Content::Thinking {
568                thinking,
569                signature,
570            } => message::AssistantContent::Reasoning(
571                Reasoning::new(&thinking).optional_id(signature),
572            ),
573            _ => {
574                return Err(MessageError::ConversionError(
575                    format!("Unsupported content type for Assistant role: {content:?}").to_owned(),
576                ));
577            }
578        })
579    }
580}
581
582impl From<ToolResultContent> for message::ToolResultContent {
583    fn from(content: ToolResultContent) -> Self {
584        match content {
585            ToolResultContent::Text { text } => message::ToolResultContent::text(text),
586            ToolResultContent::Image(ImageSource {
587                data,
588                media_type: format,
589                ..
590            }) => message::ToolResultContent::image_base64(data, Some(format.into()), None),
591        }
592    }
593}
594
595impl TryFrom<Message> for message::Message {
596    type Error = MessageError;
597
598    fn try_from(message: Message) -> Result<Self, Self::Error> {
599        Ok(match message.role {
600            Role::User => message::Message::User {
601                content: message.content.try_map(|content| {
602                    Ok(match content {
603                        Content::Text { text } => message::UserContent::text(text),
604                        Content::ToolResult {
605                            tool_use_id,
606                            content,
607                            ..
608                        } => message::UserContent::tool_result(
609                            tool_use_id,
610                            content.map(|content| content.into()),
611                        ),
612                        Content::Image { source } => message::UserContent::Image(message::Image {
613                            data: source.data.into(),
614                            media_type: Some(source.media_type.into()),
615                            detail: None,
616                            additional_params: None,
617                        }),
618                        Content::Document { source } => message::UserContent::document(
619                            source.data,
620                            Some(message::DocumentMediaType::PDF),
621                        ),
622                        _ => {
623                            return Err(MessageError::ConversionError(
624                                "Unsupported content type for User role".to_owned(),
625                            ));
626                        }
627                    })
628                })?,
629            },
630            Role::Assistant => match message.content.first() {
631                Content::Text { .. } | Content::ToolUse { .. } | Content::Thinking { .. } => {
632                    message::Message::Assistant {
633                        id: None,
634                        content: message.content.try_map(|content| content.try_into())?,
635                    }
636                }
637
638                _ => {
639                    return Err(MessageError::ConversionError(
640                        format!("Unsupported message for Assistant role: {message:?}").to_owned(),
641                    ));
642                }
643            },
644        })
645    }
646}
647
648#[derive(Clone)]
649pub struct CompletionModel<T = reqwest::Client>
650where
651    T: WasmCompatSend,
652{
653    pub(crate) client: Client<T>,
654    pub model: String,
655    pub default_max_tokens: Option<u64>,
656}
657
658impl<T> CompletionModel<T>
659where
660    T: HttpClientExt,
661{
662    pub fn new(client: Client<T>, model: &str) -> Self {
663        Self {
664            client,
665            model: model.to_string(),
666            default_max_tokens: calculate_max_tokens(model),
667        }
668    }
669}
670
671/// Anthropic requires a `max_tokens` parameter to be set, which is dependent on the model. If not
672/// set or if set too high, the request will fail. The following values are based on the models
673/// available at the time of writing.
674///
675/// Dev Note: This is really bad design, I'm not sure why they did it like this..
676fn calculate_max_tokens(model: &str) -> Option<u64> {
677    if model.starts_with("claude-opus-4") {
678        Some(32000)
679    } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
680        Some(64000)
681    } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
682        Some(8192)
683    } else if model.starts_with("claude-3-opus")
684        || model.starts_with("claude-3-sonnet")
685        || model.starts_with("claude-3-haiku")
686    {
687        Some(4096)
688    } else {
689        None
690    }
691}
692
693#[derive(Debug, Deserialize, Serialize)]
694pub struct Metadata {
695    user_id: Option<String>,
696}
697
698#[derive(Default, Debug, Serialize, Deserialize)]
699#[serde(tag = "type", rename_all = "snake_case")]
700pub enum ToolChoice {
701    #[default]
702    Auto,
703    Any,
704    None,
705    Tool {
706        name: String,
707    },
708}
709impl TryFrom<message::ToolChoice> for ToolChoice {
710    type Error = CompletionError;
711
712    fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
713        let res = match value {
714            message::ToolChoice::Auto => Self::Auto,
715            message::ToolChoice::None => Self::None,
716            message::ToolChoice::Required => Self::Any,
717            message::ToolChoice::Specific { function_names } => {
718                if function_names.len() != 1 {
719                    return Err(CompletionError::ProviderError(
720                        "Only one tool may be specified to be used by Claude".into(),
721                    ));
722                }
723
724                Self::Tool {
725                    name: function_names.first().unwrap().to_string(),
726                }
727            }
728        };
729
730        Ok(res)
731    }
732}
733impl<T> completion::CompletionModel for CompletionModel<T>
734where
735    T: HttpClientExt + Clone + Default + WasmCompatSend + WasmCompatSync + 'static,
736{
737    type Response = CompletionResponse;
738    type StreamingResponse = StreamingCompletionResponse;
739
740    #[cfg_attr(feature = "worker", worker::send)]
741    async fn completion(
742        &self,
743        completion_request: completion::CompletionRequest,
744    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
745        let span = if tracing::Span::current().is_disabled() {
746            info_span!(
747                target: "rig::completions",
748                "chat",
749                gen_ai.operation.name = "chat",
750                gen_ai.provider.name = "anthropic",
751                gen_ai.request.model = self.model,
752                gen_ai.system_instructions = &completion_request.preamble,
753                gen_ai.response.id = tracing::field::Empty,
754                gen_ai.response.model = tracing::field::Empty,
755                gen_ai.usage.output_tokens = tracing::field::Empty,
756                gen_ai.usage.input_tokens = tracing::field::Empty,
757                gen_ai.input.messages = tracing::field::Empty,
758                gen_ai.output.messages = tracing::field::Empty,
759            )
760        } else {
761            tracing::Span::current()
762        };
763        // Note: Ideally we'd introduce provider-specific Request models to handle the
764        // specific requirements of each provider. For now, we just manually check while
765        // building the request as a raw JSON document.
766
767        // Check if max_tokens is set, required for Anthropic
768        let max_tokens = if let Some(tokens) = completion_request.max_tokens {
769            tokens
770        } else if let Some(tokens) = self.default_max_tokens {
771            tokens
772        } else {
773            return Err(CompletionError::RequestError(
774                "`max_tokens` must be set for Anthropic".into(),
775            ));
776        };
777
778        let mut full_history = vec![];
779        if let Some(docs) = completion_request.normalized_documents() {
780            full_history.push(docs);
781        }
782        full_history.extend(completion_request.chat_history);
783        span.record_model_input(&full_history);
784
785        let full_history = full_history
786            .into_iter()
787            .map(Message::try_from)
788            .collect::<Result<Vec<Message>, _>>()?;
789
790        let mut request = json!({
791            "model": self.model,
792            "messages": full_history,
793            "max_tokens": max_tokens,
794            "system": completion_request.preamble.unwrap_or("".to_string()),
795        });
796
797        if let Some(temperature) = completion_request.temperature {
798            json_utils::merge_inplace(&mut request, json!({ "temperature": temperature }));
799        }
800
801        let tool_choice = if let Some(tool_choice) = completion_request.tool_choice {
802            Some(ToolChoice::try_from(tool_choice)?)
803        } else {
804            None
805        };
806
807        if !completion_request.tools.is_empty() {
808            json_utils::merge_inplace(
809                &mut request,
810                json!({
811                    "tools": completion_request
812                        .tools
813                        .into_iter()
814                        .map(|tool| ToolDefinition {
815                            name: tool.name,
816                            description: Some(tool.description),
817                            input_schema: tool.parameters,
818                        })
819                        .collect::<Vec<_>>(),
820                    "tool_choice": tool_choice,
821                }),
822            );
823        }
824
825        if let Some(ref params) = completion_request.additional_params {
826            json_utils::merge_inplace(&mut request, params.clone())
827        }
828
829        async move {
830            let request: Vec<u8> = serde_json::to_vec(&request)?;
831
832            let req = self
833                .client
834                .post("/v1/messages")
835                .header("Content-Type", "application/json")
836                .body(request)
837                .map_err(|e| CompletionError::HttpError(e.into()))?;
838
839            let response = self
840                .client
841                .send::<_, Bytes>(req)
842                .await
843                .map_err(CompletionError::HttpError)?;
844
845            if response.status().is_success() {
846                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(
847                    response
848                        .into_body()
849                        .await
850                        .map_err(CompletionError::HttpError)?
851                        .to_vec()
852                        .as_slice(),
853                )? {
854                    ApiResponse::Message(completion) => {
855                        let span = tracing::Span::current();
856                        span.record_model_output(&completion.content);
857                        span.record_response_metadata(&completion);
858                        span.record_token_usage(&completion.usage);
859                        completion.try_into()
860                    }
861                    ApiResponse::Error(ApiErrorResponse { message }) => {
862                        Err(CompletionError::ResponseError(message))
863                    }
864                }
865            } else {
866                let text: String = String::from_utf8_lossy(
867                    &response
868                        .into_body()
869                        .await
870                        .map_err(CompletionError::HttpError)?,
871                )
872                .into();
873                Err(CompletionError::ProviderError(text))
874            }
875        }
876        .instrument(span)
877        .await
878    }
879
880    #[cfg_attr(feature = "worker", worker::send)]
881    async fn stream(
882        &self,
883        request: CompletionRequest,
884    ) -> Result<
885        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
886        CompletionError,
887    > {
888        CompletionModel::stream(self, request).await
889    }
890}
891
892#[derive(Debug, Deserialize)]
893struct ApiErrorResponse {
894    message: String,
895}
896
897#[derive(Debug, Deserialize)]
898#[serde(tag = "type", rename_all = "snake_case")]
899enum ApiResponse<T> {
900    Message(T),
901    Error(ApiErrorResponse),
902}
903
904#[cfg(test)]
905mod tests {
906    use super::*;
907    use serde_path_to_error::deserialize;
908
909    #[test]
910    fn test_deserialize_message() {
911        let assistant_message_json = r#"
912        {
913            "role": "assistant",
914            "content": "\n\nHello there, how may I assist you today?"
915        }
916        "#;
917
918        let assistant_message_json2 = r#"
919        {
920            "role": "assistant",
921            "content": [
922                {
923                    "type": "text",
924                    "text": "\n\nHello there, how may I assist you today?"
925                },
926                {
927                    "type": "tool_use",
928                    "id": "toolu_01A09q90qw90lq917835lq9",
929                    "name": "get_weather",
930                    "input": {"location": "San Francisco, CA"}
931                }
932            ]
933        }
934        "#;
935
936        let user_message_json = r#"
937        {
938            "role": "user",
939            "content": [
940                {
941                    "type": "image",
942                    "source": {
943                        "type": "base64",
944                        "media_type": "image/jpeg",
945                        "data": "/9j/4AAQSkZJRg..."
946                    }
947                },
948                {
949                    "type": "text",
950                    "text": "What is in this image?"
951                },
952                {
953                    "type": "tool_result",
954                    "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
955                    "content": "15 degrees"
956                }
957            ]
958        }
959        "#;
960
961        let assistant_message: Message = {
962            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
963            deserialize(jd).unwrap_or_else(|err| {
964                panic!("Deserialization error at {}: {}", err.path(), err);
965            })
966        };
967
968        let assistant_message2: Message = {
969            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
970            deserialize(jd).unwrap_or_else(|err| {
971                panic!("Deserialization error at {}: {}", err.path(), err);
972            })
973        };
974
975        let user_message: Message = {
976            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
977            deserialize(jd).unwrap_or_else(|err| {
978                panic!("Deserialization error at {}: {}", err.path(), err);
979            })
980        };
981
982        let Message { role, content } = assistant_message;
983        assert_eq!(role, Role::Assistant);
984        assert_eq!(
985            content.first(),
986            Content::Text {
987                text: "\n\nHello there, how may I assist you today?".to_owned()
988            }
989        );
990
991        let Message { role, content } = assistant_message2;
992        {
993            assert_eq!(role, Role::Assistant);
994            assert_eq!(content.len(), 2);
995
996            let mut iter = content.into_iter();
997
998            match iter.next().unwrap() {
999                Content::Text { text } => {
1000                    assert_eq!(text, "\n\nHello there, how may I assist you today?");
1001                }
1002                _ => panic!("Expected text content"),
1003            }
1004
1005            match iter.next().unwrap() {
1006                Content::ToolUse { id, name, input } => {
1007                    assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1008                    assert_eq!(name, "get_weather");
1009                    assert_eq!(input, json!({"location": "San Francisco, CA"}));
1010                }
1011                _ => panic!("Expected tool use content"),
1012            }
1013
1014            assert_eq!(iter.next(), None);
1015        }
1016
1017        let Message { role, content } = user_message;
1018        {
1019            assert_eq!(role, Role::User);
1020            assert_eq!(content.len(), 3);
1021
1022            let mut iter = content.into_iter();
1023
1024            match iter.next().unwrap() {
1025                Content::Image { source } => {
1026                    assert_eq!(
1027                        source,
1028                        ImageSource {
1029                            data: ImageSourceData::Base64("/9j/4AAQSkZJRg...".to_owned()),
1030                            media_type: ImageFormat::JPEG,
1031                            r#type: SourceType::BASE64,
1032                        }
1033                    );
1034                }
1035                _ => panic!("Expected image content"),
1036            }
1037
1038            match iter.next().unwrap() {
1039                Content::Text { text } => {
1040                    assert_eq!(text, "What is in this image?");
1041                }
1042                _ => panic!("Expected text content"),
1043            }
1044
1045            match iter.next().unwrap() {
1046                Content::ToolResult {
1047                    tool_use_id,
1048                    content,
1049                    is_error,
1050                } => {
1051                    assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
1052                    assert_eq!(
1053                        content.first(),
1054                        ToolResultContent::Text {
1055                            text: "15 degrees".to_owned()
1056                        }
1057                    );
1058                    assert_eq!(is_error, None);
1059                }
1060                _ => panic!("Expected tool result content"),
1061            }
1062
1063            assert_eq!(iter.next(), None);
1064        }
1065    }
1066
1067    #[test]
1068    fn test_message_to_message_conversion() {
1069        let user_message: Message = serde_json::from_str(
1070            r#"
1071        {
1072            "role": "user",
1073            "content": [
1074                {
1075                    "type": "image",
1076                    "source": {
1077                        "type": "base64",
1078                        "media_type": "image/jpeg",
1079                        "data": "/9j/4AAQSkZJRg..."
1080                    }
1081                },
1082                {
1083                    "type": "text",
1084                    "text": "What is in this image?"
1085                },
1086                {
1087                    "type": "document",
1088                    "source": {
1089                        "type": "base64",
1090                        "data": "base64_encoded_pdf_data",
1091                        "media_type": "application/pdf"
1092                    }
1093                }
1094            ]
1095        }
1096        "#,
1097        )
1098        .unwrap();
1099
1100        let assistant_message = Message {
1101            role: Role::Assistant,
1102            content: OneOrMany::one(Content::ToolUse {
1103                id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1104                name: "get_weather".to_string(),
1105                input: json!({"location": "San Francisco, CA"}),
1106            }),
1107        };
1108
1109        let tool_message = Message {
1110            role: Role::User,
1111            content: OneOrMany::one(Content::ToolResult {
1112                tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1113                content: OneOrMany::one(ToolResultContent::Text {
1114                    text: "15 degrees".to_string(),
1115                }),
1116                is_error: None,
1117            }),
1118        };
1119
1120        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1121        let converted_assistant_message: message::Message =
1122            assistant_message.clone().try_into().unwrap();
1123        let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
1124
1125        match converted_user_message.clone() {
1126            message::Message::User { content } => {
1127                assert_eq!(content.len(), 3);
1128
1129                let mut iter = content.into_iter();
1130
1131                match iter.next().unwrap() {
1132                    message::UserContent::Image(message::Image {
1133                        data, media_type, ..
1134                    }) => {
1135                        assert_eq!(data, DocumentSourceKind::base64("/9j/4AAQSkZJRg..."));
1136                        assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
1137                    }
1138                    _ => panic!("Expected image content"),
1139                }
1140
1141                match iter.next().unwrap() {
1142                    message::UserContent::Text(message::Text { text }) => {
1143                        assert_eq!(text, "What is in this image?");
1144                    }
1145                    _ => panic!("Expected text content"),
1146                }
1147
1148                match iter.next().unwrap() {
1149                    message::UserContent::Document(message::Document {
1150                        data, media_type, ..
1151                    }) => {
1152                        assert_eq!(
1153                            data,
1154                            DocumentSourceKind::String("base64_encoded_pdf_data".into())
1155                        );
1156                        assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
1157                    }
1158                    _ => panic!("Expected document content"),
1159                }
1160
1161                assert_eq!(iter.next(), None);
1162            }
1163            _ => panic!("Expected user message"),
1164        }
1165
1166        match converted_tool_message.clone() {
1167            message::Message::User { content } => {
1168                let message::ToolResult { id, content, .. } = match content.first() {
1169                    message::UserContent::ToolResult(tool_result) => tool_result,
1170                    _ => panic!("Expected tool result content"),
1171                };
1172                assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1173                match content.first() {
1174                    message::ToolResultContent::Text(message::Text { text }) => {
1175                        assert_eq!(text, "15 degrees");
1176                    }
1177                    _ => panic!("Expected text content"),
1178                }
1179            }
1180            _ => panic!("Expected tool result content"),
1181        }
1182
1183        match converted_assistant_message.clone() {
1184            message::Message::Assistant { content, .. } => {
1185                assert_eq!(content.len(), 1);
1186
1187                match content.first() {
1188                    message::AssistantContent::ToolCall(message::ToolCall {
1189                        id, function, ..
1190                    }) => {
1191                        assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1192                        assert_eq!(function.name, "get_weather");
1193                        assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
1194                    }
1195                    _ => panic!("Expected tool call content"),
1196                }
1197            }
1198            _ => panic!("Expected assistant message"),
1199        }
1200
1201        let original_user_message: Message = converted_user_message.try_into().unwrap();
1202        let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
1203        let original_tool_message: Message = converted_tool_message.try_into().unwrap();
1204
1205        assert_eq!(user_message, original_user_message);
1206        assert_eq!(assistant_message, original_assistant_message);
1207        assert_eq!(tool_message, original_tool_message);
1208    }
1209
1210    #[test]
1211    fn test_content_format_conversion() {
1212        use crate::completion::message::ContentFormat;
1213
1214        let source_type: SourceType = ContentFormat::Url.try_into().unwrap();
1215        assert_eq!(source_type, SourceType::URL);
1216
1217        let content_format: ContentFormat = SourceType::URL.into();
1218        assert_eq!(content_format, ContentFormat::Url);
1219
1220        let source_type: SourceType = ContentFormat::Base64.try_into().unwrap();
1221        assert_eq!(source_type, SourceType::BASE64);
1222
1223        let content_format: ContentFormat = SourceType::BASE64.into();
1224        assert_eq!(content_format, ContentFormat::Base64);
1225
1226        let result: Result<SourceType, _> = ContentFormat::String.try_into();
1227        assert!(result.is_err());
1228        assert!(
1229            result
1230                .unwrap_err()
1231                .to_string()
1232                .contains("ContentFormat::String is deprecated")
1233        );
1234    }
1235}