llm_sdk/anthropic/
model.rs

1use crate::{
2    anthropic::api::{
3        self, Base64ImageSource, ContentBlock, ContentBlockDelta, ContentBlockDeltaEvent,
4        ContentBlockStartEvent, CreateMessageParams, ImageSource, InputContentBlock, InputMessage,
5        InputMessageContent, Message as AnthropicMessage, MessageDeltaEvent, MessageDeltaUsage,
6        MessageStartEvent, MessageStreamEvent, RequestCitationsConfig, RequestImageBlock,
7        RequestSearchResultBlock, RequestTextBlock, RequestThinkingBlock, RequestToolResultBlock,
8        RequestToolUseBlock, SystemPrompt, ThinkingConfigDisabled, ThinkingConfigEnabled,
9        ThinkingConfigParam, Tool, ToolResultContent, ToolResultContentBlock, Usage,
10    },
11    client_utils, stream_utils, Citation, CitationDelta, ContentDelta, ImagePart, LanguageModel,
12    LanguageModelError, LanguageModelInput, LanguageModelMetadata, LanguageModelResult,
13    LanguageModelStream, Message, ModelResponse, ModelUsage, Part, PartDelta, PartialModelResponse,
14    ReasoningOptions, ReasoningPart, ReasoningPartDelta, TextPart, TextPartDelta, Tool as SdkTool,
15    ToolCallPart, ToolCallPartDelta, ToolChoiceOption, ToolResultPart,
16};
17use async_stream::try_stream;
18use futures::{future::BoxFuture, StreamExt};
19use reqwest::{
20    header::{HeaderMap, HeaderName, HeaderValue},
21    Client,
22};
23use serde_json::{Map, Value};
24use std::{collections::HashMap, sync::Arc};
25
26const PROVIDER: &str = "anthropic";
27const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
28const DEFAULT_API_VERSION: &str = "2023-06-01";
29
30pub struct AnthropicModel {
31    model_id: String,
32    api_key: String,
33    base_url: String,
34    api_version: String,
35    client: Client,
36    metadata: Option<Arc<LanguageModelMetadata>>,
37    headers: HashMap<String, String>,
38}
39
40#[derive(Clone, Default)]
41pub struct AnthropicModelOptions {
42    pub base_url: Option<String>,
43    pub api_key: String,
44    pub api_version: Option<String>,
45    pub headers: Option<HashMap<String, String>>,
46    pub client: Option<Client>,
47}
48
49impl AnthropicModel {
50    #[must_use]
51    pub fn new(model_id: impl Into<String>, mut options: AnthropicModelOptions) -> Self {
52        let base_url = options
53            .base_url
54            .take()
55            .unwrap_or_else(|| DEFAULT_BASE_URL.to_string())
56            .trim_end_matches('/')
57            .to_string();
58
59        let api_version = options
60            .api_version
61            .take()
62            .unwrap_or_else(|| DEFAULT_API_VERSION.to_string());
63
64        let client = options.client.take().unwrap_or_default();
65
66        let headers = options.headers.unwrap_or_default();
67
68        Self {
69            model_id: model_id.into(),
70            api_key: options.api_key,
71            base_url,
72            api_version,
73            client,
74            metadata: None,
75            headers,
76        }
77    }
78
79    #[must_use]
80    pub fn with_metadata(mut self, metadata: LanguageModelMetadata) -> Self {
81        self.metadata = Some(Arc::new(metadata));
82        self
83    }
84
85    fn request_headers(&self) -> LanguageModelResult<HeaderMap> {
86        let mut headers = HeaderMap::new();
87
88        headers.insert(
89            "x-api-key",
90            HeaderValue::from_str(&self.api_key).map_err(|error| {
91                LanguageModelError::InvalidInput(format!(
92                    "Invalid Anthropic API key header value: {error}"
93                ))
94            })?,
95        );
96        headers.insert(
97            "anthropic-version",
98            HeaderValue::from_str(&self.api_version).map_err(|error| {
99                LanguageModelError::InvalidInput(format!(
100                    "Invalid Anthropic version header value: {error}"
101                ))
102            })?,
103        );
104
105        for (key, value) in &self.headers {
106            let header_name = HeaderName::from_bytes(key.as_bytes()).map_err(|error| {
107                LanguageModelError::InvalidInput(format!(
108                    "Invalid Anthropic header name '{key}': {error}"
109                ))
110            })?;
111            let header_value = HeaderValue::from_str(value).map_err(|error| {
112                LanguageModelError::InvalidInput(format!(
113                    "Invalid Anthropic header value for '{key}': {error}"
114                ))
115            })?;
116            headers.insert(header_name, header_value);
117        }
118
119        Ok(headers)
120    }
121}
122
123impl LanguageModel for AnthropicModel {
124    fn provider(&self) -> &'static str {
125        PROVIDER
126    }
127
128    fn model_id(&self) -> String {
129        self.model_id.clone()
130    }
131
132    fn metadata(&self) -> Option<&LanguageModelMetadata> {
133        self.metadata.as_deref()
134    }
135
136    fn generate(
137        &self,
138        input: LanguageModelInput,
139    ) -> BoxFuture<'_, LanguageModelResult<ModelResponse>> {
140        Box::pin(async move {
141            crate::opentelemetry::trace_generate(
142                self.provider(),
143                &self.model_id,
144                input,
145                |input| async move {
146                    let payload = convert_to_anthropic_create_params(input, &self.model_id, false)?;
147
148                    let headers = self.request_headers()?;
149
150                    let response: AnthropicMessage = client_utils::send_json(
151                        &self.client,
152                        &format!("{}/v1/messages", self.base_url),
153                        &payload,
154                        headers,
155                    )
156                    .await?;
157
158                    let content = map_anthropic_message(response.content);
159                    let usage = Some(map_anthropic_usage(&response.usage));
160
161                    let cost =
162                        if let (Some(usage), Some(metadata)) = (usage.as_ref(), self.metadata()) {
163                            metadata
164                                .pricing
165                                .as_ref()
166                                .map(|pricing| usage.calculate_cost(pricing))
167                        } else {
168                            None
169                        };
170
171                    Ok(ModelResponse {
172                        content,
173                        usage,
174                        cost,
175                    })
176                },
177            )
178            .await
179        })
180    }
181
182    fn stream(
183        &self,
184        input: LanguageModelInput,
185    ) -> BoxFuture<'_, LanguageModelResult<LanguageModelStream>> {
186        Box::pin(async move {
187            crate::opentelemetry::trace_stream(
188                self.provider(),
189                &self.model_id,
190                input,
191                |input| async move {
192                    let payload = convert_to_anthropic_create_params(input, &self.model_id, true)?;
193
194                    let headers = self.request_headers()?;
195                    let mut chunk_stream = client_utils::send_sse_stream::<_, MessageStreamEvent>(
196                        &self.client,
197                        &format!("{}/v1/messages", self.base_url),
198                        &payload,
199                        headers,
200                        self.provider(),
201                    )
202                    .await?;
203
204                    let metadata = self.metadata.clone();
205
206                    let stream = try_stream! {
207                        while let Some(event) = chunk_stream.next().await {
208                            match event? {
209                                MessageStreamEvent::MessageStart(MessageStartEvent { message }) => {
210                                    let usage = map_anthropic_usage(&message.usage);
211                                    let cost = metadata
212                                        .as_ref()
213                                        .and_then(|meta| meta.pricing.as_ref())
214                                        .map(|pricing| usage.calculate_cost(pricing));
215
216                                    yield PartialModelResponse {
217                                        delta: None,
218                                        usage: Some(usage),
219                                        cost,
220                                    };
221                                }
222                                MessageStreamEvent::MessageDelta(MessageDeltaEvent { usage, .. }) => {
223                                    let usage = map_anthropic_message_delta_usage(&usage);
224                                    let cost = metadata
225                                        .as_ref()
226                                        .and_then(|meta| meta.pricing.as_ref())
227                                        .map(|pricing| usage.calculate_cost(pricing));
228
229                                    yield PartialModelResponse {
230                                        delta: None,
231                                        usage: Some(usage),
232                                        cost,
233                                    };
234                                }
235                                MessageStreamEvent::ContentBlockStart(ContentBlockStartEvent { content_block, index }) => {
236                                    let deltas = map_anthropic_content_block_start_event(content_block, index)?;
237                                    for delta in deltas {
238                                        yield PartialModelResponse {
239                                            delta: Some(delta),
240                                            ..Default::default()
241                                        };
242                                    }
243                                }
244                                MessageStreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { delta, index }) => {
245                                    if let Some(delta) = map_anthropic_content_block_delta_event(delta, index) {
246                                        yield PartialModelResponse {
247                                            delta: Some(delta),
248                                            ..Default::default()
249                                        };
250                                    }
251                                }
252                                _ => {}
253                            }
254                        }
255                    };
256
257                    Ok(LanguageModelStream::from_stream(stream))
258                },
259            )
260            .await
261        })
262    }
263}
264
265fn convert_to_anthropic_create_params(
266    input: LanguageModelInput,
267    model_id: &str,
268    stream: bool,
269) -> LanguageModelResult<Value> {
270    let LanguageModelInput {
271        system_prompt,
272        messages,
273        tools,
274        tool_choice,
275        response_format: _,
276        max_tokens,
277        temperature,
278        top_p,
279        top_k,
280        presence_penalty: _,
281        frequency_penalty: _,
282        seed: _,
283        modalities: _,
284        metadata: _,
285        audio: _,
286        reasoning,
287        extra,
288    } = input;
289
290    let max_tokens = max_tokens.unwrap_or(4096);
291
292    let message_params = convert_to_anthropic_messages(messages)?;
293
294    let params = CreateMessageParams {
295        max_tokens,
296        messages: message_params,
297        metadata: None,
298        model: api::Model::String(model_id.to_string()),
299        service_tier: None,
300        stop_sequences: None,
301        stream: Some(stream),
302        system: system_prompt.map(SystemPrompt::String),
303        temperature,
304        thinking: reasoning
305            .map(|options| convert_to_anthropic_thinking_config(&options, max_tokens)),
306        tool_choice: tool_choice.map(convert_to_anthropic_tool_choice),
307        tools: tools.map(|tool_list| {
308            tool_list
309                .into_iter()
310                .map(convert_tool)
311                .map(api::ToolUnion::Tool)
312                .collect()
313        }),
314        top_k: top_k
315            .map(|value| {
316                u32::try_from(value).map_err(|_| {
317                    LanguageModelError::InvalidInput(
318                        "Anthropic top_k must be a non-negative integer".to_string(),
319                    )
320                })
321            })
322            .transpose()?,
323        top_p,
324    };
325
326    let mut value = serde_json::to_value(&params).map_err(|error| {
327        LanguageModelError::Invariant(
328            PROVIDER,
329            format!("Failed to serialize Anthropic request: {error}"),
330        )
331    })?;
332
333    if let Value::Object(ref mut map) = value {
334        if let Some(extra) = extra {
335            let Value::Object(extra_object) = extra else {
336                return Err(LanguageModelError::InvalidInput(
337                    "Anthropic extra field must be a JSON object".to_string(),
338                ));
339            };
340
341            for (key, val) in extra_object {
342                map.insert(key, val);
343            }
344        }
345    } else {
346        return Err(LanguageModelError::Invariant(
347            PROVIDER,
348            "Anthropic request serialization did not produce an object".to_string(),
349        ));
350    }
351
352    Ok(value)
353}
354
355fn convert_tool(tool: SdkTool) -> Tool {
356    Tool {
357        name: tool.name,
358        description: Some(tool.description),
359        input_schema: tool.parameters,
360        cache_control: None,
361        type_field: None,
362    }
363}
364
365fn convert_to_anthropic_messages(messages: Vec<Message>) -> LanguageModelResult<Vec<InputMessage>> {
366    messages
367        .into_iter()
368        .map(|message| match message {
369            Message::User(user) => convert_message_parts_to_input_message("user", user.content),
370            Message::Assistant(assistant) => {
371                convert_message_parts_to_input_message("assistant", assistant.content)
372            }
373            Message::Tool(tool) => convert_message_parts_to_input_message("user", tool.content),
374        })
375        .collect()
376}
377
378fn convert_message_parts_to_input_message(
379    role: &str,
380    parts: Vec<Part>,
381) -> LanguageModelResult<InputMessage> {
382    let content_blocks = convert_parts_to_content_blocks(parts)?;
383    Ok(InputMessage {
384        content: InputMessageContent::Blocks(content_blocks),
385        role: role.to_string(),
386    })
387}
388
389fn convert_parts_to_content_blocks(
390    parts: Vec<Part>,
391) -> LanguageModelResult<Vec<InputContentBlock>> {
392    parts
393        .into_iter()
394        .map(convert_part_to_content_block)
395        .collect()
396}
397
398fn convert_part_to_content_block(part: Part) -> LanguageModelResult<InputContentBlock> {
399    match part {
400        Part::Text(text_part) => Ok(InputContentBlock::Text(create_request_text_block(
401            text_part.text,
402        ))),
403        Part::Image(image_part) => Ok(InputContentBlock::Image(create_request_image_block(
404            image_part,
405        ))),
406        Part::Source(source_part) => Ok(InputContentBlock::SearchResult(convert_source_part(
407            source_part,
408        )?)),
409        Part::ToolCall(tool_call) => Ok(InputContentBlock::ToolUse(RequestToolUseBlock {
410            cache_control: None,
411            id: tool_call.tool_call_id,
412            input: normalize_tool_args(tool_call.args)?,
413            name: tool_call.tool_name,
414        })),
415        Part::ToolResult(tool_result) => Ok(InputContentBlock::ToolResult(
416            convert_tool_result_part(tool_result)?,
417        )),
418        Part::Reasoning(reasoning_part) => Ok(convert_reasoning_part(reasoning_part)),
419        Part::Audio(_) => Err(LanguageModelError::Unsupported(
420            PROVIDER,
421            "Anthropic does not support audio parts".to_string(),
422        )),
423    }
424}
425
426fn convert_reasoning_part(reasoning_part: ReasoningPart) -> InputContentBlock {
427    if reasoning_part.text.is_empty() && reasoning_part.signature.is_some() {
428        return InputContentBlock::RedactedThinking(api::RequestRedactedThinkingBlock {
429            data: reasoning_part.signature.unwrap_or_default(),
430        });
431    }
432
433    InputContentBlock::Thinking(RequestThinkingBlock {
434        thinking: reasoning_part.text,
435        signature: reasoning_part.signature.unwrap_or_default(),
436    })
437}
438
439fn convert_tool_result_part(
440    tool_result: ToolResultPart,
441) -> LanguageModelResult<RequestToolResultBlock> {
442    let mut content_blocks = Vec::new();
443    for part in tool_result.content {
444        let block = convert_part_to_tool_result_content_block(part)?;
445        content_blocks.push(block);
446    }
447
448    let content = if content_blocks.is_empty() {
449        None
450    } else {
451        Some(ToolResultContent::Blocks(content_blocks))
452    };
453
454    Ok(RequestToolResultBlock {
455        cache_control: None,
456        content,
457        is_error: tool_result.is_error,
458        tool_use_id: tool_result.tool_call_id,
459    })
460}
461
462fn convert_part_to_tool_result_content_block(
463    part: Part,
464) -> LanguageModelResult<ToolResultContentBlock> {
465    match part {
466        Part::Text(text_part) => Ok(ToolResultContentBlock::Text(create_request_text_block(
467            text_part.text,
468        ))),
469        Part::Image(image_part) => Ok(ToolResultContentBlock::Image(create_request_image_block(
470            image_part,
471        ))),
472        Part::Source(source_part) => Ok(ToolResultContentBlock::SearchResult(convert_source_part(
473            source_part,
474        )?)),
475        _ => Err(LanguageModelError::Unsupported(
476            PROVIDER,
477            "Cannot convert tool result part to Anthropic content".to_string(),
478        )),
479    }
480}
481
482fn create_request_text_block(text: String) -> RequestTextBlock {
483    RequestTextBlock {
484        cache_control: None,
485        citations: None,
486        text,
487        type_field: "text".to_string(),
488    }
489}
490
491fn create_request_image_block(image_part: ImagePart) -> RequestImageBlock {
492    RequestImageBlock {
493        cache_control: None,
494        source: ImageSource::Base64(Base64ImageSource {
495            data: image_part.data,
496            media_type: image_part.mime_type,
497        }),
498    }
499}
500
501fn convert_source_part(
502    source_part: crate::SourcePart,
503) -> LanguageModelResult<RequestSearchResultBlock> {
504    let mut content = Vec::new();
505    for part in source_part.content {
506        match part {
507            Part::Text(text_part) => content.push(create_request_text_block(text_part.text)),
508            _ => {
509                return Err(LanguageModelError::Unsupported(
510                    PROVIDER,
511                    "Anthropic source part only supports text content".to_string(),
512                ))
513            }
514        }
515    }
516
517    Ok(RequestSearchResultBlock {
518        cache_control: None,
519        citations: Some(RequestCitationsConfig {
520            enabled: Some(true),
521        }),
522        content,
523        source: source_part.source,
524        title: source_part.title,
525    })
526}
527
528fn normalize_tool_args(args: Value) -> LanguageModelResult<Value> {
529    match args {
530        Value::Object(_) => Ok(args),
531        Value::Null => Ok(Value::Object(Map::new())),
532        _ => Err(LanguageModelError::InvalidInput(
533            "Anthropic tool call arguments must be a JSON object".to_string(),
534        )),
535    }
536}
537
538fn convert_to_anthropic_tool_choice(choice: ToolChoiceOption) -> api::ToolChoice {
539    match choice {
540        ToolChoiceOption::Auto => api::ToolChoice::Auto(api::ToolChoiceAuto {
541            disable_parallel_tool_use: None,
542        }),
543        ToolChoiceOption::None => api::ToolChoice::None(api::ToolChoiceNone {}),
544        ToolChoiceOption::Required => api::ToolChoice::Any(api::ToolChoiceAny {
545            disable_parallel_tool_use: None,
546        }),
547        ToolChoiceOption::Tool(tool) => api::ToolChoice::Tool(api::ToolChoiceTool {
548            disable_parallel_tool_use: None,
549            name: tool.tool_name,
550        }),
551    }
552}
553
554fn convert_to_anthropic_thinking_config(
555    reasoning: &ReasoningOptions,
556    max_tokens: u32,
557) -> ThinkingConfigParam {
558    if !reasoning.enabled {
559        return ThinkingConfigParam::Disabled(ThinkingConfigDisabled {});
560    }
561
562    let fallback = max_tokens.saturating_sub(1).max(1);
563    let budget = reasoning
564        .budget_tokens
565        .map_or(fallback, |value| value.max(1));
566
567    ThinkingConfigParam::Enabled(ThinkingConfigEnabled {
568        budget_tokens: budget,
569    })
570}
571
572fn map_anthropic_message(content: Vec<ContentBlock>) -> Vec<Part> {
573    let mut parts = Vec::new();
574    for block in content {
575        if let Some(part) = map_content_block(block) {
576            parts.push(part);
577        }
578    }
579    parts
580}
581
582fn map_content_block(block: ContentBlock) -> Option<Part> {
583    match block {
584        ContentBlock::Text(text_block) => Some(Part::Text(map_text_block(text_block))),
585        ContentBlock::Thinking(thinking_block) => {
586            Some(Part::Reasoning(map_thinking_block(thinking_block)))
587        }
588        ContentBlock::RedactedThinking(redacted_block) => {
589            Some(Part::Reasoning(map_redacted_thinking_block(redacted_block)))
590        }
591        ContentBlock::ToolUse(tool_use) => Some(Part::ToolCall(map_tool_use_block(tool_use))),
592        _ => None,
593    }
594}
595
596fn map_text_block(block: api::ResponseTextBlock) -> TextPart {
597    let citations = map_text_citations(block.citations);
598    TextPart {
599        text: block.text,
600        citations,
601    }
602}
603
604fn map_text_citations(citations: Option<Vec<api::ResponseCitation>>) -> Option<Vec<Citation>> {
605    let citations = citations?;
606
607    let mut results = Vec::new();
608
609    for citation in citations {
610        if let api::ResponseCitation::SearchResultLocation(
611            api::ResponseSearchResultLocationCitation {
612                cited_text,
613                end_block_index,
614                search_result_index: _,
615                source,
616                start_block_index,
617                title,
618            },
619        ) = citation
620        {
621            if source.is_empty() {
622                continue;
623            }
624
625            let mapped = Citation {
626                source,
627                title,
628                cited_text: if cited_text.is_empty() {
629                    None
630                } else {
631                    Some(cited_text)
632                },
633                start_index: start_block_index,
634                end_index: end_block_index,
635            };
636
637            results.push(mapped);
638        }
639    }
640
641    if results.is_empty() {
642        None
643    } else {
644        Some(results)
645    }
646}
647
648fn map_thinking_block(block: api::ResponseThinkingBlock) -> ReasoningPart {
649    ReasoningPart {
650        text: block.thinking,
651        signature: if block.signature.is_empty() {
652            None
653        } else {
654            Some(block.signature)
655        },
656        id: None,
657    }
658}
659
660fn map_redacted_thinking_block(block: api::ResponseRedactedThinkingBlock) -> ReasoningPart {
661    ReasoningPart {
662        text: String::new(),
663        signature: Some(block.data),
664        id: None,
665    }
666}
667
668fn map_tool_use_block(block: api::ResponseToolUseBlock) -> ToolCallPart {
669    ToolCallPart {
670        tool_call_id: block.id,
671        tool_name: block.name,
672        args: block.input,
673        id: None,
674    }
675}
676
677fn map_anthropic_usage(usage: &Usage) -> ModelUsage {
678    ModelUsage {
679        input_tokens: usage.input_tokens,
680        output_tokens: usage.output_tokens,
681        ..Default::default()
682    }
683}
684
685fn map_anthropic_message_delta_usage(usage: &MessageDeltaUsage) -> ModelUsage {
686    ModelUsage {
687        input_tokens: usage.input_tokens.unwrap_or(0),
688        output_tokens: usage.output_tokens,
689        ..Default::default()
690    }
691}
692
693fn map_anthropic_content_block_start_event(
694    content_block: ContentBlock,
695    index: usize,
696) -> LanguageModelResult<Vec<ContentDelta>> {
697    if let Some(part) = map_content_block(content_block) {
698        let mut delta = stream_utils::loosely_convert_part_to_part_delta(part)?;
699        if let PartDelta::ToolCall(tool_call_delta) = &mut delta {
700            tool_call_delta.args = Some(String::new());
701        }
702        Ok(vec![ContentDelta { index, part: delta }])
703    } else {
704        Ok(vec![])
705    }
706}
707
708fn map_anthropic_content_block_delta_event(
709    delta: ContentBlockDelta,
710    index: usize,
711) -> Option<ContentDelta> {
712    let part_delta = match delta {
713        ContentBlockDelta::TextDelta(delta) => PartDelta::Text(TextPartDelta {
714            text: delta.text,
715            citation: None,
716        }),
717        ContentBlockDelta::InputJsonDelta(delta) => PartDelta::ToolCall(ToolCallPartDelta {
718            tool_name: None,
719            args: Some(delta.partial_json),
720            tool_call_id: None,
721            id: None,
722        }),
723        ContentBlockDelta::ThinkingDelta(delta) => PartDelta::Reasoning(ReasoningPartDelta {
724            text: Some(delta.thinking),
725            signature: None,
726            id: None,
727        }),
728        ContentBlockDelta::SignatureDelta(delta) => PartDelta::Reasoning(ReasoningPartDelta {
729            text: None,
730            signature: Some(delta.signature),
731            id: None,
732        }),
733        ContentBlockDelta::CitationsDelta(delta) => {
734            if let Some(citation) = map_citation_delta(delta.citation) {
735                PartDelta::Text(TextPartDelta {
736                    text: String::new(),
737                    citation: Some(citation),
738                })
739            } else {
740                return None;
741            }
742        }
743    };
744
745    Some(ContentDelta {
746        index,
747        part: part_delta,
748    })
749}
750
751fn map_citation_delta(citation: api::ResponseCitation) -> Option<CitationDelta> {
752    let api::ResponseCitation::SearchResultLocation(api::ResponseSearchResultLocationCitation {
753        cited_text,
754        end_block_index,
755        search_result_index: _,
756        source,
757        start_block_index,
758        title,
759    }) = citation
760    else {
761        return None;
762    };
763
764    let result = CitationDelta {
765        r#type: "citation".to_string(),
766        source: Some(source),
767        title,
768        cited_text: if cited_text.is_empty() {
769            None
770        } else {
771            Some(cited_text)
772        },
773        start_index: Some(start_block_index),
774        end_index: Some(end_block_index),
775    };
776
777    Some(result)
778}