rig/providers/anthropic/
completion.rs

1//! Anthropic completion api implementation
2
3use crate::{
4    OneOrMany,
5    completion::{self, CompletionError, GetTokenUsage},
6    http_client::HttpClientExt,
7    json_utils,
8    message::{self, DocumentMediaType, DocumentSourceKind, MessageError, Reasoning},
9    one_or_many::string_or_one_or_many,
10    telemetry::{ProviderResponseExt, SpanCombinator},
11    wasm_compat::*,
12};
13use std::{convert::Infallible, str::FromStr};
14
15use super::client::Client;
16use crate::completion::CompletionRequest;
17use crate::providers::anthropic::streaming::StreamingCompletionResponse;
18use bytes::Bytes;
19use serde::{Deserialize, Serialize};
20use serde_json::json;
21use tracing::{Instrument, info_span};
22
23// ================================================================
24// Anthropic Completion API
25// ================================================================
26
27/// `claude-opus-4-0` completion model
28pub const CLAUDE_4_OPUS: &str = "claude-opus-4-0";
29
30/// `claude-sonnet-4-0` completion model
31pub const CLAUDE_4_SONNET: &str = "claude-sonnet-4-0";
32
33/// `claude-3-7-sonnet-latest` completion model
34pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
35
36/// `claude-3-5-sonnet-latest` completion model
37pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
38
39/// `claude-3-5-haiku-latest` completion model
40pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
41
42/// `claude-3-5-haiku-latest` completion model
43pub const CLAUDE_3_OPUS: &str = "claude-3-opus-latest";
44
45/// `claude-3-sonnet-20240229` completion model
46pub const CLAUDE_3_SONNET: &str = "claude-3-sonnet-20240229";
47
48/// `claude-3-haiku-20240307` completion model
49pub const CLAUDE_3_HAIKU: &str = "claude-3-haiku-20240307";
50
51pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
52pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
53pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
54
55#[derive(Debug, Deserialize, Serialize)]
56pub struct CompletionResponse {
57    pub content: Vec<Content>,
58    pub id: String,
59    pub model: String,
60    pub role: String,
61    pub stop_reason: Option<String>,
62    pub stop_sequence: Option<String>,
63    pub usage: Usage,
64}
65
66impl ProviderResponseExt for CompletionResponse {
67    type OutputMessage = Content;
68    type Usage = Usage;
69
70    fn get_response_id(&self) -> Option<String> {
71        Some(self.id.to_owned())
72    }
73
74    fn get_response_model_name(&self) -> Option<String> {
75        Some(self.model.to_owned())
76    }
77
78    fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
79        self.content.clone()
80    }
81
82    fn get_text_response(&self) -> Option<String> {
83        let res = self
84            .content
85            .iter()
86            .filter_map(|x| {
87                if let Content::Text { text } = x {
88                    Some(text.to_owned())
89                } else {
90                    None
91                }
92            })
93            .collect::<Vec<String>>()
94            .join("\n");
95
96        if res.is_empty() { None } else { Some(res) }
97    }
98
99    fn get_usage(&self) -> Option<Self::Usage> {
100        Some(self.usage.clone())
101    }
102}
103
104#[derive(Clone, Debug, Deserialize, Serialize)]
105pub struct Usage {
106    pub input_tokens: u64,
107    pub cache_read_input_tokens: Option<u64>,
108    pub cache_creation_input_tokens: Option<u64>,
109    pub output_tokens: u64,
110}
111
112impl std::fmt::Display for Usage {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        write!(
115            f,
116            "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
117            self.input_tokens,
118            match self.cache_read_input_tokens {
119                Some(token) => token.to_string(),
120                None => "n/a".to_string(),
121            },
122            match self.cache_creation_input_tokens {
123                Some(token) => token.to_string(),
124                None => "n/a".to_string(),
125            },
126            self.output_tokens
127        )
128    }
129}
130
131impl GetTokenUsage for Usage {
132    fn token_usage(&self) -> Option<crate::completion::Usage> {
133        let mut usage = crate::completion::Usage::new();
134
135        usage.input_tokens = self.input_tokens
136            + self.cache_creation_input_tokens.unwrap_or_default()
137            + self.cache_read_input_tokens.unwrap_or_default();
138        usage.output_tokens = self.output_tokens;
139        usage.total_tokens = usage.input_tokens + usage.output_tokens;
140
141        Some(usage)
142    }
143}
144
145#[derive(Debug, Deserialize, Serialize)]
146pub struct ToolDefinition {
147    pub name: String,
148    pub description: Option<String>,
149    pub input_schema: serde_json::Value,
150}
151
152#[derive(Debug, Deserialize, Serialize)]
153#[serde(tag = "type", rename_all = "snake_case")]
154pub enum CacheControl {
155    Ephemeral,
156}
157
158impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
159    type Error = CompletionError;
160
161    fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
162        let content = response
163            .content
164            .iter()
165            .map(|content| {
166                Ok(match content {
167                    Content::Text { text } => completion::AssistantContent::text(text),
168                    Content::ToolUse { id, name, input } => {
169                        completion::AssistantContent::tool_call(id, name, input.clone())
170                    }
171                    _ => {
172                        return Err(CompletionError::ResponseError(
173                            "Response did not contain a message or tool call".into(),
174                        ));
175                    }
176                })
177            })
178            .collect::<Result<Vec<_>, _>>()?;
179
180        let choice = OneOrMany::many(content).map_err(|_| {
181            CompletionError::ResponseError(
182                "Response contained no message or tool call (empty)".to_owned(),
183            )
184        })?;
185
186        let usage = completion::Usage {
187            input_tokens: response.usage.input_tokens,
188            output_tokens: response.usage.output_tokens,
189            total_tokens: response.usage.input_tokens + response.usage.output_tokens,
190        };
191
192        Ok(completion::CompletionResponse {
193            choice,
194            usage,
195            raw_response: response,
196        })
197    }
198}
199
200#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
201pub struct Message {
202    pub role: Role,
203    #[serde(deserialize_with = "string_or_one_or_many")]
204    pub content: OneOrMany<Content>,
205}
206
207#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
208#[serde(rename_all = "lowercase")]
209pub enum Role {
210    User,
211    Assistant,
212}
213
214#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
215#[serde(tag = "type", rename_all = "snake_case")]
216pub enum Content {
217    Text {
218        text: String,
219    },
220    Image {
221        source: ImageSource,
222    },
223    ToolUse {
224        id: String,
225        name: String,
226        input: serde_json::Value,
227    },
228    ToolResult {
229        tool_use_id: String,
230        #[serde(deserialize_with = "string_or_one_or_many")]
231        content: OneOrMany<ToolResultContent>,
232        #[serde(skip_serializing_if = "Option::is_none")]
233        is_error: Option<bool>,
234    },
235    Document {
236        source: DocumentSource,
237    },
238    Thinking {
239        thinking: String,
240        signature: Option<String>,
241    },
242}
243
244impl FromStr for Content {
245    type Err = Infallible;
246
247    fn from_str(s: &str) -> Result<Self, Self::Err> {
248        Ok(Content::Text { text: s.to_owned() })
249    }
250}
251
252#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
253#[serde(tag = "type", rename_all = "snake_case")]
254pub enum ToolResultContent {
255    Text { text: String },
256    Image(ImageSource),
257}
258
259impl FromStr for ToolResultContent {
260    type Err = Infallible;
261
262    fn from_str(s: &str) -> Result<Self, Self::Err> {
263        Ok(ToolResultContent::Text { text: s.to_owned() })
264    }
265}
266
267#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
268#[serde(untagged)]
269pub enum ImageSourceData {
270    Base64(String),
271    Url(String),
272}
273
274impl From<ImageSourceData> for DocumentSourceKind {
275    fn from(value: ImageSourceData) -> Self {
276        match value {
277            ImageSourceData::Base64(data) => DocumentSourceKind::Base64(data),
278            ImageSourceData::Url(url) => DocumentSourceKind::Url(url),
279        }
280    }
281}
282
283impl TryFrom<DocumentSourceKind> for ImageSourceData {
284    type Error = MessageError;
285
286    fn try_from(value: DocumentSourceKind) -> Result<Self, Self::Error> {
287        match value {
288            DocumentSourceKind::Base64(data) => Ok(ImageSourceData::Base64(data)),
289            DocumentSourceKind::Url(url) => Ok(ImageSourceData::Url(url)),
290            _ => Err(MessageError::ConversionError("Content has no body".into())),
291        }
292    }
293}
294
295impl From<ImageSourceData> for String {
296    fn from(value: ImageSourceData) -> Self {
297        match value {
298            ImageSourceData::Base64(s) | ImageSourceData::Url(s) => s,
299        }
300    }
301}
302
303#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
304pub struct ImageSource {
305    pub data: ImageSourceData,
306    pub media_type: ImageFormat,
307    pub r#type: SourceType,
308}
309
310#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
311pub struct DocumentSource {
312    pub data: String,
313    pub media_type: DocumentFormat,
314    pub r#type: SourceType,
315}
316
317#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
318#[serde(rename_all = "lowercase")]
319pub enum ImageFormat {
320    #[serde(rename = "image/jpeg")]
321    JPEG,
322    #[serde(rename = "image/png")]
323    PNG,
324    #[serde(rename = "image/gif")]
325    GIF,
326    #[serde(rename = "image/webp")]
327    WEBP,
328}
329
330/// The document format to be used.
331///
332/// Currently, Anthropic only supports PDF for text documents over the API (within a message). You can find more information about this here: <https://docs.anthropic.com/en/docs/build-with-claude/pdf-support>
333#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
334#[serde(rename_all = "lowercase")]
335pub enum DocumentFormat {
336    #[serde(rename = "application/pdf")]
337    PDF,
338}
339
340#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
341#[serde(rename_all = "lowercase")]
342pub enum SourceType {
343    BASE64,
344    URL,
345}
346
347impl From<String> for Content {
348    fn from(text: String) -> Self {
349        Content::Text { text }
350    }
351}
352
353impl From<String> for ToolResultContent {
354    fn from(text: String) -> Self {
355        ToolResultContent::Text { text }
356    }
357}
358
359impl TryFrom<message::ContentFormat> for SourceType {
360    type Error = MessageError;
361
362    fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
363        match format {
364            message::ContentFormat::Base64 => Ok(SourceType::BASE64),
365            message::ContentFormat::Url => Ok(SourceType::URL),
366            message::ContentFormat::String => Err(MessageError::ConversionError(
367                "ContentFormat::String is deprecated, use ContentFormat::Url for URLs".into(),
368            )),
369        }
370    }
371}
372
373impl From<SourceType> for message::ContentFormat {
374    fn from(source_type: SourceType) -> Self {
375        match source_type {
376            SourceType::BASE64 => message::ContentFormat::Base64,
377            SourceType::URL => message::ContentFormat::Url,
378        }
379    }
380}
381
382impl TryFrom<message::ImageMediaType> for ImageFormat {
383    type Error = MessageError;
384
385    fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
386        Ok(match media_type {
387            message::ImageMediaType::JPEG => ImageFormat::JPEG,
388            message::ImageMediaType::PNG => ImageFormat::PNG,
389            message::ImageMediaType::GIF => ImageFormat::GIF,
390            message::ImageMediaType::WEBP => ImageFormat::WEBP,
391            _ => {
392                return Err(MessageError::ConversionError(
393                    format!("Unsupported image media type: {media_type:?}").to_owned(),
394                ));
395            }
396        })
397    }
398}
399
400impl From<ImageFormat> for message::ImageMediaType {
401    fn from(format: ImageFormat) -> Self {
402        match format {
403            ImageFormat::JPEG => message::ImageMediaType::JPEG,
404            ImageFormat::PNG => message::ImageMediaType::PNG,
405            ImageFormat::GIF => message::ImageMediaType::GIF,
406            ImageFormat::WEBP => message::ImageMediaType::WEBP,
407        }
408    }
409}
410
411impl TryFrom<DocumentMediaType> for DocumentFormat {
412    type Error = MessageError;
413    fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
414        if !matches!(value, DocumentMediaType::PDF) {
415            return Err(MessageError::ConversionError(
416                "Anthropic only supports PDF documents".to_string(),
417            ));
418        };
419
420        Ok(DocumentFormat::PDF)
421    }
422}
423
424impl From<message::AssistantContent> for Content {
425    fn from(text: message::AssistantContent) -> Self {
426        match text {
427            message::AssistantContent::Text(message::Text { text }) => Content::Text { text },
428            message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
429                Content::ToolUse {
430                    id,
431                    name: function.name,
432                    input: function.arguments,
433                }
434            }
435            message::AssistantContent::Reasoning(Reasoning {
436                reasoning,
437                signature,
438                ..
439            }) => Content::Thinking {
440                thinking: reasoning.first().cloned().unwrap_or(String::new()),
441                signature,
442            },
443        }
444    }
445}
446
447impl TryFrom<message::Message> for Message {
448    type Error = MessageError;
449
450    fn try_from(message: message::Message) -> Result<Self, Self::Error> {
451        Ok(match message {
452            message::Message::User { content } => Message {
453                role: Role::User,
454                content: content.try_map(|content| match content {
455                    message::UserContent::Text(message::Text { text }) => {
456                        Ok(Content::Text { text })
457                    }
458                    message::UserContent::ToolResult(message::ToolResult {
459                        id, content, ..
460                    }) => Ok(Content::ToolResult {
461                        tool_use_id: id,
462                        content: content.try_map(|content| match content {
463                            message::ToolResultContent::Text(message::Text { text }) => {
464                                Ok(ToolResultContent::Text { text })
465                            }
466                            message::ToolResultContent::Image(image) => {
467                                let DocumentSourceKind::Base64(data) = image.data else {
468                                    return Err(MessageError::ConversionError(
469                                        "Only base64 strings can be used with the Anthropic API"
470                                            .to_string(),
471                                    ));
472                                };
473                                let media_type =
474                                    image.media_type.ok_or(MessageError::ConversionError(
475                                        "Image media type is required".to_owned(),
476                                    ))?;
477                                Ok(ToolResultContent::Image(ImageSource {
478                                    data: ImageSourceData::Base64(data),
479                                    media_type: media_type.try_into()?,
480                                    r#type: SourceType::BASE64,
481                                }))
482                            }
483                        })?,
484                        is_error: None,
485                    }),
486                    message::UserContent::Image(message::Image {
487                        data, media_type, ..
488                    }) => {
489                        let media_type = media_type.ok_or(MessageError::ConversionError(
490                            "Image media type is required for Claude API".into(),
491                        ))?;
492
493                        let source = match data {
494                            DocumentSourceKind::Base64(data) => ImageSource {
495                                data: ImageSourceData::Base64(data),
496                                r#type: SourceType::BASE64,
497                                media_type: ImageFormat::try_from(media_type)?,
498                            },
499                            DocumentSourceKind::Url(url) => ImageSource {
500                                data: ImageSourceData::Url(url),
501                                r#type: SourceType::URL,
502                                media_type: ImageFormat::try_from(media_type)?,
503                            },
504                            DocumentSourceKind::Unknown => {
505                                return Err(MessageError::ConversionError(
506                                    "Image content has no body".into(),
507                                ));
508                            }
509                            doc => {
510                                return Err(MessageError::ConversionError(format!(
511                                    "Unsupported document type: {doc:?}"
512                                )));
513                            }
514                        };
515
516                        Ok(Content::Image { source })
517                    }
518                    message::UserContent::Document(message::Document {
519                        data, media_type, ..
520                    }) => {
521                        let media_type = media_type.ok_or(MessageError::ConversionError(
522                            "Document media type is required".to_string(),
523                        ))?;
524
525                        let data = match data {
526                            DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
527                                data
528                            }
529                            _ => {
530                                return Err(MessageError::ConversionError(
531                                    "Only base64 encoded documents currently supported".into(),
532                                ));
533                            }
534                        };
535
536                        let source = DocumentSource {
537                            data,
538                            media_type: media_type.try_into()?,
539                            r#type: SourceType::BASE64,
540                        };
541                        Ok(Content::Document { source })
542                    }
543                    message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
544                        "Audio is not supported in Anthropic".to_owned(),
545                    )),
546                    message::UserContent::Video { .. } => Err(MessageError::ConversionError(
547                        "Video is not supported in Anthropic".to_owned(),
548                    )),
549                })?,
550            },
551
552            message::Message::Assistant { content, .. } => Message {
553                content: content.map(|content| content.into()),
554                role: Role::Assistant,
555            },
556        })
557    }
558}
559
560impl TryFrom<Content> for message::AssistantContent {
561    type Error = MessageError;
562
563    fn try_from(content: Content) -> Result<Self, Self::Error> {
564        Ok(match content {
565            Content::Text { text } => message::AssistantContent::text(text),
566            Content::ToolUse { id, name, input } => {
567                message::AssistantContent::tool_call(id, name, input)
568            }
569            Content::Thinking {
570                thinking,
571                signature,
572            } => message::AssistantContent::Reasoning(
573                Reasoning::new(&thinking).with_signature(signature),
574            ),
575            _ => {
576                return Err(MessageError::ConversionError(
577                    format!("Unsupported content type for Assistant role: {content:?}").to_owned(),
578                ));
579            }
580        })
581    }
582}
583
584impl From<ToolResultContent> for message::ToolResultContent {
585    fn from(content: ToolResultContent) -> Self {
586        match content {
587            ToolResultContent::Text { text } => message::ToolResultContent::text(text),
588            ToolResultContent::Image(ImageSource {
589                data,
590                media_type: format,
591                ..
592            }) => message::ToolResultContent::image_base64(data, Some(format.into()), None),
593        }
594    }
595}
596
597impl TryFrom<Message> for message::Message {
598    type Error = MessageError;
599
600    fn try_from(message: Message) -> Result<Self, Self::Error> {
601        Ok(match message.role {
602            Role::User => message::Message::User {
603                content: message.content.try_map(|content| {
604                    Ok(match content {
605                        Content::Text { text } => message::UserContent::text(text),
606                        Content::ToolResult {
607                            tool_use_id,
608                            content,
609                            ..
610                        } => message::UserContent::tool_result(
611                            tool_use_id,
612                            content.map(|content| content.into()),
613                        ),
614                        Content::Image { source } => message::UserContent::Image(message::Image {
615                            data: source.data.into(),
616                            media_type: Some(source.media_type.into()),
617                            detail: None,
618                            additional_params: None,
619                        }),
620                        Content::Document { source } => message::UserContent::document(
621                            source.data,
622                            Some(message::DocumentMediaType::PDF),
623                        ),
624                        _ => {
625                            return Err(MessageError::ConversionError(
626                                "Unsupported content type for User role".to_owned(),
627                            ));
628                        }
629                    })
630                })?,
631            },
632            Role::Assistant => match message.content.first() {
633                Content::Text { .. } | Content::ToolUse { .. } | Content::Thinking { .. } => {
634                    message::Message::Assistant {
635                        id: None,
636                        content: message.content.try_map(|content| content.try_into())?,
637                    }
638                }
639
640                _ => {
641                    return Err(MessageError::ConversionError(
642                        format!("Unsupported message for Assistant role: {message:?}").to_owned(),
643                    ));
644                }
645            },
646        })
647    }
648}
649
650#[derive(Clone)]
651pub struct CompletionModel<T = reqwest::Client>
652where
653    T: WasmCompatSend,
654{
655    pub(crate) client: Client<T>,
656    pub model: String,
657    pub default_max_tokens: Option<u64>,
658}
659
660impl<T> CompletionModel<T>
661where
662    T: HttpClientExt,
663{
664    pub fn new(client: Client<T>, model: &str) -> Self {
665        Self {
666            client,
667            model: model.to_string(),
668            default_max_tokens: calculate_max_tokens(model),
669        }
670    }
671}
672
673/// Anthropic requires a `max_tokens` parameter to be set, which is dependent on the model. If not
674/// set or if set too high, the request will fail. The following values are based on the models
675/// available at the time of writing.
676///
677/// Dev Note: This is really bad design, I'm not sure why they did it like this..
678fn calculate_max_tokens(model: &str) -> Option<u64> {
679    if model.starts_with("claude-opus-4") {
680        Some(32000)
681    } else if model.starts_with("claude-sonnet-4") || model.starts_with("claude-3-7-sonnet") {
682        Some(64000)
683    } else if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
684        Some(8192)
685    } else if model.starts_with("claude-3-opus")
686        || model.starts_with("claude-3-sonnet")
687        || model.starts_with("claude-3-haiku")
688    {
689        Some(4096)
690    } else {
691        None
692    }
693}
694
695#[derive(Debug, Deserialize, Serialize)]
696pub struct Metadata {
697    user_id: Option<String>,
698}
699
700#[derive(Default, Debug, Serialize, Deserialize)]
701#[serde(tag = "type", rename_all = "snake_case")]
702pub enum ToolChoice {
703    #[default]
704    Auto,
705    Any,
706    None,
707    Tool {
708        name: String,
709    },
710}
711impl TryFrom<message::ToolChoice> for ToolChoice {
712    type Error = CompletionError;
713
714    fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
715        let res = match value {
716            message::ToolChoice::Auto => Self::Auto,
717            message::ToolChoice::None => Self::None,
718            message::ToolChoice::Required => Self::Any,
719            message::ToolChoice::Specific { function_names } => {
720                if function_names.len() != 1 {
721                    return Err(CompletionError::ProviderError(
722                        "Only one tool may be specified to be used by Claude".into(),
723                    ));
724                }
725
726                Self::Tool {
727                    name: function_names.first().unwrap().to_string(),
728                }
729            }
730        };
731
732        Ok(res)
733    }
734}
735impl<T> completion::CompletionModel for CompletionModel<T>
736where
737    T: HttpClientExt + Clone + Default + WasmCompatSend + WasmCompatSync + 'static,
738{
739    type Response = CompletionResponse;
740    type StreamingResponse = StreamingCompletionResponse;
741
742    #[cfg_attr(feature = "worker", worker::send)]
743    async fn completion(
744        &self,
745        completion_request: completion::CompletionRequest,
746    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
747        let span = if tracing::Span::current().is_disabled() {
748            info_span!(
749                target: "rig::completions",
750                "chat",
751                gen_ai.operation.name = "chat",
752                gen_ai.provider.name = "anthropic",
753                gen_ai.request.model = self.model,
754                gen_ai.system_instructions = &completion_request.preamble,
755                gen_ai.response.id = tracing::field::Empty,
756                gen_ai.response.model = tracing::field::Empty,
757                gen_ai.usage.output_tokens = tracing::field::Empty,
758                gen_ai.usage.input_tokens = tracing::field::Empty,
759                gen_ai.input.messages = tracing::field::Empty,
760                gen_ai.output.messages = tracing::field::Empty,
761            )
762        } else {
763            tracing::Span::current()
764        };
765        // Note: Ideally we'd introduce provider-specific Request models to handle the
766        // specific requirements of each provider. For now, we just manually check while
767        // building the request as a raw JSON document.
768
769        // Check if max_tokens is set, required for Anthropic
770        let max_tokens = if let Some(tokens) = completion_request.max_tokens {
771            tokens
772        } else if let Some(tokens) = self.default_max_tokens {
773            tokens
774        } else {
775            return Err(CompletionError::RequestError(
776                "`max_tokens` must be set for Anthropic".into(),
777            ));
778        };
779
780        let mut full_history = vec![];
781        if let Some(docs) = completion_request.normalized_documents() {
782            full_history.push(docs);
783        }
784        full_history.extend(completion_request.chat_history);
785        span.record_model_input(&full_history);
786
787        let full_history = full_history
788            .into_iter()
789            .map(Message::try_from)
790            .collect::<Result<Vec<Message>, _>>()?;
791
792        let mut request = json!({
793            "model": self.model,
794            "messages": full_history,
795            "max_tokens": max_tokens,
796            "system": completion_request.preamble.unwrap_or("".to_string()),
797        });
798
799        if let Some(temperature) = completion_request.temperature {
800            json_utils::merge_inplace(&mut request, json!({ "temperature": temperature }));
801        }
802
803        let tool_choice = if let Some(tool_choice) = completion_request.tool_choice {
804            Some(ToolChoice::try_from(tool_choice)?)
805        } else {
806            None
807        };
808
809        if !completion_request.tools.is_empty() {
810            let mut tools_json = json!({
811                "tools": completion_request
812                    .tools
813                    .into_iter()
814                    .map(|tool| ToolDefinition {
815                        name: tool.name,
816                        description: Some(tool.description),
817                        input_schema: tool.parameters,
818                    })
819                    .collect::<Vec<_>>(),
820            });
821
822            // Only include tool_choice if it's explicitly set (not None)
823            // When omitted, Anthropic defaults to "auto"
824            if let Some(tc) = tool_choice {
825                tools_json["tool_choice"] = serde_json::to_value(tc)?;
826            }
827
828            json_utils::merge_inplace(&mut request, tools_json);
829        }
830
831        if let Some(ref params) = completion_request.additional_params {
832            json_utils::merge_inplace(&mut request, params.clone())
833        }
834
835        async move {
836            let request: Vec<u8> = serde_json::to_vec(&request)?;
837
838            if let Ok(json_str) = String::from_utf8(request.clone()) {
839                tracing::debug!("Request body:\n{}", json_str);
840            }
841
842            let req = self
843                .client
844                .post("/v1/messages")
845                .header("Content-Type", "application/json")
846                .body(request)
847                .map_err(|e| CompletionError::HttpError(e.into()))?;
848
849            let response = self
850                .client
851                .send::<_, Bytes>(req)
852                .await
853                .map_err(CompletionError::HttpError)?;
854
855            if response.status().is_success() {
856                match serde_json::from_slice::<ApiResponse<CompletionResponse>>(
857                    response
858                        .into_body()
859                        .await
860                        .map_err(CompletionError::HttpError)?
861                        .to_vec()
862                        .as_slice(),
863                )? {
864                    ApiResponse::Message(completion) => {
865                        let span = tracing::Span::current();
866                        span.record_model_output(&completion.content);
867                        span.record_response_metadata(&completion);
868                        span.record_token_usage(&completion.usage);
869                        completion.try_into()
870                    }
871                    ApiResponse::Error(ApiErrorResponse { message }) => {
872                        Err(CompletionError::ResponseError(message))
873                    }
874                }
875            } else {
876                let text: String = String::from_utf8_lossy(
877                    &response
878                        .into_body()
879                        .await
880                        .map_err(CompletionError::HttpError)?,
881                )
882                .into();
883                Err(CompletionError::ProviderError(text))
884            }
885        }
886        .instrument(span)
887        .await
888    }
889
890    #[cfg_attr(feature = "worker", worker::send)]
891    async fn stream(
892        &self,
893        request: CompletionRequest,
894    ) -> Result<
895        crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
896        CompletionError,
897    > {
898        CompletionModel::stream(self, request).await
899    }
900}
901
902#[derive(Debug, Deserialize)]
903struct ApiErrorResponse {
904    message: String,
905}
906
907#[derive(Debug, Deserialize)]
908#[serde(tag = "type", rename_all = "snake_case")]
909enum ApiResponse<T> {
910    Message(T),
911    Error(ApiErrorResponse),
912}
913
914#[cfg(test)]
915mod tests {
916    use super::*;
917    use serde_path_to_error::deserialize;
918
919    #[test]
920    fn test_deserialize_message() {
921        let assistant_message_json = r#"
922        {
923            "role": "assistant",
924            "content": "\n\nHello there, how may I assist you today?"
925        }
926        "#;
927
928        let assistant_message_json2 = r#"
929        {
930            "role": "assistant",
931            "content": [
932                {
933                    "type": "text",
934                    "text": "\n\nHello there, how may I assist you today?"
935                },
936                {
937                    "type": "tool_use",
938                    "id": "toolu_01A09q90qw90lq917835lq9",
939                    "name": "get_weather",
940                    "input": {"location": "San Francisco, CA"}
941                }
942            ]
943        }
944        "#;
945
946        let user_message_json = r#"
947        {
948            "role": "user",
949            "content": [
950                {
951                    "type": "image",
952                    "source": {
953                        "type": "base64",
954                        "media_type": "image/jpeg",
955                        "data": "/9j/4AAQSkZJRg..."
956                    }
957                },
958                {
959                    "type": "text",
960                    "text": "What is in this image?"
961                },
962                {
963                    "type": "tool_result",
964                    "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
965                    "content": "15 degrees"
966                }
967            ]
968        }
969        "#;
970
971        let assistant_message: Message = {
972            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
973            deserialize(jd).unwrap_or_else(|err| {
974                panic!("Deserialization error at {}: {}", err.path(), err);
975            })
976        };
977
978        let assistant_message2: Message = {
979            let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
980            deserialize(jd).unwrap_or_else(|err| {
981                panic!("Deserialization error at {}: {}", err.path(), err);
982            })
983        };
984
985        let user_message: Message = {
986            let jd = &mut serde_json::Deserializer::from_str(user_message_json);
987            deserialize(jd).unwrap_or_else(|err| {
988                panic!("Deserialization error at {}: {}", err.path(), err);
989            })
990        };
991
992        let Message { role, content } = assistant_message;
993        assert_eq!(role, Role::Assistant);
994        assert_eq!(
995            content.first(),
996            Content::Text {
997                text: "\n\nHello there, how may I assist you today?".to_owned()
998            }
999        );
1000
1001        let Message { role, content } = assistant_message2;
1002        {
1003            assert_eq!(role, Role::Assistant);
1004            assert_eq!(content.len(), 2);
1005
1006            let mut iter = content.into_iter();
1007
1008            match iter.next().unwrap() {
1009                Content::Text { text } => {
1010                    assert_eq!(text, "\n\nHello there, how may I assist you today?");
1011                }
1012                _ => panic!("Expected text content"),
1013            }
1014
1015            match iter.next().unwrap() {
1016                Content::ToolUse { id, name, input } => {
1017                    assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1018                    assert_eq!(name, "get_weather");
1019                    assert_eq!(input, json!({"location": "San Francisco, CA"}));
1020                }
1021                _ => panic!("Expected tool use content"),
1022            }
1023
1024            assert_eq!(iter.next(), None);
1025        }
1026
1027        let Message { role, content } = user_message;
1028        {
1029            assert_eq!(role, Role::User);
1030            assert_eq!(content.len(), 3);
1031
1032            let mut iter = content.into_iter();
1033
1034            match iter.next().unwrap() {
1035                Content::Image { source } => {
1036                    assert_eq!(
1037                        source,
1038                        ImageSource {
1039                            data: ImageSourceData::Base64("/9j/4AAQSkZJRg...".to_owned()),
1040                            media_type: ImageFormat::JPEG,
1041                            r#type: SourceType::BASE64,
1042                        }
1043                    );
1044                }
1045                _ => panic!("Expected image content"),
1046            }
1047
1048            match iter.next().unwrap() {
1049                Content::Text { text } => {
1050                    assert_eq!(text, "What is in this image?");
1051                }
1052                _ => panic!("Expected text content"),
1053            }
1054
1055            match iter.next().unwrap() {
1056                Content::ToolResult {
1057                    tool_use_id,
1058                    content,
1059                    is_error,
1060                } => {
1061                    assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
1062                    assert_eq!(
1063                        content.first(),
1064                        ToolResultContent::Text {
1065                            text: "15 degrees".to_owned()
1066                        }
1067                    );
1068                    assert_eq!(is_error, None);
1069                }
1070                _ => panic!("Expected tool result content"),
1071            }
1072
1073            assert_eq!(iter.next(), None);
1074        }
1075    }
1076
1077    #[test]
1078    fn test_message_to_message_conversion() {
1079        let user_message: Message = serde_json::from_str(
1080            r#"
1081        {
1082            "role": "user",
1083            "content": [
1084                {
1085                    "type": "image",
1086                    "source": {
1087                        "type": "base64",
1088                        "media_type": "image/jpeg",
1089                        "data": "/9j/4AAQSkZJRg..."
1090                    }
1091                },
1092                {
1093                    "type": "text",
1094                    "text": "What is in this image?"
1095                },
1096                {
1097                    "type": "document",
1098                    "source": {
1099                        "type": "base64",
1100                        "data": "base64_encoded_pdf_data",
1101                        "media_type": "application/pdf"
1102                    }
1103                }
1104            ]
1105        }
1106        "#,
1107        )
1108        .unwrap();
1109
1110        let assistant_message = Message {
1111            role: Role::Assistant,
1112            content: OneOrMany::one(Content::ToolUse {
1113                id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1114                name: "get_weather".to_string(),
1115                input: json!({"location": "San Francisco, CA"}),
1116            }),
1117        };
1118
1119        let tool_message = Message {
1120            role: Role::User,
1121            content: OneOrMany::one(Content::ToolResult {
1122                tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1123                content: OneOrMany::one(ToolResultContent::Text {
1124                    text: "15 degrees".to_string(),
1125                }),
1126                is_error: None,
1127            }),
1128        };
1129
1130        let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1131        let converted_assistant_message: message::Message =
1132            assistant_message.clone().try_into().unwrap();
1133        let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
1134
1135        match converted_user_message.clone() {
1136            message::Message::User { content } => {
1137                assert_eq!(content.len(), 3);
1138
1139                let mut iter = content.into_iter();
1140
1141                match iter.next().unwrap() {
1142                    message::UserContent::Image(message::Image {
1143                        data, media_type, ..
1144                    }) => {
1145                        assert_eq!(data, DocumentSourceKind::base64("/9j/4AAQSkZJRg..."));
1146                        assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
1147                    }
1148                    _ => panic!("Expected image content"),
1149                }
1150
1151                match iter.next().unwrap() {
1152                    message::UserContent::Text(message::Text { text }) => {
1153                        assert_eq!(text, "What is in this image?");
1154                    }
1155                    _ => panic!("Expected text content"),
1156                }
1157
1158                match iter.next().unwrap() {
1159                    message::UserContent::Document(message::Document {
1160                        data, media_type, ..
1161                    }) => {
1162                        assert_eq!(
1163                            data,
1164                            DocumentSourceKind::String("base64_encoded_pdf_data".into())
1165                        );
1166                        assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
1167                    }
1168                    _ => panic!("Expected document content"),
1169                }
1170
1171                assert_eq!(iter.next(), None);
1172            }
1173            _ => panic!("Expected user message"),
1174        }
1175
1176        match converted_tool_message.clone() {
1177            message::Message::User { content } => {
1178                let message::ToolResult { id, content, .. } = match content.first() {
1179                    message::UserContent::ToolResult(tool_result) => tool_result,
1180                    _ => panic!("Expected tool result content"),
1181                };
1182                assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1183                match content.first() {
1184                    message::ToolResultContent::Text(message::Text { text }) => {
1185                        assert_eq!(text, "15 degrees");
1186                    }
1187                    _ => panic!("Expected text content"),
1188                }
1189            }
1190            _ => panic!("Expected tool result content"),
1191        }
1192
1193        match converted_assistant_message.clone() {
1194            message::Message::Assistant { content, .. } => {
1195                assert_eq!(content.len(), 1);
1196
1197                match content.first() {
1198                    message::AssistantContent::ToolCall(message::ToolCall {
1199                        id, function, ..
1200                    }) => {
1201                        assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1202                        assert_eq!(function.name, "get_weather");
1203                        assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
1204                    }
1205                    _ => panic!("Expected tool call content"),
1206                }
1207            }
1208            _ => panic!("Expected assistant message"),
1209        }
1210
1211        let original_user_message: Message = converted_user_message.try_into().unwrap();
1212        let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
1213        let original_tool_message: Message = converted_tool_message.try_into().unwrap();
1214
1215        assert_eq!(user_message, original_user_message);
1216        assert_eq!(assistant_message, original_assistant_message);
1217        assert_eq!(tool_message, original_tool_message);
1218    }
1219
1220    #[test]
1221    fn test_content_format_conversion() {
1222        use crate::completion::message::ContentFormat;
1223
1224        let source_type: SourceType = ContentFormat::Url.try_into().unwrap();
1225        assert_eq!(source_type, SourceType::URL);
1226
1227        let content_format: ContentFormat = SourceType::URL.into();
1228        assert_eq!(content_format, ContentFormat::Url);
1229
1230        let source_type: SourceType = ContentFormat::Base64.try_into().unwrap();
1231        assert_eq!(source_type, SourceType::BASE64);
1232
1233        let content_format: ContentFormat = SourceType::BASE64.into();
1234        assert_eq!(content_format, ContentFormat::Base64);
1235
1236        let result: Result<SourceType, _> = ContentFormat::String.try_into();
1237        assert!(result.is_err());
1238        assert!(
1239            result
1240                .unwrap_err()
1241                .to_string()
1242                .contains("ContentFormat::String is deprecated")
1243        );
1244    }
1245}