1use crate::{
2 anthropic::api::{
3 self, Base64ImageSource, ContentBlock, ContentBlockDelta, ContentBlockDeltaEvent,
4 ContentBlockStartEvent, CreateMessageParams, ImageSource, InputContentBlock, InputMessage,
5 InputMessageContent, Message as AnthropicMessage, MessageDeltaEvent, MessageDeltaUsage,
6 MessageStartEvent, MessageStreamEvent, RequestCitationsConfig, RequestImageBlock,
7 RequestSearchResultBlock, RequestTextBlock, RequestThinkingBlock, RequestToolResultBlock,
8 RequestToolUseBlock, SystemPrompt, ThinkingConfigDisabled, ThinkingConfigEnabled,
9 ThinkingConfigParam, Tool, ToolResultContent, ToolResultContentBlock, Usage,
10 },
11 client_utils, stream_utils, Citation, CitationDelta, ContentDelta, ImagePart, LanguageModel,
12 LanguageModelError, LanguageModelInput, LanguageModelMetadata, LanguageModelResult,
13 LanguageModelStream, Message, ModelResponse, ModelUsage, Part, PartDelta, PartialModelResponse,
14 ReasoningOptions, ReasoningPart, ReasoningPartDelta, TextPart, TextPartDelta, Tool as SdkTool,
15 ToolCallPart, ToolCallPartDelta, ToolChoiceOption, ToolResultPart,
16};
17use async_stream::try_stream;
18use futures::{future::BoxFuture, StreamExt};
19use reqwest::{
20 header::{HeaderMap, HeaderName, HeaderValue},
21 Client,
22};
23use serde_json::{Map, Value};
24use std::{collections::HashMap, sync::Arc};
25
26const PROVIDER: &str = "anthropic";
27const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
28const DEFAULT_API_VERSION: &str = "2023-06-01";
29
30pub struct AnthropicModel {
31 model_id: String,
32 api_key: String,
33 base_url: String,
34 api_version: String,
35 client: Client,
36 metadata: Option<Arc<LanguageModelMetadata>>,
37 headers: HashMap<String, String>,
38}
39
40#[derive(Clone, Default)]
41pub struct AnthropicModelOptions {
42 pub base_url: Option<String>,
43 pub api_key: String,
44 pub api_version: Option<String>,
45 pub headers: Option<HashMap<String, String>>,
46 pub client: Option<Client>,
47}
48
49impl AnthropicModel {
50 #[must_use]
51 pub fn new(model_id: impl Into<String>, mut options: AnthropicModelOptions) -> Self {
52 let base_url = options
53 .base_url
54 .take()
55 .unwrap_or_else(|| DEFAULT_BASE_URL.to_string())
56 .trim_end_matches('/')
57 .to_string();
58
59 let api_version = options
60 .api_version
61 .take()
62 .unwrap_or_else(|| DEFAULT_API_VERSION.to_string());
63
64 let client = options.client.take().unwrap_or_default();
65
66 let headers = options.headers.unwrap_or_default();
67
68 Self {
69 model_id: model_id.into(),
70 api_key: options.api_key,
71 base_url,
72 api_version,
73 client,
74 metadata: None,
75 headers,
76 }
77 }
78
79 #[must_use]
80 pub fn with_metadata(mut self, metadata: LanguageModelMetadata) -> Self {
81 self.metadata = Some(Arc::new(metadata));
82 self
83 }
84
85 fn request_headers(&self) -> LanguageModelResult<HeaderMap> {
86 let mut headers = HeaderMap::new();
87
88 headers.insert(
89 "x-api-key",
90 HeaderValue::from_str(&self.api_key).map_err(|error| {
91 LanguageModelError::InvalidInput(format!(
92 "Invalid Anthropic API key header value: {error}"
93 ))
94 })?,
95 );
96 headers.insert(
97 "anthropic-version",
98 HeaderValue::from_str(&self.api_version).map_err(|error| {
99 LanguageModelError::InvalidInput(format!(
100 "Invalid Anthropic version header value: {error}"
101 ))
102 })?,
103 );
104
105 for (key, value) in &self.headers {
106 let header_name = HeaderName::from_bytes(key.as_bytes()).map_err(|error| {
107 LanguageModelError::InvalidInput(format!(
108 "Invalid Anthropic header name '{key}': {error}"
109 ))
110 })?;
111 let header_value = HeaderValue::from_str(value).map_err(|error| {
112 LanguageModelError::InvalidInput(format!(
113 "Invalid Anthropic header value for '{key}': {error}"
114 ))
115 })?;
116 headers.insert(header_name, header_value);
117 }
118
119 Ok(headers)
120 }
121}
122
123impl LanguageModel for AnthropicModel {
124 fn provider(&self) -> &'static str {
125 PROVIDER
126 }
127
128 fn model_id(&self) -> String {
129 self.model_id.clone()
130 }
131
132 fn metadata(&self) -> Option<&LanguageModelMetadata> {
133 self.metadata.as_deref()
134 }
135
136 fn generate(
137 &self,
138 input: LanguageModelInput,
139 ) -> BoxFuture<'_, LanguageModelResult<ModelResponse>> {
140 Box::pin(async move {
141 crate::opentelemetry::trace_generate(
142 self.provider(),
143 &self.model_id,
144 input,
145 |input| async move {
146 let payload = convert_to_anthropic_create_params(input, &self.model_id, false)?;
147
148 let headers = self.request_headers()?;
149
150 let response: AnthropicMessage = client_utils::send_json(
151 &self.client,
152 &format!("{}/v1/messages", self.base_url),
153 &payload,
154 headers,
155 )
156 .await?;
157
158 let content = map_anthropic_message(response.content);
159 let usage = Some(map_anthropic_usage(&response.usage));
160
161 let cost =
162 if let (Some(usage), Some(metadata)) = (usage.as_ref(), self.metadata()) {
163 metadata
164 .pricing
165 .as_ref()
166 .map(|pricing| usage.calculate_cost(pricing))
167 } else {
168 None
169 };
170
171 Ok(ModelResponse {
172 content,
173 usage,
174 cost,
175 })
176 },
177 )
178 .await
179 })
180 }
181
182 fn stream(
183 &self,
184 input: LanguageModelInput,
185 ) -> BoxFuture<'_, LanguageModelResult<LanguageModelStream>> {
186 Box::pin(async move {
187 crate::opentelemetry::trace_stream(
188 self.provider(),
189 &self.model_id,
190 input,
191 |input| async move {
192 let payload = convert_to_anthropic_create_params(input, &self.model_id, true)?;
193
194 let headers = self.request_headers()?;
195 let mut chunk_stream = client_utils::send_sse_stream::<_, MessageStreamEvent>(
196 &self.client,
197 &format!("{}/v1/messages", self.base_url),
198 &payload,
199 headers,
200 self.provider(),
201 )
202 .await?;
203
204 let metadata = self.metadata.clone();
205
206 let stream = try_stream! {
207 while let Some(event) = chunk_stream.next().await {
208 match event? {
209 MessageStreamEvent::MessageStart(MessageStartEvent { message }) => {
210 let usage = map_anthropic_usage(&message.usage);
211 let cost = metadata
212 .as_ref()
213 .and_then(|meta| meta.pricing.as_ref())
214 .map(|pricing| usage.calculate_cost(pricing));
215
216 yield PartialModelResponse {
217 delta: None,
218 usage: Some(usage),
219 cost,
220 };
221 }
222 MessageStreamEvent::MessageDelta(MessageDeltaEvent { usage, .. }) => {
223 let usage = map_anthropic_message_delta_usage(&usage);
224 let cost = metadata
225 .as_ref()
226 .and_then(|meta| meta.pricing.as_ref())
227 .map(|pricing| usage.calculate_cost(pricing));
228
229 yield PartialModelResponse {
230 delta: None,
231 usage: Some(usage),
232 cost,
233 };
234 }
235 MessageStreamEvent::ContentBlockStart(ContentBlockStartEvent { content_block, index }) => {
236 let deltas = map_anthropic_content_block_start_event(content_block, index)?;
237 for delta in deltas {
238 yield PartialModelResponse {
239 delta: Some(delta),
240 ..Default::default()
241 };
242 }
243 }
244 MessageStreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { delta, index }) => {
245 if let Some(delta) = map_anthropic_content_block_delta_event(delta, index) {
246 yield PartialModelResponse {
247 delta: Some(delta),
248 ..Default::default()
249 };
250 }
251 }
252 _ => {}
253 }
254 }
255 };
256
257 Ok(LanguageModelStream::from_stream(stream))
258 },
259 )
260 .await
261 })
262 }
263}
264
265fn convert_to_anthropic_create_params(
266 input: LanguageModelInput,
267 model_id: &str,
268 stream: bool,
269) -> LanguageModelResult<Value> {
270 let LanguageModelInput {
271 system_prompt,
272 messages,
273 tools,
274 tool_choice,
275 response_format: _,
276 max_tokens,
277 temperature,
278 top_p,
279 top_k,
280 presence_penalty: _,
281 frequency_penalty: _,
282 seed: _,
283 modalities: _,
284 metadata: _,
285 audio: _,
286 reasoning,
287 extra,
288 } = input;
289
290 let max_tokens = max_tokens.unwrap_or(4096);
291
292 let message_params = convert_to_anthropic_messages(messages)?;
293
294 let params = CreateMessageParams {
295 max_tokens,
296 messages: message_params,
297 metadata: None,
298 model: api::Model::String(model_id.to_string()),
299 service_tier: None,
300 stop_sequences: None,
301 stream: Some(stream),
302 system: system_prompt.map(SystemPrompt::String),
303 temperature,
304 thinking: reasoning
305 .map(|options| convert_to_anthropic_thinking_config(&options, max_tokens)),
306 tool_choice: tool_choice.map(convert_to_anthropic_tool_choice),
307 tools: tools.map(|tool_list| {
308 tool_list
309 .into_iter()
310 .map(convert_tool)
311 .map(api::ToolUnion::Tool)
312 .collect()
313 }),
314 top_k: top_k
315 .map(|value| {
316 u32::try_from(value).map_err(|_| {
317 LanguageModelError::InvalidInput(
318 "Anthropic top_k must be a non-negative integer".to_string(),
319 )
320 })
321 })
322 .transpose()?,
323 top_p,
324 };
325
326 let mut value = serde_json::to_value(¶ms).map_err(|error| {
327 LanguageModelError::Invariant(
328 PROVIDER,
329 format!("Failed to serialize Anthropic request: {error}"),
330 )
331 })?;
332
333 if let Value::Object(ref mut map) = value {
334 if let Some(extra) = extra {
335 let Value::Object(extra_object) = extra else {
336 return Err(LanguageModelError::InvalidInput(
337 "Anthropic extra field must be a JSON object".to_string(),
338 ));
339 };
340
341 for (key, val) in extra_object {
342 map.insert(key, val);
343 }
344 }
345 } else {
346 return Err(LanguageModelError::Invariant(
347 PROVIDER,
348 "Anthropic request serialization did not produce an object".to_string(),
349 ));
350 }
351
352 Ok(value)
353}
354
355fn convert_tool(tool: SdkTool) -> Tool {
356 Tool {
357 name: tool.name,
358 description: Some(tool.description),
359 input_schema: tool.parameters,
360 cache_control: None,
361 type_field: None,
362 }
363}
364
365fn convert_to_anthropic_messages(messages: Vec<Message>) -> LanguageModelResult<Vec<InputMessage>> {
366 messages
367 .into_iter()
368 .map(|message| match message {
369 Message::User(user) => convert_message_parts_to_input_message("user", user.content),
370 Message::Assistant(assistant) => {
371 convert_message_parts_to_input_message("assistant", assistant.content)
372 }
373 Message::Tool(tool) => convert_message_parts_to_input_message("user", tool.content),
374 })
375 .collect()
376}
377
378fn convert_message_parts_to_input_message(
379 role: &str,
380 parts: Vec<Part>,
381) -> LanguageModelResult<InputMessage> {
382 let content_blocks = convert_parts_to_content_blocks(parts)?;
383 Ok(InputMessage {
384 content: InputMessageContent::Blocks(content_blocks),
385 role: role.to_string(),
386 })
387}
388
389fn convert_parts_to_content_blocks(
390 parts: Vec<Part>,
391) -> LanguageModelResult<Vec<InputContentBlock>> {
392 parts
393 .into_iter()
394 .map(convert_part_to_content_block)
395 .collect()
396}
397
398fn convert_part_to_content_block(part: Part) -> LanguageModelResult<InputContentBlock> {
399 match part {
400 Part::Text(text_part) => Ok(InputContentBlock::Text(create_request_text_block(
401 text_part.text,
402 ))),
403 Part::Image(image_part) => Ok(InputContentBlock::Image(create_request_image_block(
404 image_part,
405 ))),
406 Part::Source(source_part) => Ok(InputContentBlock::SearchResult(convert_source_part(
407 source_part,
408 )?)),
409 Part::ToolCall(tool_call) => Ok(InputContentBlock::ToolUse(RequestToolUseBlock {
410 cache_control: None,
411 id: tool_call.tool_call_id,
412 input: normalize_tool_args(tool_call.args)?,
413 name: tool_call.tool_name,
414 })),
415 Part::ToolResult(tool_result) => Ok(InputContentBlock::ToolResult(
416 convert_tool_result_part(tool_result)?,
417 )),
418 Part::Reasoning(reasoning_part) => Ok(convert_reasoning_part(reasoning_part)),
419 Part::Audio(_) => Err(LanguageModelError::Unsupported(
420 PROVIDER,
421 "Anthropic does not support audio parts".to_string(),
422 )),
423 }
424}
425
426fn convert_reasoning_part(reasoning_part: ReasoningPart) -> InputContentBlock {
427 if reasoning_part.text.is_empty() && reasoning_part.signature.is_some() {
428 return InputContentBlock::RedactedThinking(api::RequestRedactedThinkingBlock {
429 data: reasoning_part.signature.unwrap_or_default(),
430 });
431 }
432
433 InputContentBlock::Thinking(RequestThinkingBlock {
434 thinking: reasoning_part.text,
435 signature: reasoning_part.signature.unwrap_or_default(),
436 })
437}
438
439fn convert_tool_result_part(
440 tool_result: ToolResultPart,
441) -> LanguageModelResult<RequestToolResultBlock> {
442 let mut content_blocks = Vec::new();
443 for part in tool_result.content {
444 let block = convert_part_to_tool_result_content_block(part)?;
445 content_blocks.push(block);
446 }
447
448 let content = if content_blocks.is_empty() {
449 None
450 } else {
451 Some(ToolResultContent::Blocks(content_blocks))
452 };
453
454 Ok(RequestToolResultBlock {
455 cache_control: None,
456 content,
457 is_error: tool_result.is_error,
458 tool_use_id: tool_result.tool_call_id,
459 })
460}
461
462fn convert_part_to_tool_result_content_block(
463 part: Part,
464) -> LanguageModelResult<ToolResultContentBlock> {
465 match part {
466 Part::Text(text_part) => Ok(ToolResultContentBlock::Text(create_request_text_block(
467 text_part.text,
468 ))),
469 Part::Image(image_part) => Ok(ToolResultContentBlock::Image(create_request_image_block(
470 image_part,
471 ))),
472 Part::Source(source_part) => Ok(ToolResultContentBlock::SearchResult(convert_source_part(
473 source_part,
474 )?)),
475 _ => Err(LanguageModelError::Unsupported(
476 PROVIDER,
477 "Cannot convert tool result part to Anthropic content".to_string(),
478 )),
479 }
480}
481
482fn create_request_text_block(text: String) -> RequestTextBlock {
483 RequestTextBlock {
484 cache_control: None,
485 citations: None,
486 text,
487 type_field: "text".to_string(),
488 }
489}
490
491fn create_request_image_block(image_part: ImagePart) -> RequestImageBlock {
492 RequestImageBlock {
493 cache_control: None,
494 source: ImageSource::Base64(Base64ImageSource {
495 data: image_part.data,
496 media_type: image_part.mime_type,
497 }),
498 }
499}
500
501fn convert_source_part(
502 source_part: crate::SourcePart,
503) -> LanguageModelResult<RequestSearchResultBlock> {
504 let mut content = Vec::new();
505 for part in source_part.content {
506 match part {
507 Part::Text(text_part) => content.push(create_request_text_block(text_part.text)),
508 _ => {
509 return Err(LanguageModelError::Unsupported(
510 PROVIDER,
511 "Anthropic source part only supports text content".to_string(),
512 ))
513 }
514 }
515 }
516
517 Ok(RequestSearchResultBlock {
518 cache_control: None,
519 citations: Some(RequestCitationsConfig {
520 enabled: Some(true),
521 }),
522 content,
523 source: source_part.source,
524 title: source_part.title,
525 })
526}
527
528fn normalize_tool_args(args: Value) -> LanguageModelResult<Value> {
529 match args {
530 Value::Object(_) => Ok(args),
531 Value::Null => Ok(Value::Object(Map::new())),
532 _ => Err(LanguageModelError::InvalidInput(
533 "Anthropic tool call arguments must be a JSON object".to_string(),
534 )),
535 }
536}
537
538fn convert_to_anthropic_tool_choice(choice: ToolChoiceOption) -> api::ToolChoice {
539 match choice {
540 ToolChoiceOption::Auto => api::ToolChoice::Auto(api::ToolChoiceAuto {
541 disable_parallel_tool_use: None,
542 }),
543 ToolChoiceOption::None => api::ToolChoice::None(api::ToolChoiceNone {}),
544 ToolChoiceOption::Required => api::ToolChoice::Any(api::ToolChoiceAny {
545 disable_parallel_tool_use: None,
546 }),
547 ToolChoiceOption::Tool(tool) => api::ToolChoice::Tool(api::ToolChoiceTool {
548 disable_parallel_tool_use: None,
549 name: tool.tool_name,
550 }),
551 }
552}
553
554fn convert_to_anthropic_thinking_config(
555 reasoning: &ReasoningOptions,
556 max_tokens: u32,
557) -> ThinkingConfigParam {
558 if !reasoning.enabled {
559 return ThinkingConfigParam::Disabled(ThinkingConfigDisabled {});
560 }
561
562 let fallback = max_tokens.saturating_sub(1).max(1);
563 let budget = reasoning
564 .budget_tokens
565 .map_or(fallback, |value| value.max(1));
566
567 ThinkingConfigParam::Enabled(ThinkingConfigEnabled {
568 budget_tokens: budget,
569 })
570}
571
572fn map_anthropic_message(content: Vec<ContentBlock>) -> Vec<Part> {
573 let mut parts = Vec::new();
574 for block in content {
575 if let Some(part) = map_content_block(block) {
576 parts.push(part);
577 }
578 }
579 parts
580}
581
582fn map_content_block(block: ContentBlock) -> Option<Part> {
583 match block {
584 ContentBlock::Text(text_block) => Some(Part::Text(map_text_block(text_block))),
585 ContentBlock::Thinking(thinking_block) => {
586 Some(Part::Reasoning(map_thinking_block(thinking_block)))
587 }
588 ContentBlock::RedactedThinking(redacted_block) => {
589 Some(Part::Reasoning(map_redacted_thinking_block(redacted_block)))
590 }
591 ContentBlock::ToolUse(tool_use) => Some(Part::ToolCall(map_tool_use_block(tool_use))),
592 _ => None,
593 }
594}
595
596fn map_text_block(block: api::ResponseTextBlock) -> TextPart {
597 let citations = map_text_citations(block.citations);
598 TextPart {
599 text: block.text,
600 citations,
601 }
602}
603
604fn map_text_citations(citations: Option<Vec<api::ResponseCitation>>) -> Option<Vec<Citation>> {
605 let citations = citations?;
606
607 let mut results = Vec::new();
608
609 for citation in citations {
610 if let api::ResponseCitation::SearchResultLocation(
611 api::ResponseSearchResultLocationCitation {
612 cited_text,
613 end_block_index,
614 search_result_index: _,
615 source,
616 start_block_index,
617 title,
618 },
619 ) = citation
620 {
621 if source.is_empty() {
622 continue;
623 }
624
625 let mapped = Citation {
626 source,
627 title,
628 cited_text: if cited_text.is_empty() {
629 None
630 } else {
631 Some(cited_text)
632 },
633 start_index: start_block_index,
634 end_index: end_block_index,
635 };
636
637 results.push(mapped);
638 }
639 }
640
641 if results.is_empty() {
642 None
643 } else {
644 Some(results)
645 }
646}
647
648fn map_thinking_block(block: api::ResponseThinkingBlock) -> ReasoningPart {
649 ReasoningPart {
650 text: block.thinking,
651 signature: if block.signature.is_empty() {
652 None
653 } else {
654 Some(block.signature)
655 },
656 id: None,
657 }
658}
659
660fn map_redacted_thinking_block(block: api::ResponseRedactedThinkingBlock) -> ReasoningPart {
661 ReasoningPart {
662 text: String::new(),
663 signature: Some(block.data),
664 id: None,
665 }
666}
667
668fn map_tool_use_block(block: api::ResponseToolUseBlock) -> ToolCallPart {
669 ToolCallPart {
670 tool_call_id: block.id,
671 tool_name: block.name,
672 args: block.input,
673 id: None,
674 }
675}
676
677fn map_anthropic_usage(usage: &Usage) -> ModelUsage {
678 ModelUsage {
679 input_tokens: usage.input_tokens,
680 output_tokens: usage.output_tokens,
681 ..Default::default()
682 }
683}
684
685fn map_anthropic_message_delta_usage(usage: &MessageDeltaUsage) -> ModelUsage {
686 ModelUsage {
687 input_tokens: usage.input_tokens.unwrap_or(0),
688 output_tokens: usage.output_tokens,
689 ..Default::default()
690 }
691}
692
693fn map_anthropic_content_block_start_event(
694 content_block: ContentBlock,
695 index: usize,
696) -> LanguageModelResult<Vec<ContentDelta>> {
697 if let Some(part) = map_content_block(content_block) {
698 let mut delta = stream_utils::loosely_convert_part_to_part_delta(part)?;
699 if let PartDelta::ToolCall(tool_call_delta) = &mut delta {
700 tool_call_delta.args = Some(String::new());
701 }
702 Ok(vec![ContentDelta { index, part: delta }])
703 } else {
704 Ok(vec![])
705 }
706}
707
708fn map_anthropic_content_block_delta_event(
709 delta: ContentBlockDelta,
710 index: usize,
711) -> Option<ContentDelta> {
712 let part_delta = match delta {
713 ContentBlockDelta::TextDelta(delta) => PartDelta::Text(TextPartDelta {
714 text: delta.text,
715 citation: None,
716 }),
717 ContentBlockDelta::InputJsonDelta(delta) => PartDelta::ToolCall(ToolCallPartDelta {
718 tool_name: None,
719 args: Some(delta.partial_json),
720 tool_call_id: None,
721 id: None,
722 }),
723 ContentBlockDelta::ThinkingDelta(delta) => PartDelta::Reasoning(ReasoningPartDelta {
724 text: Some(delta.thinking),
725 signature: None,
726 id: None,
727 }),
728 ContentBlockDelta::SignatureDelta(delta) => PartDelta::Reasoning(ReasoningPartDelta {
729 text: None,
730 signature: Some(delta.signature),
731 id: None,
732 }),
733 ContentBlockDelta::CitationsDelta(delta) => {
734 if let Some(citation) = map_citation_delta(delta.citation) {
735 PartDelta::Text(TextPartDelta {
736 text: String::new(),
737 citation: Some(citation),
738 })
739 } else {
740 return None;
741 }
742 }
743 };
744
745 Some(ContentDelta {
746 index,
747 part: part_delta,
748 })
749}
750
751fn map_citation_delta(citation: api::ResponseCitation) -> Option<CitationDelta> {
752 let api::ResponseCitation::SearchResultLocation(api::ResponseSearchResultLocationCitation {
753 cited_text,
754 end_block_index,
755 search_result_index: _,
756 source,
757 start_block_index,
758 title,
759 }) = citation
760 else {
761 return None;
762 };
763
764 let result = CitationDelta {
765 r#type: "citation".to_string(),
766 source: Some(source),
767 title,
768 cited_text: if cited_text.is_empty() {
769 None
770 } else {
771 Some(cited_text)
772 },
773 start_index: Some(start_block_index),
774 end_index: Some(end_block_index),
775 };
776
777 Some(result)
778}