Skip to main content

agent_code_lib/llm/
anthropic.rs

1//! Anthropic Messages API provider.
2//!
3//! Native support for Claude models. Uses the Anthropic-specific
4//! wire format: top-level system param, content block arrays,
5//! tool definitions with input_schema, and SSE streaming with
6//! content_block_start/delta/stop events.
7
8use async_trait::async_trait;
9use futures::StreamExt;
10use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
11use tokio::sync::mpsc;
12use tracing::{debug, warn};
13
14use super::message::{messages_to_api_params, messages_to_api_params_cached};
15use super::provider::{Provider, ProviderError, ProviderRequest};
16use super::stream::{RawSseEvent, StreamEvent, StreamParser};
17
18pub struct AnthropicProvider {
19    http: reqwest::Client,
20    base_url: String,
21    api_key: String,
22}
23
24impl AnthropicProvider {
25    pub fn new(base_url: &str, api_key: &str) -> Self {
26        let http = reqwest::Client::builder()
27            .timeout(std::time::Duration::from_secs(300))
28            .build()
29            .expect("failed to build HTTP client");
30
31        Self {
32            http,
33            base_url: base_url.trim_end_matches('/').to_string(),
34            api_key: api_key.to_string(),
35        }
36    }
37}
38
39#[async_trait]
40impl Provider for AnthropicProvider {
41    fn name(&self) -> &str {
42        "anthropic"
43    }
44
45    async fn stream(
46        &self,
47        request: &ProviderRequest,
48    ) -> Result<mpsc::Receiver<StreamEvent>, ProviderError> {
49        let url = format!("{}/messages", self.base_url);
50
51        let mut headers = HeaderMap::new();
52        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
53        headers.insert(
54            "x-api-key",
55            HeaderValue::from_str(&self.api_key).map_err(|e| ProviderError::Auth(e.to_string()))?,
56        );
57        headers.insert("anthropic-version", HeaderValue::from_static("2023-06-01"));
58
59        // Enable beta features.
60        let mut betas = Vec::new();
61        betas.push("interleaved-thinking-2025-05-14"); // Extended thinking.
62        if request.enable_caching {
63            betas.push("prompt-caching-2024-07-31");
64        }
65        if !betas.is_empty() {
66            headers.insert(
67                "anthropic-beta",
68                HeaderValue::from_str(&betas.join(",")).unwrap_or(HeaderValue::from_static("")),
69            );
70        }
71
72        // Build tool definitions in Anthropic format.
73        let tools: Vec<serde_json::Value> = request
74            .tools
75            .iter()
76            .map(|t| {
77                serde_json::json!({
78                    "name": t.name,
79                    "description": t.description,
80                    "input_schema": t.input_schema,
81                })
82            })
83            .collect();
84
85        // System prompt with optional cache control.
86        let system = if request.enable_caching {
87            serde_json::json!([{
88                "type": "text",
89                "text": request.system_prompt,
90                "cache_control": { "type": "ephemeral" }
91            }])
92        } else {
93            serde_json::json!(request.system_prompt)
94        };
95
96        let mut body = serde_json::json!({
97            "model": request.model,
98            "max_tokens": request.max_tokens,
99            "stream": true,
100            "system": system,
101            "messages": if request.enable_caching {
102                messages_to_api_params_cached(&request.messages)
103            } else {
104                messages_to_api_params(&request.messages)
105            },
106            "tools": tools,
107        });
108
109        if let Some(temp) = request.temperature {
110            body["temperature"] = serde_json::json!(temp);
111        }
112
113        // Tool choice.
114        if !request.tools.is_empty() {
115            use super::provider::ToolChoice;
116            match &request.tool_choice {
117                ToolChoice::Auto => {
118                    body["tool_choice"] = serde_json::json!({"type": "auto"});
119                }
120                ToolChoice::Any => {
121                    body["tool_choice"] = serde_json::json!({"type": "any"});
122                }
123                ToolChoice::None => {
124                    // Anthropic doesn't have a "none" tool_choice — just omit tools.
125                    body.as_object_mut().unwrap().remove("tools");
126                }
127                ToolChoice::Specific(name) => {
128                    body["tool_choice"] = serde_json::json!({
129                        "type": "tool",
130                        "name": name
131                    });
132                }
133            }
134        }
135
136        // Metadata (e.g., user_id for analytics).
137        if let Some(ref meta) = request.metadata {
138            body["metadata"] = meta.clone();
139        }
140
141        // Thinking configuration (adaptive or budgeted).
142        let thinking_budget =
143            crate::services::tokens::max_thinking_tokens_for_model(&request.model);
144        body["thinking"] = serde_json::json!({
145            "type": "enabled",
146            "budget_tokens": thinking_budget,
147        });
148
149        debug!("Anthropic request to {url} (thinking budget: {thinking_budget})");
150
151        let response = self
152            .http
153            .post(&url)
154            .headers(headers)
155            .json(&body)
156            .send()
157            .await
158            .map_err(|e| ProviderError::Network(e.to_string()))?;
159
160        let status = response.status();
161        if !status.is_success() {
162            let body_text = response.text().await.unwrap_or_default();
163            return match status.as_u16() {
164                401 | 403 => Err(ProviderError::Auth(body_text)),
165                429 => {
166                    let retry = parse_retry_after(&body_text);
167                    Err(ProviderError::RateLimited {
168                        retry_after_ms: retry,
169                    })
170                }
171                529 => Err(ProviderError::Overloaded),
172                413 => Err(ProviderError::RequestTooLarge(body_text)),
173                _ => Err(ProviderError::Network(format!("{status}: {body_text}"))),
174            };
175        }
176
177        // Parse Anthropic SSE stream (reuses existing StreamParser).
178        let (tx, rx) = mpsc::channel(64);
179        tokio::spawn(async move {
180            let mut parser = StreamParser::new();
181            let mut byte_stream = response.bytes_stream();
182            let mut buffer = String::new();
183            let start = std::time::Instant::now();
184            let mut first_token = false;
185
186            while let Some(chunk_result) = byte_stream.next().await {
187                let chunk = match chunk_result {
188                    Ok(c) => c,
189                    Err(e) => {
190                        let _ = tx.send(StreamEvent::Error(e.to_string())).await;
191                        break;
192                    }
193                };
194
195                buffer.push_str(&String::from_utf8_lossy(&chunk));
196
197                while let Some(pos) = buffer.find("\n\n") {
198                    let event_text = buffer[..pos].to_string();
199                    buffer = buffer[pos + 2..].to_string();
200
201                    if let Some(data) = extract_sse_data(&event_text) {
202                        if data == "[DONE]" {
203                            return;
204                        }
205
206                        match serde_json::from_str::<RawSseEvent>(data) {
207                            Ok(raw) => {
208                                let events = parser.process(raw);
209                                for event in events {
210                                    if !first_token && matches!(event, StreamEvent::TextDelta(_)) {
211                                        first_token = true;
212                                        let ttft = start.elapsed().as_millis() as u64;
213                                        let _ = tx.send(StreamEvent::Ttft(ttft)).await;
214                                    }
215                                    if tx.send(event).await.is_err() {
216                                        return;
217                                    }
218                                }
219                            }
220                            Err(e) => {
221                                warn!("SSE parse error: {e}");
222                            }
223                        }
224                    }
225                }
226            }
227        });
228
229        Ok(rx)
230    }
231}
232
233fn extract_sse_data(event_text: &str) -> Option<&str> {
234    for line in event_text.lines() {
235        if let Some(data) = line.strip_prefix("data: ") {
236            return Some(data);
237        }
238        if let Some(data) = line.strip_prefix("data:") {
239            return Some(data.trim_start());
240        }
241    }
242    None
243}
244
245fn parse_retry_after(body: &str) -> u64 {
246    if let Ok(v) = serde_json::from_str::<serde_json::Value>(body)
247        && let Some(retry) = v
248            .get("error")
249            .and_then(|e| e.get("retry_after"))
250            .and_then(|r| r.as_f64())
251    {
252        return (retry * 1000.0) as u64;
253    }
254    1000
255}