Skip to main content

agent_code_lib/llm/
anthropic.rs

1//! Anthropic Messages API provider.
2//!
3//! Native support for Claude models. Uses the Anthropic-specific
4//! wire format: top-level system param, content block arrays,
5//! tool definitions with input_schema, and SSE streaming with
6//! content_block_start/delta/stop events.
7
8use async_trait::async_trait;
9use futures::StreamExt;
10use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
11use tokio::sync::mpsc;
12use tracing::{debug, warn};
13
14use super::message::{messages_to_api_params, messages_to_api_params_cached};
15use super::provider::{Provider, ProviderError, ProviderRequest};
16use super::stream::{RawSseEvent, StreamEvent, StreamParser};
17
18/// Anthropic Messages API provider (Claude models, Bedrock, Vertex).
19pub struct AnthropicProvider {
20    http: reqwest::Client,
21    base_url: String,
22    api_key: String,
23}
24
25impl AnthropicProvider {
26    pub fn new(base_url: &str, api_key: &str) -> Self {
27        let http = reqwest::Client::builder()
28            .timeout(std::time::Duration::from_secs(300))
29            .build()
30            .expect("failed to build HTTP client");
31
32        Self {
33            http,
34            base_url: base_url.trim_end_matches('/').to_string(),
35            api_key: api_key.to_string(),
36        }
37    }
38}
39
40#[async_trait]
41impl Provider for AnthropicProvider {
42    fn name(&self) -> &str {
43        "anthropic"
44    }
45
46    async fn stream(
47        &self,
48        request: &ProviderRequest,
49    ) -> Result<mpsc::Receiver<StreamEvent>, ProviderError> {
50        let url = format!("{}/messages", self.base_url);
51
52        let mut headers = HeaderMap::new();
53        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
54        headers.insert(
55            "x-api-key",
56            HeaderValue::from_str(&self.api_key).map_err(|e| ProviderError::Auth(e.to_string()))?,
57        );
58        headers.insert("anthropic-version", HeaderValue::from_static("2023-06-01"));
59
60        // Enable beta features.
61        let mut betas = Vec::new();
62        betas.push("interleaved-thinking-2025-05-14"); // Extended thinking.
63        if request.enable_caching {
64            betas.push("prompt-caching-2024-07-31");
65        }
66        if !betas.is_empty() {
67            headers.insert(
68                "anthropic-beta",
69                HeaderValue::from_str(&betas.join(",")).unwrap_or(HeaderValue::from_static("")),
70            );
71        }
72
73        // Build tool definitions in Anthropic format.
74        // When caching is enabled, add cache_control to the last tool definition.
75        // This causes the API to cache the entire tools block (system prompt +
76        // tools are cached together as a prefix).
77        let tool_count = request.tools.len();
78        let tools: Vec<serde_json::Value> = request
79            .tools
80            .iter()
81            .enumerate()
82            .map(|(i, t)| {
83                let mut tool = serde_json::json!({
84                    "name": t.name,
85                    "description": t.description,
86                    "input_schema": t.input_schema,
87                });
88                if request.enable_caching && i == tool_count - 1 && tool_count > 0 {
89                    tool["cache_control"] = serde_json::json!({"type": "ephemeral"});
90                }
91                tool
92            })
93            .collect();
94
95        // System prompt with optional cache control.
96        let system = if request.enable_caching {
97            serde_json::json!([{
98                "type": "text",
99                "text": request.system_prompt,
100                "cache_control": { "type": "ephemeral" }
101            }])
102        } else {
103            serde_json::json!(request.system_prompt)
104        };
105
106        let mut body = serde_json::json!({
107            "model": request.model,
108            "max_tokens": request.max_tokens,
109            "stream": true,
110            "system": system,
111            "messages": if request.enable_caching {
112                messages_to_api_params_cached(&request.messages)
113            } else {
114                messages_to_api_params(&request.messages)
115            },
116            "tools": tools,
117        });
118
119        if let Some(temp) = request.temperature {
120            body["temperature"] = serde_json::json!(temp);
121        }
122
123        // Tool choice.
124        if !request.tools.is_empty() {
125            use super::provider::ToolChoice;
126            match &request.tool_choice {
127                ToolChoice::Auto => {
128                    body["tool_choice"] = serde_json::json!({"type": "auto"});
129                }
130                ToolChoice::Any => {
131                    body["tool_choice"] = serde_json::json!({"type": "any"});
132                }
133                ToolChoice::None => {
134                    // Anthropic doesn't have a "none" tool_choice — just omit tools.
135                    body.as_object_mut().unwrap().remove("tools");
136                }
137                ToolChoice::Specific(name) => {
138                    body["tool_choice"] = serde_json::json!({
139                        "type": "tool",
140                        "name": name
141                    });
142                }
143            }
144        }
145
146        // Metadata (e.g., user_id for analytics).
147        if let Some(ref meta) = request.metadata {
148            body["metadata"] = meta.clone();
149        }
150
151        // Thinking configuration (adaptive or budgeted).
152        let thinking_budget =
153            crate::services::tokens::max_thinking_tokens_for_model(&request.model);
154        body["thinking"] = serde_json::json!({
155            "type": "enabled",
156            "budget_tokens": thinking_budget,
157        });
158
159        debug!("Anthropic request to {url} (thinking budget: {thinking_budget})");
160
161        let response = self
162            .http
163            .post(&url)
164            .headers(headers)
165            .json(&body)
166            .send()
167            .await
168            .map_err(|e| ProviderError::Network(e.to_string()))?;
169
170        let status = response.status();
171        if !status.is_success() {
172            let body_text = response.text().await.unwrap_or_default();
173            return match status.as_u16() {
174                401 | 403 => Err(ProviderError::Auth(body_text)),
175                429 => {
176                    let retry = parse_retry_after(&body_text);
177                    Err(ProviderError::RateLimited {
178                        retry_after_ms: retry,
179                    })
180                }
181                529 => Err(ProviderError::Overloaded),
182                413 => Err(ProviderError::RequestTooLarge(body_text)),
183                _ => Err(ProviderError::Network(format!("{status}: {body_text}"))),
184            };
185        }
186
187        // Parse Anthropic SSE stream (reuses existing StreamParser).
188        let (tx, rx) = mpsc::channel(64);
189        tokio::spawn(async move {
190            let mut parser = StreamParser::new();
191            let mut byte_stream = response.bytes_stream();
192            let mut buffer = String::new();
193            let start = std::time::Instant::now();
194            let mut first_token = false;
195
196            while let Some(chunk_result) = byte_stream.next().await {
197                let chunk = match chunk_result {
198                    Ok(c) => c,
199                    Err(e) => {
200                        let _ = tx.send(StreamEvent::Error(e.to_string())).await;
201                        break;
202                    }
203                };
204
205                buffer.push_str(&String::from_utf8_lossy(&chunk));
206
207                while let Some(pos) = buffer.find("\n\n") {
208                    let event_text = buffer[..pos].to_string();
209                    buffer = buffer[pos + 2..].to_string();
210
211                    if let Some(data) = extract_sse_data(&event_text) {
212                        if data == "[DONE]" {
213                            return;
214                        }
215
216                        match serde_json::from_str::<RawSseEvent>(data) {
217                            Ok(raw) => {
218                                let events = parser.process(raw);
219                                for event in events {
220                                    if !first_token && matches!(event, StreamEvent::TextDelta(_)) {
221                                        first_token = true;
222                                        let ttft = start.elapsed().as_millis() as u64;
223                                        let _ = tx.send(StreamEvent::Ttft(ttft)).await;
224                                    }
225                                    if tx.send(event).await.is_err() {
226                                        return;
227                                    }
228                                }
229                            }
230                            Err(e) => {
231                                warn!("SSE parse error: {e}");
232                            }
233                        }
234                    }
235                }
236            }
237        });
238
239        Ok(rx)
240    }
241}
242
243fn extract_sse_data(event_text: &str) -> Option<&str> {
244    for line in event_text.lines() {
245        if let Some(data) = line.strip_prefix("data: ") {
246            return Some(data);
247        }
248        if let Some(data) = line.strip_prefix("data:") {
249            return Some(data.trim_start());
250        }
251    }
252    None
253}
254
255fn parse_retry_after(body: &str) -> u64 {
256    if let Ok(v) = serde_json::from_str::<serde_json::Value>(body)
257        && let Some(retry) = v
258            .get("error")
259            .and_then(|e| e.get("retry_after"))
260            .and_then(|r| r.as_f64())
261    {
262        return (retry * 1000.0) as u64;
263    }
264    1000
265}