Skip to main content

codetether_agent/provider/
anthropic.rs

1//! Anthropic provider implementation using the Messages API
2//!
3//! Supports Claude Sonnet 4, Claude Opus 4, and other Claude models.
4//! Uses the native Anthropic API format (not OpenAI-compatible).
5//! Reference: https://docs.anthropic.com/en/api/messages
6
7use super::{
8    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
9    Role, StreamChunk, ToolDefinition, Usage,
10};
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use serde::Deserialize;
15use serde_json::{Value, json};
16
17const ANTHROPIC_API_BASE: &str = "https://api.anthropic.com";
18const ANTHROPIC_VERSION: &str = "2023-06-01";
19
20pub struct AnthropicProvider {
21    client: Client,
22    api_key: String,
23}
24
25impl std::fmt::Debug for AnthropicProvider {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        f.debug_struct("AnthropicProvider")
28            .field("api_key", &"<REDACTED>")
29            .field("api_key_len", &self.api_key.len())
30            .finish()
31    }
32}
33
34impl AnthropicProvider {
35    pub fn new(api_key: String) -> Result<Self> {
36        tracing::debug!(
37            provider = "anthropic",
38            api_key_len = api_key.len(),
39            "Creating Anthropic provider"
40        );
41        Ok(Self {
42            client: Client::new(),
43            api_key,
44        })
45    }
46
47    fn validate_api_key(&self) -> Result<()> {
48        if self.api_key.is_empty() {
49            anyhow::bail!("Anthropic API key is empty");
50        }
51        Ok(())
52    }
53
54    /// Convert our generic messages to Anthropic Messages API format.
55    ///
56    /// Anthropic uses a different format:
57    /// - system prompt is a top-level field, not a message
58    /// - tool results go in user messages with type "tool_result"
59    /// - tool calls appear in assistant messages with type "tool_use"
60    fn convert_messages(messages: &[Message]) -> (Option<String>, Vec<Value>) {
61        let mut system_prompt = None;
62        let mut api_messages: Vec<Value> = Vec::new();
63
64        for msg in messages {
65            match msg.role {
66                Role::System => {
67                    let text: String = msg
68                        .content
69                        .iter()
70                        .filter_map(|p| match p {
71                            ContentPart::Text { text } => Some(text.clone()),
72                            _ => None,
73                        })
74                        .collect::<Vec<_>>()
75                        .join("\n");
76                    system_prompt = Some(match system_prompt {
77                        Some(existing) => format!("{}\n{}", existing, text),
78                        None => text,
79                    });
80                }
81                Role::User => {
82                    let text: String = msg
83                        .content
84                        .iter()
85                        .filter_map(|p| match p {
86                            ContentPart::Text { text } => Some(text.clone()),
87                            _ => None,
88                        })
89                        .collect::<Vec<_>>()
90                        .join("\n");
91                    api_messages.push(json!({
92                        "role": "user",
93                        "content": text
94                    }));
95                }
96                Role::Assistant => {
97                    let mut content_parts: Vec<Value> = Vec::new();
98
99                    for part in &msg.content {
100                        match part {
101                            ContentPart::Text { text } => {
102                                if !text.is_empty() {
103                                    content_parts.push(json!({
104                                        "type": "text",
105                                        "text": text
106                                    }));
107                                }
108                            }
109                            ContentPart::ToolCall {
110                                id,
111                                name,
112                                arguments,
113                            } => {
114                                let input: Value = serde_json::from_str(arguments)
115                                    .unwrap_or_else(|_| json!({"raw": arguments}));
116                                content_parts.push(json!({
117                                    "type": "tool_use",
118                                    "id": id,
119                                    "name": name,
120                                    "input": input
121                                }));
122                            }
123                            _ => {}
124                        }
125                    }
126
127                    if content_parts.is_empty() {
128                        content_parts.push(json!({"type": "text", "text": ""}));
129                    }
130
131                    api_messages.push(json!({
132                        "role": "assistant",
133                        "content": content_parts
134                    }));
135                }
136                Role::Tool => {
137                    let mut tool_results: Vec<Value> = Vec::new();
138                    for part in &msg.content {
139                        if let ContentPart::ToolResult {
140                            tool_call_id,
141                            content,
142                        } = part
143                        {
144                            tool_results.push(json!({
145                                "type": "tool_result",
146                                "tool_use_id": tool_call_id,
147                                "content": content
148                            }));
149                        }
150                    }
151                    if !tool_results.is_empty() {
152                        api_messages.push(json!({
153                            "role": "user",
154                            "content": tool_results
155                        }));
156                    }
157                }
158            }
159        }
160
161        (system_prompt, api_messages)
162    }
163
164    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
165        tools
166            .iter()
167            .map(|t| {
168                json!({
169                    "name": t.name,
170                    "description": t.description,
171                    "input_schema": t.parameters
172                })
173            })
174            .collect()
175    }
176}
177
178#[derive(Debug, Deserialize)]
179struct AnthropicResponse {
180    #[allow(dead_code)]
181    id: String,
182    #[allow(dead_code)]
183    model: String,
184    content: Vec<AnthropicContent>,
185    #[serde(default)]
186    stop_reason: Option<String>,
187    #[serde(default)]
188    usage: Option<AnthropicUsage>,
189}
190
191#[derive(Debug, Deserialize)]
192#[serde(tag = "type")]
193enum AnthropicContent {
194    #[serde(rename = "text")]
195    Text { text: String },
196    #[serde(rename = "tool_use")]
197    ToolUse {
198        id: String,
199        name: String,
200        input: Value,
201    },
202}
203
204#[derive(Debug, Deserialize)]
205struct AnthropicUsage {
206    #[serde(default)]
207    input_tokens: usize,
208    #[serde(default)]
209    output_tokens: usize,
210    #[serde(default)]
211    cache_creation_input_tokens: Option<usize>,
212    #[serde(default)]
213    cache_read_input_tokens: Option<usize>,
214}
215
216#[derive(Debug, Deserialize)]
217struct AnthropicError {
218    error: AnthropicErrorDetail,
219}
220
221#[derive(Debug, Deserialize)]
222struct AnthropicErrorDetail {
223    message: String,
224    #[serde(default, rename = "type")]
225    error_type: Option<String>,
226}
227
228#[async_trait]
229impl Provider for AnthropicProvider {
230    fn name(&self) -> &str {
231        "anthropic"
232    }
233
234    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
235        self.validate_api_key()?;
236
237        Ok(vec![
238            ModelInfo {
239                id: "claude-sonnet-4-20250514".to_string(),
240                name: "Claude Sonnet 4".to_string(),
241                provider: "anthropic".to_string(),
242                context_window: 200_000,
243                max_output_tokens: Some(64_000),
244                supports_vision: true,
245                supports_tools: true,
246                supports_streaming: true,
247                input_cost_per_million: Some(3.0),
248                output_cost_per_million: Some(15.0),
249            },
250            ModelInfo {
251                id: "claude-opus-4-20250514".to_string(),
252                name: "Claude Opus 4".to_string(),
253                provider: "anthropic".to_string(),
254                context_window: 200_000,
255                max_output_tokens: Some(32_000),
256                supports_vision: true,
257                supports_tools: true,
258                supports_streaming: true,
259                input_cost_per_million: Some(15.0),
260                output_cost_per_million: Some(75.0),
261            },
262            ModelInfo {
263                id: "claude-haiku-3-5-20241022".to_string(),
264                name: "Claude 3.5 Haiku".to_string(),
265                provider: "anthropic".to_string(),
266                context_window: 200_000,
267                max_output_tokens: Some(8_192),
268                supports_vision: true,
269                supports_tools: true,
270                supports_streaming: true,
271                input_cost_per_million: Some(0.80),
272                output_cost_per_million: Some(4.0),
273            },
274        ])
275    }
276
277    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
278        tracing::debug!(
279            provider = "anthropic",
280            model = %request.model,
281            message_count = request.messages.len(),
282            tool_count = request.tools.len(),
283            "Starting completion request"
284        );
285
286        self.validate_api_key()?;
287
288        let (system_prompt, messages) = Self::convert_messages(&request.messages);
289        let tools = Self::convert_tools(&request.tools);
290
291        let mut body = json!({
292            "model": request.model,
293            "messages": messages,
294            "max_tokens": request.max_tokens.unwrap_or(8192),
295        });
296
297        if let Some(system) = system_prompt {
298            body["system"] = json!(system);
299        }
300        if !tools.is_empty() {
301            body["tools"] = json!(tools);
302        }
303        if let Some(temp) = request.temperature {
304            body["temperature"] = json!(temp);
305        }
306        if let Some(top_p) = request.top_p {
307            body["top_p"] = json!(top_p);
308        }
309
310        tracing::debug!("Anthropic request to model {}", request.model);
311
312        let response = self
313            .client
314            .post(format!("{}/v1/messages", ANTHROPIC_API_BASE))
315            .header("x-api-key", &self.api_key)
316            .header("anthropic-version", ANTHROPIC_VERSION)
317            .header("content-type", "application/json")
318            .json(&body)
319            .send()
320            .await
321            .context("Failed to send request to Anthropic")?;
322
323        let status = response.status();
324        let text = response
325            .text()
326            .await
327            .context("Failed to read Anthropic response")?;
328
329        if !status.is_success() {
330            if let Ok(err) = serde_json::from_str::<AnthropicError>(&text) {
331                anyhow::bail!(
332                    "Anthropic API error: {} ({:?})",
333                    err.error.message,
334                    err.error.error_type
335                );
336            }
337            anyhow::bail!("Anthropic API error: {} {}", status, text);
338        }
339
340        let response: AnthropicResponse = serde_json::from_str(&text).context(format!(
341            "Failed to parse Anthropic response: {}",
342            &text[..text.len().min(200)]
343        ))?;
344
345        tracing::debug!(
346            response_id = %response.id,
347            model = %response.model,
348            stop_reason = ?response.stop_reason,
349            "Received Anthropic response"
350        );
351
352        let mut content = Vec::new();
353        let mut has_tool_calls = false;
354
355        for part in &response.content {
356            match part {
357                AnthropicContent::Text { text } => {
358                    if !text.is_empty() {
359                        content.push(ContentPart::Text { text: text.clone() });
360                    }
361                }
362                AnthropicContent::ToolUse { id, name, input } => {
363                    has_tool_calls = true;
364                    content.push(ContentPart::ToolCall {
365                        id: id.clone(),
366                        name: name.clone(),
367                        arguments: serde_json::to_string(input).unwrap_or_default(),
368                    });
369                }
370            }
371        }
372
373        let finish_reason = if has_tool_calls {
374            FinishReason::ToolCalls
375        } else {
376            match response.stop_reason.as_deref() {
377                Some("end_turn") | Some("stop") => FinishReason::Stop,
378                Some("max_tokens") => FinishReason::Length,
379                Some("tool_use") => FinishReason::ToolCalls,
380                Some("content_filter") => FinishReason::ContentFilter,
381                _ => FinishReason::Stop,
382            }
383        };
384
385        let usage = response.usage.as_ref();
386
387        Ok(CompletionResponse {
388            message: Message {
389                role: Role::Assistant,
390                content,
391            },
392            usage: Usage {
393                prompt_tokens: usage.map(|u| u.input_tokens).unwrap_or(0),
394                completion_tokens: usage.map(|u| u.output_tokens).unwrap_or(0),
395                total_tokens: usage.map(|u| u.input_tokens + u.output_tokens).unwrap_or(0),
396                cache_read_tokens: usage.and_then(|u| u.cache_read_input_tokens),
397                cache_write_tokens: usage.and_then(|u| u.cache_creation_input_tokens),
398            },
399            finish_reason,
400        })
401    }
402
403    async fn complete_stream(
404        &self,
405        request: CompletionRequest,
406    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
407        // Fall back to non-streaming for now
408        let response = self.complete(request).await?;
409        let text = response
410            .message
411            .content
412            .iter()
413            .filter_map(|p| match p {
414                ContentPart::Text { text } => Some(text.clone()),
415                _ => None,
416            })
417            .collect::<Vec<_>>()
418            .join("");
419
420        Ok(Box::pin(futures::stream::once(async move {
421            StreamChunk::Text(text)
422        })))
423    }
424}