1use crate::{
4 OneOrMany,
5 completion::{self, CompletionError, GetTokenUsage},
6 http_client::HttpClientExt,
7 message::{self, DocumentMediaType, DocumentSourceKind, MessageError, Reasoning},
8 one_or_many::string_or_one_or_many,
9 telemetry::{ProviderResponseExt, SpanCombinator},
10 wasm_compat::*,
11};
12use std::{convert::Infallible, str::FromStr};
13
14use super::client::Client;
15use crate::completion::CompletionRequest;
16use crate::providers::anthropic::streaming::StreamingCompletionResponse;
17use bytes::Bytes;
18use serde::{Deserialize, Serialize};
19use tracing::{Instrument, info_span};
20
21pub const CLAUDE_4_OPUS: &str = "claude-opus-4-0";
27pub const CLAUDE_4_SONNET: &str = "claude-sonnet-4-0";
29pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet-latest";
31pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
33pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
35
36pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
37pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
38pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
39
40#[derive(Debug, Deserialize, Serialize)]
41pub struct CompletionResponse {
42 pub content: Vec<Content>,
43 pub id: String,
44 pub model: String,
45 pub role: String,
46 pub stop_reason: Option<String>,
47 pub stop_sequence: Option<String>,
48 pub usage: Usage,
49}
50
51impl ProviderResponseExt for CompletionResponse {
52 type OutputMessage = Content;
53 type Usage = Usage;
54
55 fn get_response_id(&self) -> Option<String> {
56 Some(self.id.to_owned())
57 }
58
59 fn get_response_model_name(&self) -> Option<String> {
60 Some(self.model.to_owned())
61 }
62
63 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
64 self.content.clone()
65 }
66
67 fn get_text_response(&self) -> Option<String> {
68 let res = self
69 .content
70 .iter()
71 .filter_map(|x| {
72 if let Content::Text { text } = x {
73 Some(text.to_owned())
74 } else {
75 None
76 }
77 })
78 .collect::<Vec<String>>()
79 .join("\n");
80
81 if res.is_empty() { None } else { Some(res) }
82 }
83
84 fn get_usage(&self) -> Option<Self::Usage> {
85 Some(self.usage.clone())
86 }
87}
88
89#[derive(Clone, Debug, Deserialize, Serialize)]
90pub struct Usage {
91 pub input_tokens: u64,
92 pub cache_read_input_tokens: Option<u64>,
93 pub cache_creation_input_tokens: Option<u64>,
94 pub output_tokens: u64,
95}
96
97impl std::fmt::Display for Usage {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 write!(
100 f,
101 "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
102 self.input_tokens,
103 match self.cache_read_input_tokens {
104 Some(token) => token.to_string(),
105 None => "n/a".to_string(),
106 },
107 match self.cache_creation_input_tokens {
108 Some(token) => token.to_string(),
109 None => "n/a".to_string(),
110 },
111 self.output_tokens
112 )
113 }
114}
115
116impl GetTokenUsage for Usage {
117 fn token_usage(&self) -> Option<crate::completion::Usage> {
118 let mut usage = crate::completion::Usage::new();
119
120 usage.input_tokens = self.input_tokens
121 + self.cache_creation_input_tokens.unwrap_or_default()
122 + self.cache_read_input_tokens.unwrap_or_default();
123 usage.output_tokens = self.output_tokens;
124 usage.total_tokens = usage.input_tokens + usage.output_tokens;
125
126 Some(usage)
127 }
128}
129
130#[derive(Debug, Deserialize, Serialize)]
131pub struct ToolDefinition {
132 pub name: String,
133 pub description: Option<String>,
134 pub input_schema: serde_json::Value,
135}
136
137#[derive(Debug, Deserialize, Serialize)]
138#[serde(tag = "type", rename_all = "snake_case")]
139pub enum CacheControl {
140 Ephemeral,
141}
142
143impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
144 type Error = CompletionError;
145
146 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
147 let content = response
148 .content
149 .iter()
150 .map(|content| content.clone().try_into())
151 .collect::<Result<Vec<_>, _>>()?;
152
153 let choice = OneOrMany::many(content).map_err(|_| {
154 CompletionError::ResponseError(
155 "Response contained no message or tool call (empty)".to_owned(),
156 )
157 })?;
158
159 let usage = completion::Usage {
160 input_tokens: response.usage.input_tokens,
161 output_tokens: response.usage.output_tokens,
162 total_tokens: response.usage.input_tokens + response.usage.output_tokens,
163 };
164
165 Ok(completion::CompletionResponse {
166 choice,
167 usage,
168 raw_response: response,
169 })
170 }
171}
172
173#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
174pub struct Message {
175 pub role: Role,
176 #[serde(deserialize_with = "string_or_one_or_many")]
177 pub content: OneOrMany<Content>,
178}
179
180#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
181#[serde(rename_all = "lowercase")]
182pub enum Role {
183 User,
184 Assistant,
185}
186
187#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
188#[serde(tag = "type", rename_all = "snake_case")]
189pub enum Content {
190 Text {
191 text: String,
192 },
193 Image {
194 source: ImageSource,
195 },
196 ToolUse {
197 id: String,
198 name: String,
199 input: serde_json::Value,
200 },
201 ToolResult {
202 tool_use_id: String,
203 #[serde(deserialize_with = "string_or_one_or_many")]
204 content: OneOrMany<ToolResultContent>,
205 #[serde(skip_serializing_if = "Option::is_none")]
206 is_error: Option<bool>,
207 },
208 Document {
209 source: DocumentSource,
210 },
211 Thinking {
212 thinking: String,
213 signature: Option<String>,
214 },
215}
216
217impl FromStr for Content {
218 type Err = Infallible;
219
220 fn from_str(s: &str) -> Result<Self, Self::Err> {
221 Ok(Content::Text { text: s.to_owned() })
222 }
223}
224
225#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
226#[serde(tag = "type", rename_all = "snake_case")]
227pub enum ToolResultContent {
228 Text { text: String },
229 Image(ImageSource),
230}
231
232impl FromStr for ToolResultContent {
233 type Err = Infallible;
234
235 fn from_str(s: &str) -> Result<Self, Self::Err> {
236 Ok(ToolResultContent::Text { text: s.to_owned() })
237 }
238}
239
240#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
241#[serde(untagged)]
242pub enum ImageSourceData {
243 Base64(String),
244 Url(String),
245}
246
247impl From<ImageSourceData> for DocumentSourceKind {
248 fn from(value: ImageSourceData) -> Self {
249 match value {
250 ImageSourceData::Base64(data) => DocumentSourceKind::Base64(data),
251 ImageSourceData::Url(url) => DocumentSourceKind::Url(url),
252 }
253 }
254}
255
256impl TryFrom<DocumentSourceKind> for ImageSourceData {
257 type Error = MessageError;
258
259 fn try_from(value: DocumentSourceKind) -> Result<Self, Self::Error> {
260 match value {
261 DocumentSourceKind::Base64(data) => Ok(ImageSourceData::Base64(data)),
262 DocumentSourceKind::Url(url) => Ok(ImageSourceData::Url(url)),
263 _ => Err(MessageError::ConversionError("Content has no body".into())),
264 }
265 }
266}
267
268impl From<ImageSourceData> for String {
269 fn from(value: ImageSourceData) -> Self {
270 match value {
271 ImageSourceData::Base64(s) | ImageSourceData::Url(s) => s,
272 }
273 }
274}
275
276#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
277pub struct ImageSource {
278 pub data: ImageSourceData,
279 pub media_type: ImageFormat,
280 pub r#type: SourceType,
281}
282
283#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
284pub struct DocumentSource {
285 pub data: String,
286 pub media_type: DocumentFormat,
287 pub r#type: SourceType,
288}
289
290#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
291#[serde(rename_all = "lowercase")]
292pub enum ImageFormat {
293 #[serde(rename = "image/jpeg")]
294 JPEG,
295 #[serde(rename = "image/png")]
296 PNG,
297 #[serde(rename = "image/gif")]
298 GIF,
299 #[serde(rename = "image/webp")]
300 WEBP,
301}
302
303#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
307#[serde(rename_all = "lowercase")]
308pub enum DocumentFormat {
309 #[serde(rename = "application/pdf")]
310 PDF,
311}
312
313#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
314#[serde(rename_all = "lowercase")]
315pub enum SourceType {
316 BASE64,
317 URL,
318}
319
320impl From<String> for Content {
321 fn from(text: String) -> Self {
322 Content::Text { text }
323 }
324}
325
326impl From<String> for ToolResultContent {
327 fn from(text: String) -> Self {
328 ToolResultContent::Text { text }
329 }
330}
331
332impl TryFrom<message::ContentFormat> for SourceType {
333 type Error = MessageError;
334
335 fn try_from(format: message::ContentFormat) -> Result<Self, Self::Error> {
336 match format {
337 message::ContentFormat::Base64 => Ok(SourceType::BASE64),
338 message::ContentFormat::Url => Ok(SourceType::URL),
339 message::ContentFormat::String => Err(MessageError::ConversionError(
340 "ContentFormat::String is deprecated, use ContentFormat::Url for URLs".into(),
341 )),
342 }
343 }
344}
345
346impl From<SourceType> for message::ContentFormat {
347 fn from(source_type: SourceType) -> Self {
348 match source_type {
349 SourceType::BASE64 => message::ContentFormat::Base64,
350 SourceType::URL => message::ContentFormat::Url,
351 }
352 }
353}
354
355impl TryFrom<message::ImageMediaType> for ImageFormat {
356 type Error = MessageError;
357
358 fn try_from(media_type: message::ImageMediaType) -> Result<Self, Self::Error> {
359 Ok(match media_type {
360 message::ImageMediaType::JPEG => ImageFormat::JPEG,
361 message::ImageMediaType::PNG => ImageFormat::PNG,
362 message::ImageMediaType::GIF => ImageFormat::GIF,
363 message::ImageMediaType::WEBP => ImageFormat::WEBP,
364 _ => {
365 return Err(MessageError::ConversionError(
366 format!("Unsupported image media type: {media_type:?}").to_owned(),
367 ));
368 }
369 })
370 }
371}
372
373impl From<ImageFormat> for message::ImageMediaType {
374 fn from(format: ImageFormat) -> Self {
375 match format {
376 ImageFormat::JPEG => message::ImageMediaType::JPEG,
377 ImageFormat::PNG => message::ImageMediaType::PNG,
378 ImageFormat::GIF => message::ImageMediaType::GIF,
379 ImageFormat::WEBP => message::ImageMediaType::WEBP,
380 }
381 }
382}
383
384impl TryFrom<DocumentMediaType> for DocumentFormat {
385 type Error = MessageError;
386 fn try_from(value: DocumentMediaType) -> Result<Self, Self::Error> {
387 if !matches!(value, DocumentMediaType::PDF) {
388 return Err(MessageError::ConversionError(
389 "Anthropic only supports PDF documents".to_string(),
390 ));
391 };
392
393 Ok(DocumentFormat::PDF)
394 }
395}
396
397impl TryFrom<message::AssistantContent> for Content {
398 type Error = MessageError;
399 fn try_from(text: message::AssistantContent) -> Result<Self, Self::Error> {
400 match text {
401 message::AssistantContent::Text(message::Text { text }) => Ok(Content::Text { text }),
402 message::AssistantContent::Image(_) => Err(MessageError::ConversionError(
403 "Anthropic currently doesn't support images.".to_string(),
404 )),
405 message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
406 Ok(Content::ToolUse {
407 id,
408 name: function.name,
409 input: function.arguments,
410 })
411 }
412 message::AssistantContent::Reasoning(Reasoning {
413 reasoning,
414 signature,
415 ..
416 }) => Ok(Content::Thinking {
417 thinking: reasoning.first().cloned().unwrap_or(String::new()),
418 signature,
419 }),
420 }
421 }
422}
423
424impl TryFrom<message::Message> for Message {
425 type Error = MessageError;
426
427 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
428 Ok(match message {
429 message::Message::User { content } => Message {
430 role: Role::User,
431 content: content.try_map(|content| match content {
432 message::UserContent::Text(message::Text { text }) => {
433 Ok(Content::Text { text })
434 }
435 message::UserContent::ToolResult(message::ToolResult {
436 id, content, ..
437 }) => Ok(Content::ToolResult {
438 tool_use_id: id,
439 content: content.try_map(|content| match content {
440 message::ToolResultContent::Text(message::Text { text }) => {
441 Ok(ToolResultContent::Text { text })
442 }
443 message::ToolResultContent::Image(image) => {
444 let DocumentSourceKind::Base64(data) = image.data else {
445 return Err(MessageError::ConversionError(
446 "Only base64 strings can be used with the Anthropic API"
447 .to_string(),
448 ));
449 };
450 let media_type =
451 image.media_type.ok_or(MessageError::ConversionError(
452 "Image media type is required".to_owned(),
453 ))?;
454 Ok(ToolResultContent::Image(ImageSource {
455 data: ImageSourceData::Base64(data),
456 media_type: media_type.try_into()?,
457 r#type: SourceType::BASE64,
458 }))
459 }
460 })?,
461 is_error: None,
462 }),
463 message::UserContent::Image(message::Image {
464 data, media_type, ..
465 }) => {
466 let media_type = media_type.ok_or(MessageError::ConversionError(
467 "Image media type is required for Claude API".to_string(),
468 ))?;
469
470 let source = match data {
471 DocumentSourceKind::Base64(data) => ImageSource {
472 data: ImageSourceData::Base64(data),
473 r#type: SourceType::BASE64,
474 media_type: ImageFormat::try_from(media_type)?,
475 },
476 DocumentSourceKind::Url(url) => ImageSource {
477 data: ImageSourceData::Url(url),
478 r#type: SourceType::URL,
479 media_type: ImageFormat::try_from(media_type)?,
480 },
481 DocumentSourceKind::Unknown => {
482 return Err(MessageError::ConversionError(
483 "Image content has no body".into(),
484 ));
485 }
486 doc => {
487 return Err(MessageError::ConversionError(format!(
488 "Unsupported document type: {doc:?}"
489 )));
490 }
491 };
492
493 Ok(Content::Image { source })
494 }
495 message::UserContent::Document(message::Document {
496 data, media_type, ..
497 }) => {
498 let media_type = media_type.ok_or(MessageError::ConversionError(
499 "Document media type is required".to_string(),
500 ))?;
501
502 let data = match data {
503 DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => {
504 data
505 }
506 _ => {
507 return Err(MessageError::ConversionError(
508 "Only base64 encoded documents currently supported".into(),
509 ));
510 }
511 };
512
513 let source = DocumentSource {
514 data,
515 media_type: media_type.try_into()?,
516 r#type: SourceType::BASE64,
517 };
518 Ok(Content::Document { source })
519 }
520 message::UserContent::Audio { .. } => Err(MessageError::ConversionError(
521 "Audio is not supported in Anthropic".to_owned(),
522 )),
523 message::UserContent::Video { .. } => Err(MessageError::ConversionError(
524 "Video is not supported in Anthropic".to_owned(),
525 )),
526 })?,
527 },
528
529 message::Message::Assistant { content, .. } => Message {
530 content: content.try_map(|content| content.try_into())?,
531 role: Role::Assistant,
532 },
533 })
534 }
535}
536
537impl TryFrom<Content> for message::AssistantContent {
538 type Error = MessageError;
539
540 fn try_from(content: Content) -> Result<Self, Self::Error> {
541 Ok(match content {
542 Content::Text { text } => message::AssistantContent::text(text),
543 Content::ToolUse { id, name, input } => {
544 message::AssistantContent::tool_call(id, name, input)
545 }
546 Content::Thinking {
547 thinking,
548 signature,
549 } => message::AssistantContent::Reasoning(
550 Reasoning::new(&thinking).with_signature(signature),
551 ),
552 _ => {
553 return Err(MessageError::ConversionError(
554 "Content did not contain a message, tool call, or reasoning".to_owned(),
555 ));
556 }
557 })
558 }
559}
560
561impl From<ToolResultContent> for message::ToolResultContent {
562 fn from(content: ToolResultContent) -> Self {
563 match content {
564 ToolResultContent::Text { text } => message::ToolResultContent::text(text),
565 ToolResultContent::Image(ImageSource {
566 data,
567 media_type: format,
568 ..
569 }) => message::ToolResultContent::image_base64(data, Some(format.into()), None),
570 }
571 }
572}
573
574impl TryFrom<Message> for message::Message {
575 type Error = MessageError;
576
577 fn try_from(message: Message) -> Result<Self, Self::Error> {
578 Ok(match message.role {
579 Role::User => message::Message::User {
580 content: message.content.try_map(|content| {
581 Ok(match content {
582 Content::Text { text } => message::UserContent::text(text),
583 Content::ToolResult {
584 tool_use_id,
585 content,
586 ..
587 } => message::UserContent::tool_result(
588 tool_use_id,
589 content.map(|content| content.into()),
590 ),
591 Content::Image { source } => message::UserContent::Image(message::Image {
592 data: source.data.into(),
593 media_type: Some(source.media_type.into()),
594 detail: None,
595 additional_params: None,
596 }),
597 Content::Document { source } => message::UserContent::document(
598 source.data,
599 Some(message::DocumentMediaType::PDF),
600 ),
601 _ => {
602 return Err(MessageError::ConversionError(
603 "Unsupported content type for User role".to_owned(),
604 ));
605 }
606 })
607 })?,
608 },
609 Role::Assistant => match message.content.first() {
610 Content::Text { .. } | Content::ToolUse { .. } | Content::Thinking { .. } => {
611 message::Message::Assistant {
612 id: None,
613 content: message.content.try_map(|content| content.try_into())?,
614 }
615 }
616
617 _ => {
618 return Err(MessageError::ConversionError(
619 format!("Unsupported message for Assistant role: {message:?}").to_owned(),
620 ));
621 }
622 },
623 })
624 }
625}
626
627#[derive(Clone)]
628pub struct CompletionModel<T = reqwest::Client> {
629 pub(crate) client: Client<T>,
630 pub model: String,
631 pub default_max_tokens: Option<u64>,
632}
633
634impl<T> CompletionModel<T>
635where
636 T: HttpClientExt,
637{
638 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
639 let model = model.into();
640 let default_max_tokens = calculate_max_tokens(&model);
641
642 Self {
643 client,
644 model,
645 default_max_tokens,
646 }
647 }
648
649 pub fn with_model(client: Client<T>, model: &str) -> Self {
650 Self {
651 client,
652 model: model.to_string(),
653 default_max_tokens: Some(calculate_max_tokens_custom(model)),
654 }
655 }
656}
657
658fn calculate_max_tokens(model: &str) -> Option<u64> {
662 match model {
663 CLAUDE_4_OPUS => Some(32_000),
664 CLAUDE_4_SONNET | CLAUDE_3_7_SONNET => Some(64_000),
665 CLAUDE_3_5_SONNET | CLAUDE_3_5_HAIKU => Some(8_192),
666 _ => None,
667 }
668}
669
670fn calculate_max_tokens_custom(model: &str) -> u64 {
671 match model {
672 "claude-4-opus" => 32_000,
673 "claude-4-sonnet" | "claude-3.7-sonnet" => 64_000,
674 "claude-3.5-sonnet" | "claude-3.5-haiku" => 8_192,
675 _ => 4_096,
676 }
677}
678
679#[derive(Debug, Deserialize, Serialize)]
680pub struct Metadata {
681 user_id: Option<String>,
682}
683
684#[derive(Default, Debug, Serialize, Deserialize)]
685#[serde(tag = "type", rename_all = "snake_case")]
686pub enum ToolChoice {
687 #[default]
688 Auto,
689 Any,
690 None,
691 Tool {
692 name: String,
693 },
694}
695impl TryFrom<message::ToolChoice> for ToolChoice {
696 type Error = CompletionError;
697
698 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
699 let res = match value {
700 message::ToolChoice::Auto => Self::Auto,
701 message::ToolChoice::None => Self::None,
702 message::ToolChoice::Required => Self::Any,
703 message::ToolChoice::Specific { function_names } => {
704 if function_names.len() != 1 {
705 return Err(CompletionError::ProviderError(
706 "Only one tool may be specified to be used by Claude".into(),
707 ));
708 }
709
710 Self::Tool {
711 name: function_names.first().unwrap().to_string(),
712 }
713 }
714 };
715
716 Ok(res)
717 }
718}
719
720#[derive(Debug, Deserialize, Serialize)]
721struct AnthropicCompletionRequest {
722 model: String,
723 messages: Vec<Message>,
724 max_tokens: u64,
725 system: String,
726 #[serde(skip_serializing_if = "Option::is_none")]
727 temperature: Option<f64>,
728 #[serde(skip_serializing_if = "Option::is_none")]
729 tool_choice: Option<ToolChoice>,
730 #[serde(skip_serializing_if = "Vec::is_empty")]
731 tools: Vec<ToolDefinition>,
732 #[serde(flatten, skip_serializing_if = "Option::is_none")]
733 additional_params: Option<serde_json::Value>,
734}
735
736impl TryFrom<(&str, CompletionRequest)> for AnthropicCompletionRequest {
737 type Error = CompletionError;
738
739 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
740 let Some(max_tokens) = req.max_tokens else {
742 return Err(CompletionError::RequestError(
743 "`max_tokens` must be set for Anthropic".into(),
744 ));
745 };
746
747 let mut full_history = vec![];
748 if let Some(docs) = req.normalized_documents() {
749 full_history.push(docs);
750 }
751 full_history.extend(req.chat_history);
752
753 let messages = full_history
754 .into_iter()
755 .map(Message::try_from)
756 .collect::<Result<Vec<Message>, _>>()?;
757
758 let tools = req
759 .tools
760 .into_iter()
761 .map(|tool| ToolDefinition {
762 name: tool.name,
763 description: Some(tool.description),
764 input_schema: tool.parameters,
765 })
766 .collect::<Vec<_>>();
767
768 Ok(Self {
769 model: model.to_string(),
770 messages,
771 max_tokens,
772 system: req.preamble.unwrap_or_default(),
773 temperature: req.temperature,
774 tool_choice: req.tool_choice.and_then(|x| ToolChoice::try_from(x).ok()),
775 tools,
776 additional_params: req.additional_params,
777 })
778 }
779}
780
781impl<T> completion::CompletionModel for CompletionModel<T>
782where
783 T: HttpClientExt + Clone + Default + WasmCompatSend + WasmCompatSync + 'static,
784{
785 type Response = CompletionResponse;
786 type StreamingResponse = StreamingCompletionResponse;
787 type Client = Client<T>;
788
789 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
790 Self::new(client.clone(), model.into())
791 }
792
793 #[cfg_attr(feature = "worker", worker::send)]
794 async fn completion(
795 &self,
796 mut completion_request: completion::CompletionRequest,
797 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
798 let span = if tracing::Span::current().is_disabled() {
799 info_span!(
800 target: "rig::completions",
801 "chat",
802 gen_ai.operation.name = "chat",
803 gen_ai.provider.name = "anthropic",
804 gen_ai.request.model = &self.model,
805 gen_ai.system_instructions = &completion_request.preamble,
806 gen_ai.response.id = tracing::field::Empty,
807 gen_ai.response.model = tracing::field::Empty,
808 gen_ai.usage.output_tokens = tracing::field::Empty,
809 gen_ai.usage.input_tokens = tracing::field::Empty,
810 gen_ai.input.messages = tracing::field::Empty,
811 gen_ai.output.messages = tracing::field::Empty,
812 )
813 } else {
814 tracing::Span::current()
815 };
816
817 if completion_request.max_tokens.is_none() {
819 if let Some(tokens) = self.default_max_tokens {
820 completion_request.max_tokens = Some(tokens);
821 } else {
822 return Err(CompletionError::RequestError(
823 "`max_tokens` must be set for Anthropic".into(),
824 ));
825 }
826 }
827
828 let request =
829 AnthropicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
830 span.record_model_input(&request.messages);
831
832 async move {
833 let request: Vec<u8> = serde_json::to_vec(&request)?;
834
835 let req = self
836 .client
837 .post("/v1/messages")?
838 .body(request)
839 .map_err(|e| CompletionError::HttpError(e.into()))?;
840
841 let response = self
842 .client
843 .send::<_, Bytes>(req)
844 .await
845 .map_err(CompletionError::HttpError)?;
846
847 if response.status().is_success() {
848 match serde_json::from_slice::<ApiResponse<CompletionResponse>>(
849 response
850 .into_body()
851 .await
852 .map_err(CompletionError::HttpError)?
853 .to_vec()
854 .as_slice(),
855 )? {
856 ApiResponse::Message(completion) => {
857 let span = tracing::Span::current();
858 span.record_model_output(&completion.content);
859 span.record_response_metadata(&completion);
860 span.record_token_usage(&completion.usage);
861 tracing::trace!(
862 target: "rig::completions",
863 "Anthropic completion response: {}",
864 serde_json::to_string_pretty(&completion)?
865 );
866 completion.try_into()
867 }
868 ApiResponse::Error(ApiErrorResponse { message }) => {
869 Err(CompletionError::ResponseError(message))
870 }
871 }
872 } else {
873 let text: String = String::from_utf8_lossy(
874 &response
875 .into_body()
876 .await
877 .map_err(CompletionError::HttpError)?,
878 )
879 .into();
880 Err(CompletionError::ProviderError(text))
881 }
882 }
883 .instrument(span)
884 .await
885 }
886
887 #[cfg_attr(feature = "worker", worker::send)]
888 async fn stream(
889 &self,
890 request: CompletionRequest,
891 ) -> Result<
892 crate::streaming::StreamingCompletionResponse<Self::StreamingResponse>,
893 CompletionError,
894 > {
895 CompletionModel::stream(self, request).await
896 }
897}
898
899#[derive(Debug, Deserialize)]
900struct ApiErrorResponse {
901 message: String,
902}
903
904#[derive(Debug, Deserialize)]
905#[serde(tag = "type", rename_all = "snake_case")]
906enum ApiResponse<T> {
907 Message(T),
908 Error(ApiErrorResponse),
909}
910
911#[cfg(test)]
912mod tests {
913 use super::*;
914 use serde_json::json;
915 use serde_path_to_error::deserialize;
916
917 #[test]
918 fn test_deserialize_message() {
919 let assistant_message_json = r#"
920 {
921 "role": "assistant",
922 "content": "\n\nHello there, how may I assist you today?"
923 }
924 "#;
925
926 let assistant_message_json2 = r#"
927 {
928 "role": "assistant",
929 "content": [
930 {
931 "type": "text",
932 "text": "\n\nHello there, how may I assist you today?"
933 },
934 {
935 "type": "tool_use",
936 "id": "toolu_01A09q90qw90lq917835lq9",
937 "name": "get_weather",
938 "input": {"location": "San Francisco, CA"}
939 }
940 ]
941 }
942 "#;
943
944 let user_message_json = r#"
945 {
946 "role": "user",
947 "content": [
948 {
949 "type": "image",
950 "source": {
951 "type": "base64",
952 "media_type": "image/jpeg",
953 "data": "/9j/4AAQSkZJRg..."
954 }
955 },
956 {
957 "type": "text",
958 "text": "What is in this image?"
959 },
960 {
961 "type": "tool_result",
962 "tool_use_id": "toolu_01A09q90qw90lq917835lq9",
963 "content": "15 degrees"
964 }
965 ]
966 }
967 "#;
968
969 let assistant_message: Message = {
970 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
971 deserialize(jd).unwrap_or_else(|err| {
972 panic!("Deserialization error at {}: {}", err.path(), err);
973 })
974 };
975
976 let assistant_message2: Message = {
977 let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
978 deserialize(jd).unwrap_or_else(|err| {
979 panic!("Deserialization error at {}: {}", err.path(), err);
980 })
981 };
982
983 let user_message: Message = {
984 let jd = &mut serde_json::Deserializer::from_str(user_message_json);
985 deserialize(jd).unwrap_or_else(|err| {
986 panic!("Deserialization error at {}: {}", err.path(), err);
987 })
988 };
989
990 let Message { role, content } = assistant_message;
991 assert_eq!(role, Role::Assistant);
992 assert_eq!(
993 content.first(),
994 Content::Text {
995 text: "\n\nHello there, how may I assist you today?".to_owned()
996 }
997 );
998
999 let Message { role, content } = assistant_message2;
1000 {
1001 assert_eq!(role, Role::Assistant);
1002 assert_eq!(content.len(), 2);
1003
1004 let mut iter = content.into_iter();
1005
1006 match iter.next().unwrap() {
1007 Content::Text { text } => {
1008 assert_eq!(text, "\n\nHello there, how may I assist you today?");
1009 }
1010 _ => panic!("Expected text content"),
1011 }
1012
1013 match iter.next().unwrap() {
1014 Content::ToolUse { id, name, input } => {
1015 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1016 assert_eq!(name, "get_weather");
1017 assert_eq!(input, json!({"location": "San Francisco, CA"}));
1018 }
1019 _ => panic!("Expected tool use content"),
1020 }
1021
1022 assert_eq!(iter.next(), None);
1023 }
1024
1025 let Message { role, content } = user_message;
1026 {
1027 assert_eq!(role, Role::User);
1028 assert_eq!(content.len(), 3);
1029
1030 let mut iter = content.into_iter();
1031
1032 match iter.next().unwrap() {
1033 Content::Image { source } => {
1034 assert_eq!(
1035 source,
1036 ImageSource {
1037 data: ImageSourceData::Base64("/9j/4AAQSkZJRg...".to_owned()),
1038 media_type: ImageFormat::JPEG,
1039 r#type: SourceType::BASE64,
1040 }
1041 );
1042 }
1043 _ => panic!("Expected image content"),
1044 }
1045
1046 match iter.next().unwrap() {
1047 Content::Text { text } => {
1048 assert_eq!(text, "What is in this image?");
1049 }
1050 _ => panic!("Expected text content"),
1051 }
1052
1053 match iter.next().unwrap() {
1054 Content::ToolResult {
1055 tool_use_id,
1056 content,
1057 is_error,
1058 } => {
1059 assert_eq!(tool_use_id, "toolu_01A09q90qw90lq917835lq9");
1060 assert_eq!(
1061 content.first(),
1062 ToolResultContent::Text {
1063 text: "15 degrees".to_owned()
1064 }
1065 );
1066 assert_eq!(is_error, None);
1067 }
1068 _ => panic!("Expected tool result content"),
1069 }
1070
1071 assert_eq!(iter.next(), None);
1072 }
1073 }
1074
1075 #[test]
1076 fn test_message_to_message_conversion() {
1077 let user_message: Message = serde_json::from_str(
1078 r#"
1079 {
1080 "role": "user",
1081 "content": [
1082 {
1083 "type": "image",
1084 "source": {
1085 "type": "base64",
1086 "media_type": "image/jpeg",
1087 "data": "/9j/4AAQSkZJRg..."
1088 }
1089 },
1090 {
1091 "type": "text",
1092 "text": "What is in this image?"
1093 },
1094 {
1095 "type": "document",
1096 "source": {
1097 "type": "base64",
1098 "data": "base64_encoded_pdf_data",
1099 "media_type": "application/pdf"
1100 }
1101 }
1102 ]
1103 }
1104 "#,
1105 )
1106 .unwrap();
1107
1108 let assistant_message = Message {
1109 role: Role::Assistant,
1110 content: OneOrMany::one(Content::ToolUse {
1111 id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1112 name: "get_weather".to_string(),
1113 input: json!({"location": "San Francisco, CA"}),
1114 }),
1115 };
1116
1117 let tool_message = Message {
1118 role: Role::User,
1119 content: OneOrMany::one(Content::ToolResult {
1120 tool_use_id: "toolu_01A09q90qw90lq917835lq9".to_string(),
1121 content: OneOrMany::one(ToolResultContent::Text {
1122 text: "15 degrees".to_string(),
1123 }),
1124 is_error: None,
1125 }),
1126 };
1127
1128 let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
1129 let converted_assistant_message: message::Message =
1130 assistant_message.clone().try_into().unwrap();
1131 let converted_tool_message: message::Message = tool_message.clone().try_into().unwrap();
1132
1133 match converted_user_message.clone() {
1134 message::Message::User { content } => {
1135 assert_eq!(content.len(), 3);
1136
1137 let mut iter = content.into_iter();
1138
1139 match iter.next().unwrap() {
1140 message::UserContent::Image(message::Image {
1141 data, media_type, ..
1142 }) => {
1143 assert_eq!(data, DocumentSourceKind::base64("/9j/4AAQSkZJRg..."));
1144 assert_eq!(media_type, Some(message::ImageMediaType::JPEG));
1145 }
1146 _ => panic!("Expected image content"),
1147 }
1148
1149 match iter.next().unwrap() {
1150 message::UserContent::Text(message::Text { text }) => {
1151 assert_eq!(text, "What is in this image?");
1152 }
1153 _ => panic!("Expected text content"),
1154 }
1155
1156 match iter.next().unwrap() {
1157 message::UserContent::Document(message::Document {
1158 data, media_type, ..
1159 }) => {
1160 assert_eq!(
1161 data,
1162 DocumentSourceKind::String("base64_encoded_pdf_data".into())
1163 );
1164 assert_eq!(media_type, Some(message::DocumentMediaType::PDF));
1165 }
1166 _ => panic!("Expected document content"),
1167 }
1168
1169 assert_eq!(iter.next(), None);
1170 }
1171 _ => panic!("Expected user message"),
1172 }
1173
1174 match converted_tool_message.clone() {
1175 message::Message::User { content } => {
1176 let message::ToolResult { id, content, .. } = match content.first() {
1177 message::UserContent::ToolResult(tool_result) => tool_result,
1178 _ => panic!("Expected tool result content"),
1179 };
1180 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1181 match content.first() {
1182 message::ToolResultContent::Text(message::Text { text }) => {
1183 assert_eq!(text, "15 degrees");
1184 }
1185 _ => panic!("Expected text content"),
1186 }
1187 }
1188 _ => panic!("Expected tool result content"),
1189 }
1190
1191 match converted_assistant_message.clone() {
1192 message::Message::Assistant { content, .. } => {
1193 assert_eq!(content.len(), 1);
1194
1195 match content.first() {
1196 message::AssistantContent::ToolCall(message::ToolCall {
1197 id, function, ..
1198 }) => {
1199 assert_eq!(id, "toolu_01A09q90qw90lq917835lq9");
1200 assert_eq!(function.name, "get_weather");
1201 assert_eq!(function.arguments, json!({"location": "San Francisco, CA"}));
1202 }
1203 _ => panic!("Expected tool call content"),
1204 }
1205 }
1206 _ => panic!("Expected assistant message"),
1207 }
1208
1209 let original_user_message: Message = converted_user_message.try_into().unwrap();
1210 let original_assistant_message: Message = converted_assistant_message.try_into().unwrap();
1211 let original_tool_message: Message = converted_tool_message.try_into().unwrap();
1212
1213 assert_eq!(user_message, original_user_message);
1214 assert_eq!(assistant_message, original_assistant_message);
1215 assert_eq!(tool_message, original_tool_message);
1216 }
1217
1218 #[test]
1219 fn test_content_format_conversion() {
1220 use crate::completion::message::ContentFormat;
1221
1222 let source_type: SourceType = ContentFormat::Url.try_into().unwrap();
1223 assert_eq!(source_type, SourceType::URL);
1224
1225 let content_format: ContentFormat = SourceType::URL.into();
1226 assert_eq!(content_format, ContentFormat::Url);
1227
1228 let source_type: SourceType = ContentFormat::Base64.try_into().unwrap();
1229 assert_eq!(source_type, SourceType::BASE64);
1230
1231 let content_format: ContentFormat = SourceType::BASE64.into();
1232 assert_eq!(content_format, ContentFormat::Base64);
1233
1234 let result: Result<SourceType, _> = ContentFormat::String.try_into();
1235 assert!(result.is_err());
1236 assert!(
1237 result
1238 .unwrap_err()
1239 .to_string()
1240 .contains("ContentFormat::String is deprecated")
1241 );
1242 }
1243}