Skip to main content

agent_sdk_rs/llm/
anthropic.rs

1use anthropic_ai_sdk::client::AnthropicClient;
2use anthropic_ai_sdk::types::message::{
3    ContentBlock, CreateMessageParams, CreateMessageResponse, Message, MessageClient, MessageError,
4    RequiredMessageParams, Role, Thinking, ThinkingType, Tool, ToolChoice,
5};
6use async_trait::async_trait;
7
8use crate::error::ProviderError;
9use crate::llm::{
10    ChatModel, ModelCompletion, ModelMessage, ModelToolCall, ModelToolChoice, ModelToolDefinition,
11    ModelUsage,
12};
13
14#[cfg(test)]
15use anthropic_ai_sdk::types::message::ContentBlockDelta;
16#[cfg(test)]
17use anthropic_ai_sdk::types::message::{MessageStartContent, StopReason, StreamEvent};
18
19#[derive(Debug, Clone)]
20/// Runtime configuration for [`AnthropicModel`].
21pub struct AnthropicModelConfig {
22    /// Anthropic API key.
23    pub api_key: String,
24    /// Model id (for example `claude-sonnet-4-5`).
25    pub model: String,
26    /// Anthropic API version header value.
27    pub api_version: String,
28    /// Optional base URL override for proxies or compatible endpoints.
29    pub api_base_url: Option<String>,
30    /// Maximum output tokens per call.
31    pub max_tokens: u32,
32    /// Optional sampling temperature.
33    pub temperature: Option<f32>,
34    /// Optional nucleus sampling parameter.
35    pub top_p: Option<f32>,
36    /// Optional budget for extended thinking tokens.
37    pub thinking_budget_tokens: Option<usize>,
38}
39
40impl AnthropicModelConfig {
41    /// Creates a config with sensible defaults.
42    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
43        Self {
44            api_key: api_key.into(),
45            model: model.into(),
46            api_version: AnthropicClient::DEFAULT_API_VERSION.to_string(),
47            api_base_url: None,
48            max_tokens: 4096,
49            temperature: None,
50            top_p: None,
51            thinking_budget_tokens: None,
52        }
53    }
54}
55
56#[derive(Debug, Clone)]
57/// Anthropic provider adapter implementing [`ChatModel`].
58pub struct AnthropicModel {
59    client: AnthropicClient,
60    config: AnthropicModelConfig,
61}
62
63impl AnthropicModel {
64    /// Creates a model adapter from explicit config.
65    pub fn new(config: AnthropicModelConfig) -> Result<Self, ProviderError> {
66        let mut builder =
67            AnthropicClient::builder(config.api_key.clone(), config.api_version.clone());
68        if let Some(url) = &config.api_base_url {
69            builder = builder.with_api_base_url(url.clone());
70        }
71
72        let client = builder
73            .build::<MessageError>()
74            .map_err(|err| ProviderError::Request(err.to_string()))?;
75
76        Ok(Self { client, config })
77    }
78
79    /// Creates a model adapter using `ANTHROPIC_API_KEY` from the environment.
80    pub fn from_env(model: impl Into<String>) -> Result<Self, ProviderError> {
81        let api_key = std::env::var("ANTHROPIC_API_KEY")
82            .map_err(|_| ProviderError::Request("ANTHROPIC_API_KEY is not set".to_string()))?;
83        Self::new(AnthropicModelConfig::new(api_key, model))
84    }
85}
86
87#[async_trait]
88impl ChatModel for AnthropicModel {
89    async fn invoke(
90        &self,
91        messages: &[ModelMessage],
92        tools: &[ModelToolDefinition],
93        tool_choice: ModelToolChoice,
94    ) -> Result<ModelCompletion, ProviderError> {
95        let (history, system) = to_anthropic_messages(messages);
96
97        let required = RequiredMessageParams {
98            model: self.config.model.clone(),
99            messages: history,
100            max_tokens: self.config.max_tokens,
101        };
102
103        let mut request = CreateMessageParams::new(required).with_stream(false);
104
105        if let Some(system_prompt) = system {
106            request = request.with_system(system_prompt);
107        }
108
109        if let Some(temperature) = self.config.temperature {
110            request = request.with_temperature(temperature);
111        }
112
113        if let Some(top_p) = self.config.top_p {
114            request = request.with_top_p(top_p);
115        }
116
117        if let Some(budget_tokens) = self.config.thinking_budget_tokens {
118            request = request.with_thinking(Thinking {
119                budget_tokens,
120                type_: ThinkingType::Enabled,
121            });
122        }
123
124        if !tools.is_empty() {
125            let anthropic_tools = tools
126                .iter()
127                .map(|tool| Tool {
128                    name: tool.name.clone(),
129                    description: Some(tool.description.clone()),
130                    input_schema: tool.parameters.clone(),
131                })
132                .collect::<Vec<_>>();
133
134            request = request.with_tools(anthropic_tools);
135            request = request.with_tool_choice(match tool_choice {
136                ModelToolChoice::Auto => ToolChoice::Auto,
137                ModelToolChoice::Required => ToolChoice::Any,
138                ModelToolChoice::None => ToolChoice::None,
139                ModelToolChoice::Tool(name) => ToolChoice::Tool { name },
140            });
141        }
142
143        let response = self
144            .client
145            .create_message(Some(&request))
146            .await
147            .map_err(|err| ProviderError::Request(err.to_string()))?;
148
149        Ok(normalize_response(&response))
150    }
151}
152
153fn to_anthropic_messages(messages: &[ModelMessage]) -> (Vec<Message>, Option<String>) {
154    let mut system_lines = Vec::new();
155    let mut anthropic_messages = Vec::new();
156
157    for message in messages {
158        match message {
159            ModelMessage::System(content) => system_lines.push(content.clone()),
160            ModelMessage::User(content) => {
161                anthropic_messages.push(Message::new_text(Role::User, content.clone()));
162            }
163            ModelMessage::Assistant {
164                content,
165                tool_calls,
166            } => {
167                let mut blocks = Vec::new();
168                if let Some(content) = content {
169                    if !content.is_empty() {
170                        blocks.push(ContentBlock::Text {
171                            text: content.clone(),
172                        });
173                    }
174                }
175                for call in tool_calls {
176                    blocks.push(ContentBlock::ToolUse {
177                        id: call.id.clone(),
178                        name: call.name.clone(),
179                        input: call.arguments.clone(),
180                    });
181                }
182                if !blocks.is_empty() {
183                    anthropic_messages.push(Message::new_blocks(Role::Assistant, blocks));
184                }
185            }
186            ModelMessage::ToolResult {
187                tool_call_id,
188                tool_name: _,
189                content,
190                is_error,
191            } => {
192                let rendered = if *is_error {
193                    format!("Error: {content}")
194                } else {
195                    content.clone()
196                };
197                anthropic_messages.push(Message::new_blocks(
198                    Role::User,
199                    vec![ContentBlock::ToolResult {
200                        tool_use_id: tool_call_id.clone(),
201                        content: rendered,
202                    }],
203                ));
204            }
205        }
206    }
207
208    let system = if system_lines.is_empty() {
209        None
210    } else {
211        Some(system_lines.join("\n\n"))
212    };
213
214    (anthropic_messages, system)
215}
216
217fn normalize_response(response: &CreateMessageResponse) -> ModelCompletion {
218    let mut text_parts = Vec::new();
219    let mut thinking_parts = Vec::new();
220    let mut tool_calls = Vec::new();
221
222    for block in &response.content {
223        match block {
224            ContentBlock::Text { text } => text_parts.push(text.clone()),
225            ContentBlock::ToolUse { id, name, input } => tool_calls.push(ModelToolCall {
226                id: id.clone(),
227                name: name.clone(),
228                arguments: input.clone(),
229            }),
230            ContentBlock::Thinking { thinking, .. } => thinking_parts.push(thinking.clone()),
231            ContentBlock::RedactedThinking { data } => {
232                thinking_parts.push(format!("[redacted:{} bytes]", data.len()))
233            }
234            _ => {}
235        }
236    }
237
238    let text = if text_parts.is_empty() {
239        None
240    } else {
241        Some(text_parts.join("\n"))
242    };
243
244    let thinking = if thinking_parts.is_empty() {
245        None
246    } else {
247        Some(thinking_parts.join("\n"))
248    };
249
250    ModelCompletion {
251        text,
252        thinking,
253        tool_calls,
254        usage: Some(ModelUsage {
255            input_tokens: response.usage.input_tokens,
256            output_tokens: response.usage.output_tokens,
257        }),
258    }
259}
260
261#[cfg(test)]
262#[derive(Debug, Clone, PartialEq)]
263pub(crate) enum AnthropicStreamChunk {
264    Text {
265        index: usize,
266        text: String,
267    },
268    Thinking {
269        index: usize,
270        content: String,
271    },
272    ToolInputJson {
273        index: usize,
274        partial_json: String,
275    },
276    ToolCallStart {
277        id: String,
278        name: String,
279        input: serde_json::Value,
280    },
281    Signature {
282        index: usize,
283        signature: String,
284    },
285    MessageStop {
286        stop_reason: Option<String>,
287    },
288    Error {
289        message: String,
290    },
291}
292
293#[cfg(test)]
294pub(crate) fn normalize_stream_event(event: &StreamEvent) -> Option<AnthropicStreamChunk> {
295    match event {
296        StreamEvent::ContentBlockStart {
297            index: _,
298            content_block,
299        } => {
300            if let ContentBlock::ToolUse { id, name, input } = content_block {
301                Some(AnthropicStreamChunk::ToolCallStart {
302                    id: id.clone(),
303                    name: name.clone(),
304                    input: input.clone(),
305                })
306            } else {
307                None
308            }
309        }
310        StreamEvent::ContentBlockDelta { index, delta } => match delta {
311            ContentBlockDelta::TextDelta { text } => Some(AnthropicStreamChunk::Text {
312                index: *index,
313                text: text.clone(),
314            }),
315            ContentBlockDelta::ThinkingDelta { thinking } => Some(AnthropicStreamChunk::Thinking {
316                index: *index,
317                content: thinking.clone(),
318            }),
319            ContentBlockDelta::InputJsonDelta { partial_json } => {
320                Some(AnthropicStreamChunk::ToolInputJson {
321                    index: *index,
322                    partial_json: partial_json.clone(),
323                })
324            }
325            ContentBlockDelta::SignatureDelta { signature } => {
326                Some(AnthropicStreamChunk::Signature {
327                    index: *index,
328                    signature: signature.clone(),
329                })
330            }
331        },
332        StreamEvent::MessageDelta { delta, usage: _ } => Some(AnthropicStreamChunk::MessageStop {
333            stop_reason: delta.stop_reason.as_ref().map(stop_reason_name),
334        }),
335        StreamEvent::MessageStop => Some(AnthropicStreamChunk::MessageStop { stop_reason: None }),
336        StreamEvent::Error { error } => Some(AnthropicStreamChunk::Error {
337            message: error.message.clone(),
338        }),
339        StreamEvent::MessageStart {
340            message: MessageStartContent { .. },
341        }
342        | StreamEvent::ContentBlockStop { .. }
343        | StreamEvent::Ping => None,
344    }
345}
346
347#[cfg(test)]
348fn stop_reason_name(stop_reason: &StopReason) -> String {
349    match stop_reason {
350        StopReason::EndTurn => "end_turn",
351        StopReason::MaxTokens => "max_tokens",
352        StopReason::StopSequence => "stop_sequence",
353        StopReason::ToolUse => "tool_use",
354        StopReason::Refusal => "refusal",
355    }
356    .to_string()
357}
358
359#[cfg(test)]
360mod tests {
361    use anthropic_ai_sdk::types::message::MessageContent;
362    use serde_json::json;
363
364    use super::*;
365    use crate::llm::ModelMessage;
366
367    #[test]
368    fn normalize_response_extracts_tool_calls_and_text() {
369        let response = CreateMessageResponse {
370            content: vec![
371                ContentBlock::Text {
372                    text: "Looking up".to_string(),
373                },
374                ContentBlock::ToolUse {
375                    id: "call_1".to_string(),
376                    name: "search".to_string(),
377                    input: json!({"query": "rust"}),
378                },
379            ],
380            id: "msg_1".to_string(),
381            model: "claude-test".to_string(),
382            role: Role::Assistant,
383            stop_reason: Some(StopReason::ToolUse),
384            stop_sequence: None,
385            type_: "message".to_string(),
386            usage: anthropic_ai_sdk::types::message::Usage {
387                input_tokens: 1,
388                output_tokens: 1,
389            },
390        };
391
392        let completion = normalize_response(&response);
393        assert_eq!(completion.text.as_deref(), Some("Looking up"));
394        assert_eq!(completion.tool_calls.len(), 1);
395        assert_eq!(completion.tool_calls[0].name, "search");
396    }
397
398    #[test]
399    fn to_anthropic_messages_serializes_tool_result() {
400        let history = vec![
401            ModelMessage::System("sys".to_string()),
402            ModelMessage::User("u1".to_string()),
403            ModelMessage::ToolResult {
404                tool_call_id: "call_1".to_string(),
405                tool_name: "search".to_string(),
406                content: "failed".to_string(),
407                is_error: true,
408            },
409        ];
410
411        let (messages, system) = to_anthropic_messages(&history);
412        assert_eq!(system.as_deref(), Some("sys"));
413        assert_eq!(messages.len(), 2);
414
415        let MessageContent::Blocks { content } = &messages[1].content else {
416            panic!("expected blocks")
417        };
418        assert_eq!(
419            content[0],
420            ContentBlock::ToolResult {
421                tool_use_id: "call_1".to_string(),
422                content: "Error: failed".to_string(),
423            }
424        );
425    }
426
427    #[test]
428    fn normalize_stream_event_maps_deltas() {
429        let text_event = StreamEvent::ContentBlockDelta {
430            index: 0,
431            delta: ContentBlockDelta::TextDelta {
432                text: "hi".to_string(),
433            },
434        };
435        let mapped_text = normalize_stream_event(&text_event);
436        assert_eq!(
437            mapped_text,
438            Some(AnthropicStreamChunk::Text {
439                index: 0,
440                text: "hi".to_string(),
441            })
442        );
443
444        let thinking_event = StreamEvent::ContentBlockDelta {
445            index: 1,
446            delta: ContentBlockDelta::ThinkingDelta {
447                thinking: "plan".to_string(),
448            },
449        };
450        let mapped_thinking = normalize_stream_event(&thinking_event);
451        assert_eq!(
452            mapped_thinking,
453            Some(AnthropicStreamChunk::Thinking {
454                index: 1,
455                content: "plan".to_string(),
456            })
457        );
458    }
459
460    #[test]
461    fn normalize_response_handles_thinking_without_text() {
462        let response = CreateMessageResponse {
463            content: vec![ContentBlock::Thinking {
464                thinking: "I should call a tool".to_string(),
465                signature: "sig".to_string(),
466            }],
467            id: "msg_2".to_string(),
468            model: "claude-test".to_string(),
469            role: Role::Assistant,
470            stop_reason: Some(StopReason::EndTurn),
471            stop_sequence: None,
472            type_: "message".to_string(),
473            usage: anthropic_ai_sdk::types::message::Usage {
474                input_tokens: 1,
475                output_tokens: 1,
476            },
477        };
478
479        let completion = normalize_response(&response);
480        assert!(completion.text.is_none());
481        assert_eq!(
482            completion.thinking,
483            Some("I should call a tool".to_string())
484        );
485    }
486
487    #[test]
488    fn normalize_stream_event_extracts_tool_call_start() {
489        let event = StreamEvent::ContentBlockStart {
490            index: 0,
491            content_block: ContentBlock::ToolUse {
492                id: "tool_1".to_string(),
493                name: "lookup".to_string(),
494                input: json!({"x": 1}),
495            },
496        };
497
498        let mapped = normalize_stream_event(&event);
499        assert_eq!(
500            mapped,
501            Some(AnthropicStreamChunk::ToolCallStart {
502                id: "tool_1".to_string(),
503                name: "lookup".to_string(),
504                input: json!({"x": 1}),
505            })
506        );
507    }
508}