1use anthropic_ai_sdk::client::AnthropicClient;
2use anthropic_ai_sdk::types::message::{
3 ContentBlock, CreateMessageParams, CreateMessageResponse, Message, MessageClient, MessageError,
4 RequiredMessageParams, Role, Thinking, ThinkingType, Tool, ToolChoice,
5};
6use async_trait::async_trait;
7
8use crate::error::ProviderError;
9use crate::llm::{
10 ChatModel, ModelCompletion, ModelMessage, ModelToolCall, ModelToolChoice, ModelToolDefinition,
11 ModelUsage,
12};
13
14#[cfg(test)]
15use anthropic_ai_sdk::types::message::ContentBlockDelta;
16#[cfg(test)]
17use anthropic_ai_sdk::types::message::{MessageStartContent, StopReason, StreamEvent};
18
19#[derive(Debug, Clone)]
20pub struct AnthropicModelConfig {
22 pub api_key: String,
24 pub model: String,
26 pub api_version: String,
28 pub api_base_url: Option<String>,
30 pub max_tokens: u32,
32 pub temperature: Option<f32>,
34 pub top_p: Option<f32>,
36 pub thinking_budget_tokens: Option<usize>,
38}
39
40impl AnthropicModelConfig {
41 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
43 Self {
44 api_key: api_key.into(),
45 model: model.into(),
46 api_version: AnthropicClient::DEFAULT_API_VERSION.to_string(),
47 api_base_url: None,
48 max_tokens: 4096,
49 temperature: None,
50 top_p: None,
51 thinking_budget_tokens: None,
52 }
53 }
54}
55
56#[derive(Debug, Clone)]
57pub struct AnthropicModel {
59 client: AnthropicClient,
60 config: AnthropicModelConfig,
61}
62
63impl AnthropicModel {
64 pub fn new(config: AnthropicModelConfig) -> Result<Self, ProviderError> {
66 let mut builder =
67 AnthropicClient::builder(config.api_key.clone(), config.api_version.clone());
68 if let Some(url) = &config.api_base_url {
69 builder = builder.with_api_base_url(url.clone());
70 }
71
72 let client = builder
73 .build::<MessageError>()
74 .map_err(|err| ProviderError::Request(err.to_string()))?;
75
76 Ok(Self { client, config })
77 }
78
79 pub fn from_env(model: impl Into<String>) -> Result<Self, ProviderError> {
81 let api_key = std::env::var("ANTHROPIC_API_KEY")
82 .map_err(|_| ProviderError::Request("ANTHROPIC_API_KEY is not set".to_string()))?;
83 Self::new(AnthropicModelConfig::new(api_key, model))
84 }
85}
86
87#[async_trait]
88impl ChatModel for AnthropicModel {
89 async fn invoke(
90 &self,
91 messages: &[ModelMessage],
92 tools: &[ModelToolDefinition],
93 tool_choice: ModelToolChoice,
94 ) -> Result<ModelCompletion, ProviderError> {
95 let (history, system) = to_anthropic_messages(messages);
96
97 let required = RequiredMessageParams {
98 model: self.config.model.clone(),
99 messages: history,
100 max_tokens: self.config.max_tokens,
101 };
102
103 let mut request = CreateMessageParams::new(required).with_stream(false);
104
105 if let Some(system_prompt) = system {
106 request = request.with_system(system_prompt);
107 }
108
109 if let Some(temperature) = self.config.temperature {
110 request = request.with_temperature(temperature);
111 }
112
113 if let Some(top_p) = self.config.top_p {
114 request = request.with_top_p(top_p);
115 }
116
117 if let Some(budget_tokens) = self.config.thinking_budget_tokens {
118 request = request.with_thinking(Thinking {
119 budget_tokens,
120 type_: ThinkingType::Enabled,
121 });
122 }
123
124 if !tools.is_empty() {
125 let anthropic_tools = tools
126 .iter()
127 .map(|tool| Tool {
128 name: tool.name.clone(),
129 description: Some(tool.description.clone()),
130 input_schema: tool.parameters.clone(),
131 })
132 .collect::<Vec<_>>();
133
134 request = request.with_tools(anthropic_tools);
135 request = request.with_tool_choice(match tool_choice {
136 ModelToolChoice::Auto => ToolChoice::Auto,
137 ModelToolChoice::Required => ToolChoice::Any,
138 ModelToolChoice::None => ToolChoice::None,
139 ModelToolChoice::Tool(name) => ToolChoice::Tool { name },
140 });
141 }
142
143 let response = self
144 .client
145 .create_message(Some(&request))
146 .await
147 .map_err(|err| ProviderError::Request(err.to_string()))?;
148
149 Ok(normalize_response(&response))
150 }
151}
152
153fn to_anthropic_messages(messages: &[ModelMessage]) -> (Vec<Message>, Option<String>) {
154 let mut system_lines = Vec::new();
155 let mut anthropic_messages = Vec::new();
156
157 for message in messages {
158 match message {
159 ModelMessage::System(content) => system_lines.push(content.clone()),
160 ModelMessage::User(content) => {
161 anthropic_messages.push(Message::new_text(Role::User, content.clone()));
162 }
163 ModelMessage::Assistant {
164 content,
165 tool_calls,
166 } => {
167 let mut blocks = Vec::new();
168 if let Some(content) = content {
169 if !content.is_empty() {
170 blocks.push(ContentBlock::Text {
171 text: content.clone(),
172 });
173 }
174 }
175 for call in tool_calls {
176 blocks.push(ContentBlock::ToolUse {
177 id: call.id.clone(),
178 name: call.name.clone(),
179 input: call.arguments.clone(),
180 });
181 }
182 if !blocks.is_empty() {
183 anthropic_messages.push(Message::new_blocks(Role::Assistant, blocks));
184 }
185 }
186 ModelMessage::ToolResult {
187 tool_call_id,
188 tool_name: _,
189 content,
190 is_error,
191 } => {
192 let rendered = if *is_error {
193 format!("Error: {content}")
194 } else {
195 content.clone()
196 };
197 anthropic_messages.push(Message::new_blocks(
198 Role::User,
199 vec![ContentBlock::ToolResult {
200 tool_use_id: tool_call_id.clone(),
201 content: rendered,
202 }],
203 ));
204 }
205 }
206 }
207
208 let system = if system_lines.is_empty() {
209 None
210 } else {
211 Some(system_lines.join("\n\n"))
212 };
213
214 (anthropic_messages, system)
215}
216
217fn normalize_response(response: &CreateMessageResponse) -> ModelCompletion {
218 let mut text_parts = Vec::new();
219 let mut thinking_parts = Vec::new();
220 let mut tool_calls = Vec::new();
221
222 for block in &response.content {
223 match block {
224 ContentBlock::Text { text } => text_parts.push(text.clone()),
225 ContentBlock::ToolUse { id, name, input } => tool_calls.push(ModelToolCall {
226 id: id.clone(),
227 name: name.clone(),
228 arguments: input.clone(),
229 }),
230 ContentBlock::Thinking { thinking, .. } => thinking_parts.push(thinking.clone()),
231 ContentBlock::RedactedThinking { data } => {
232 thinking_parts.push(format!("[redacted:{} bytes]", data.len()))
233 }
234 _ => {}
235 }
236 }
237
238 let text = if text_parts.is_empty() {
239 None
240 } else {
241 Some(text_parts.join("\n"))
242 };
243
244 let thinking = if thinking_parts.is_empty() {
245 None
246 } else {
247 Some(thinking_parts.join("\n"))
248 };
249
250 ModelCompletion {
251 text,
252 thinking,
253 tool_calls,
254 usage: Some(ModelUsage {
255 input_tokens: response.usage.input_tokens,
256 output_tokens: response.usage.output_tokens,
257 }),
258 }
259}
260
261#[cfg(test)]
262#[derive(Debug, Clone, PartialEq)]
263pub(crate) enum AnthropicStreamChunk {
264 Text {
265 index: usize,
266 text: String,
267 },
268 Thinking {
269 index: usize,
270 content: String,
271 },
272 ToolInputJson {
273 index: usize,
274 partial_json: String,
275 },
276 ToolCallStart {
277 id: String,
278 name: String,
279 input: serde_json::Value,
280 },
281 Signature {
282 index: usize,
283 signature: String,
284 },
285 MessageStop {
286 stop_reason: Option<String>,
287 },
288 Error {
289 message: String,
290 },
291}
292
293#[cfg(test)]
294pub(crate) fn normalize_stream_event(event: &StreamEvent) -> Option<AnthropicStreamChunk> {
295 match event {
296 StreamEvent::ContentBlockStart {
297 index: _,
298 content_block,
299 } => {
300 if let ContentBlock::ToolUse { id, name, input } = content_block {
301 Some(AnthropicStreamChunk::ToolCallStart {
302 id: id.clone(),
303 name: name.clone(),
304 input: input.clone(),
305 })
306 } else {
307 None
308 }
309 }
310 StreamEvent::ContentBlockDelta { index, delta } => match delta {
311 ContentBlockDelta::TextDelta { text } => Some(AnthropicStreamChunk::Text {
312 index: *index,
313 text: text.clone(),
314 }),
315 ContentBlockDelta::ThinkingDelta { thinking } => Some(AnthropicStreamChunk::Thinking {
316 index: *index,
317 content: thinking.clone(),
318 }),
319 ContentBlockDelta::InputJsonDelta { partial_json } => {
320 Some(AnthropicStreamChunk::ToolInputJson {
321 index: *index,
322 partial_json: partial_json.clone(),
323 })
324 }
325 ContentBlockDelta::SignatureDelta { signature } => {
326 Some(AnthropicStreamChunk::Signature {
327 index: *index,
328 signature: signature.clone(),
329 })
330 }
331 },
332 StreamEvent::MessageDelta { delta, usage: _ } => Some(AnthropicStreamChunk::MessageStop {
333 stop_reason: delta.stop_reason.as_ref().map(stop_reason_name),
334 }),
335 StreamEvent::MessageStop => Some(AnthropicStreamChunk::MessageStop { stop_reason: None }),
336 StreamEvent::Error { error } => Some(AnthropicStreamChunk::Error {
337 message: error.message.clone(),
338 }),
339 StreamEvent::MessageStart {
340 message: MessageStartContent { .. },
341 }
342 | StreamEvent::ContentBlockStop { .. }
343 | StreamEvent::Ping => None,
344 }
345}
346
347#[cfg(test)]
348fn stop_reason_name(stop_reason: &StopReason) -> String {
349 match stop_reason {
350 StopReason::EndTurn => "end_turn",
351 StopReason::MaxTokens => "max_tokens",
352 StopReason::StopSequence => "stop_sequence",
353 StopReason::ToolUse => "tool_use",
354 StopReason::Refusal => "refusal",
355 }
356 .to_string()
357}
358
359#[cfg(test)]
360mod tests {
361 use anthropic_ai_sdk::types::message::MessageContent;
362 use serde_json::json;
363
364 use super::*;
365 use crate::llm::ModelMessage;
366
367 #[test]
368 fn normalize_response_extracts_tool_calls_and_text() {
369 let response = CreateMessageResponse {
370 content: vec![
371 ContentBlock::Text {
372 text: "Looking up".to_string(),
373 },
374 ContentBlock::ToolUse {
375 id: "call_1".to_string(),
376 name: "search".to_string(),
377 input: json!({"query": "rust"}),
378 },
379 ],
380 id: "msg_1".to_string(),
381 model: "claude-test".to_string(),
382 role: Role::Assistant,
383 stop_reason: Some(StopReason::ToolUse),
384 stop_sequence: None,
385 type_: "message".to_string(),
386 usage: anthropic_ai_sdk::types::message::Usage {
387 input_tokens: 1,
388 output_tokens: 1,
389 },
390 };
391
392 let completion = normalize_response(&response);
393 assert_eq!(completion.text.as_deref(), Some("Looking up"));
394 assert_eq!(completion.tool_calls.len(), 1);
395 assert_eq!(completion.tool_calls[0].name, "search");
396 }
397
398 #[test]
399 fn to_anthropic_messages_serializes_tool_result() {
400 let history = vec![
401 ModelMessage::System("sys".to_string()),
402 ModelMessage::User("u1".to_string()),
403 ModelMessage::ToolResult {
404 tool_call_id: "call_1".to_string(),
405 tool_name: "search".to_string(),
406 content: "failed".to_string(),
407 is_error: true,
408 },
409 ];
410
411 let (messages, system) = to_anthropic_messages(&history);
412 assert_eq!(system.as_deref(), Some("sys"));
413 assert_eq!(messages.len(), 2);
414
415 let MessageContent::Blocks { content } = &messages[1].content else {
416 panic!("expected blocks")
417 };
418 assert_eq!(
419 content[0],
420 ContentBlock::ToolResult {
421 tool_use_id: "call_1".to_string(),
422 content: "Error: failed".to_string(),
423 }
424 );
425 }
426
427 #[test]
428 fn normalize_stream_event_maps_deltas() {
429 let text_event = StreamEvent::ContentBlockDelta {
430 index: 0,
431 delta: ContentBlockDelta::TextDelta {
432 text: "hi".to_string(),
433 },
434 };
435 let mapped_text = normalize_stream_event(&text_event);
436 assert_eq!(
437 mapped_text,
438 Some(AnthropicStreamChunk::Text {
439 index: 0,
440 text: "hi".to_string(),
441 })
442 );
443
444 let thinking_event = StreamEvent::ContentBlockDelta {
445 index: 1,
446 delta: ContentBlockDelta::ThinkingDelta {
447 thinking: "plan".to_string(),
448 },
449 };
450 let mapped_thinking = normalize_stream_event(&thinking_event);
451 assert_eq!(
452 mapped_thinking,
453 Some(AnthropicStreamChunk::Thinking {
454 index: 1,
455 content: "plan".to_string(),
456 })
457 );
458 }
459
460 #[test]
461 fn normalize_response_handles_thinking_without_text() {
462 let response = CreateMessageResponse {
463 content: vec![ContentBlock::Thinking {
464 thinking: "I should call a tool".to_string(),
465 signature: "sig".to_string(),
466 }],
467 id: "msg_2".to_string(),
468 model: "claude-test".to_string(),
469 role: Role::Assistant,
470 stop_reason: Some(StopReason::EndTurn),
471 stop_sequence: None,
472 type_: "message".to_string(),
473 usage: anthropic_ai_sdk::types::message::Usage {
474 input_tokens: 1,
475 output_tokens: 1,
476 },
477 };
478
479 let completion = normalize_response(&response);
480 assert!(completion.text.is_none());
481 assert_eq!(
482 completion.thinking,
483 Some("I should call a tool".to_string())
484 );
485 }
486
487 #[test]
488 fn normalize_stream_event_extracts_tool_call_start() {
489 let event = StreamEvent::ContentBlockStart {
490 index: 0,
491 content_block: ContentBlock::ToolUse {
492 id: "tool_1".to_string(),
493 name: "lookup".to_string(),
494 input: json!({"x": 1}),
495 },
496 };
497
498 let mapped = normalize_stream_event(&event);
499 assert_eq!(
500 mapped,
501 Some(AnthropicStreamChunk::ToolCallStart {
502 id: "tool_1".to_string(),
503 name: "lookup".to_string(),
504 input: json!({"x": 1}),
505 })
506 );
507 }
508}