Skip to main content

agent_code_lib/llm/
anthropic.rs

1//! Anthropic Messages API provider.
2//!
3//! Native support for Claude models. Uses the Anthropic-specific
4//! wire format: top-level system param, content block arrays,
5//! tool definitions with input_schema, and SSE streaming with
6//! content_block_start/delta/stop events.
7
8use async_trait::async_trait;
9use futures::StreamExt;
10use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
11use tokio::sync::mpsc;
12use tracing::{debug, warn};
13
14use super::message::{messages_to_api_params, messages_to_api_params_cached};
15use super::provider::{Provider, ProviderError, ProviderRequest};
16use super::stream::{RawSseEvent, StreamEvent, StreamParser};
17
18/// Anthropic Messages API provider (Claude models, Bedrock, Vertex).
19pub struct AnthropicProvider {
20    http: reqwest::Client,
21    base_url: String,
22    api_key: String,
23}
24
25impl AnthropicProvider {
26    pub fn new(base_url: &str, api_key: &str) -> Self {
27        let http = reqwest::Client::builder()
28            .timeout(std::time::Duration::from_secs(300))
29            .build()
30            .expect("failed to build HTTP client");
31
32        Self {
33            http,
34            base_url: base_url.trim_end_matches('/').to_string(),
35            api_key: api_key.to_string(),
36        }
37    }
38}
39
40#[async_trait]
41impl Provider for AnthropicProvider {
42    fn name(&self) -> &str {
43        "anthropic"
44    }
45
46    async fn stream(
47        &self,
48        request: &ProviderRequest,
49    ) -> Result<mpsc::Receiver<StreamEvent>, ProviderError> {
50        let url = format!("{}/messages", self.base_url);
51
52        let mut headers = HeaderMap::new();
53        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
54        headers.insert(
55            "x-api-key",
56            HeaderValue::from_str(&self.api_key).map_err(|e| ProviderError::Auth(e.to_string()))?,
57        );
58        headers.insert("anthropic-version", HeaderValue::from_static("2023-06-01"));
59
60        // Enable beta features.
61        let mut betas = Vec::new();
62        betas.push("interleaved-thinking-2025-05-14"); // Extended thinking.
63        if request.enable_caching {
64            betas.push("prompt-caching-2024-07-31");
65        }
66        if !betas.is_empty() {
67            headers.insert(
68                "anthropic-beta",
69                HeaderValue::from_str(&betas.join(",")).unwrap_or(HeaderValue::from_static("")),
70            );
71        }
72
73        // Build tool definitions in Anthropic format.
74        let tools: Vec<serde_json::Value> = request
75            .tools
76            .iter()
77            .map(|t| {
78                serde_json::json!({
79                    "name": t.name,
80                    "description": t.description,
81                    "input_schema": t.input_schema,
82                })
83            })
84            .collect();
85
86        // System prompt with optional cache control.
87        let system = if request.enable_caching {
88            serde_json::json!([{
89                "type": "text",
90                "text": request.system_prompt,
91                "cache_control": { "type": "ephemeral" }
92            }])
93        } else {
94            serde_json::json!(request.system_prompt)
95        };
96
97        let mut body = serde_json::json!({
98            "model": request.model,
99            "max_tokens": request.max_tokens,
100            "stream": true,
101            "system": system,
102            "messages": if request.enable_caching {
103                messages_to_api_params_cached(&request.messages)
104            } else {
105                messages_to_api_params(&request.messages)
106            },
107            "tools": tools,
108        });
109
110        if let Some(temp) = request.temperature {
111            body["temperature"] = serde_json::json!(temp);
112        }
113
114        // Tool choice.
115        if !request.tools.is_empty() {
116            use super::provider::ToolChoice;
117            match &request.tool_choice {
118                ToolChoice::Auto => {
119                    body["tool_choice"] = serde_json::json!({"type": "auto"});
120                }
121                ToolChoice::Any => {
122                    body["tool_choice"] = serde_json::json!({"type": "any"});
123                }
124                ToolChoice::None => {
125                    // Anthropic doesn't have a "none" tool_choice — just omit tools.
126                    body.as_object_mut().unwrap().remove("tools");
127                }
128                ToolChoice::Specific(name) => {
129                    body["tool_choice"] = serde_json::json!({
130                        "type": "tool",
131                        "name": name
132                    });
133                }
134            }
135        }
136
137        // Metadata (e.g., user_id for analytics).
138        if let Some(ref meta) = request.metadata {
139            body["metadata"] = meta.clone();
140        }
141
142        // Thinking configuration (adaptive or budgeted).
143        let thinking_budget =
144            crate::services::tokens::max_thinking_tokens_for_model(&request.model);
145        body["thinking"] = serde_json::json!({
146            "type": "enabled",
147            "budget_tokens": thinking_budget,
148        });
149
150        debug!("Anthropic request to {url} (thinking budget: {thinking_budget})");
151
152        let response = self
153            .http
154            .post(&url)
155            .headers(headers)
156            .json(&body)
157            .send()
158            .await
159            .map_err(|e| ProviderError::Network(e.to_string()))?;
160
161        let status = response.status();
162        if !status.is_success() {
163            let body_text = response.text().await.unwrap_or_default();
164            return match status.as_u16() {
165                401 | 403 => Err(ProviderError::Auth(body_text)),
166                429 => {
167                    let retry = parse_retry_after(&body_text);
168                    Err(ProviderError::RateLimited {
169                        retry_after_ms: retry,
170                    })
171                }
172                529 => Err(ProviderError::Overloaded),
173                413 => Err(ProviderError::RequestTooLarge(body_text)),
174                _ => Err(ProviderError::Network(format!("{status}: {body_text}"))),
175            };
176        }
177
178        // Parse Anthropic SSE stream (reuses existing StreamParser).
179        let (tx, rx) = mpsc::channel(64);
180        tokio::spawn(async move {
181            let mut parser = StreamParser::new();
182            let mut byte_stream = response.bytes_stream();
183            let mut buffer = String::new();
184            let start = std::time::Instant::now();
185            let mut first_token = false;
186
187            while let Some(chunk_result) = byte_stream.next().await {
188                let chunk = match chunk_result {
189                    Ok(c) => c,
190                    Err(e) => {
191                        let _ = tx.send(StreamEvent::Error(e.to_string())).await;
192                        break;
193                    }
194                };
195
196                buffer.push_str(&String::from_utf8_lossy(&chunk));
197
198                while let Some(pos) = buffer.find("\n\n") {
199                    let event_text = buffer[..pos].to_string();
200                    buffer = buffer[pos + 2..].to_string();
201
202                    if let Some(data) = extract_sse_data(&event_text) {
203                        if data == "[DONE]" {
204                            return;
205                        }
206
207                        match serde_json::from_str::<RawSseEvent>(data) {
208                            Ok(raw) => {
209                                let events = parser.process(raw);
210                                for event in events {
211                                    if !first_token && matches!(event, StreamEvent::TextDelta(_)) {
212                                        first_token = true;
213                                        let ttft = start.elapsed().as_millis() as u64;
214                                        let _ = tx.send(StreamEvent::Ttft(ttft)).await;
215                                    }
216                                    if tx.send(event).await.is_err() {
217                                        return;
218                                    }
219                                }
220                            }
221                            Err(e) => {
222                                warn!("SSE parse error: {e}");
223                            }
224                        }
225                    }
226                }
227            }
228        });
229
230        Ok(rx)
231    }
232}
233
234fn extract_sse_data(event_text: &str) -> Option<&str> {
235    for line in event_text.lines() {
236        if let Some(data) = line.strip_prefix("data: ") {
237            return Some(data);
238        }
239        if let Some(data) = line.strip_prefix("data:") {
240            return Some(data.trim_start());
241        }
242    }
243    None
244}
245
246fn parse_retry_after(body: &str) -> u64 {
247    if let Ok(v) = serde_json::from_str::<serde_json::Value>(body)
248        && let Some(retry) = v
249            .get("error")
250            .and_then(|e| e.get("retry_after"))
251            .and_then(|r| r.as_f64())
252    {
253        return (retry * 1000.0) as u64;
254    }
255    1000
256}