agent_code_lib/llm/
anthropic.rs1use async_trait::async_trait;
9use futures::StreamExt;
10use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
11use tokio::sync::mpsc;
12use tracing::{debug, warn};
13
14use super::message::{messages_to_api_params, messages_to_api_params_cached};
15use super::provider::{Provider, ProviderError, ProviderRequest};
16use super::stream::{RawSseEvent, StreamEvent, StreamParser};
17
18pub struct AnthropicProvider {
19 http: reqwest::Client,
20 base_url: String,
21 api_key: String,
22}
23
24impl AnthropicProvider {
25 pub fn new(base_url: &str, api_key: &str) -> Self {
26 let http = reqwest::Client::builder()
27 .timeout(std::time::Duration::from_secs(300))
28 .build()
29 .expect("failed to build HTTP client");
30
31 Self {
32 http,
33 base_url: base_url.trim_end_matches('/').to_string(),
34 api_key: api_key.to_string(),
35 }
36 }
37}
38
39#[async_trait]
40impl Provider for AnthropicProvider {
41 fn name(&self) -> &str {
42 "anthropic"
43 }
44
45 async fn stream(
46 &self,
47 request: &ProviderRequest,
48 ) -> Result<mpsc::Receiver<StreamEvent>, ProviderError> {
49 let url = format!("{}/messages", self.base_url);
50
51 let mut headers = HeaderMap::new();
52 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
53 headers.insert(
54 "x-api-key",
55 HeaderValue::from_str(&self.api_key).map_err(|e| ProviderError::Auth(e.to_string()))?,
56 );
57 headers.insert("anthropic-version", HeaderValue::from_static("2023-06-01"));
58
59 let mut betas = Vec::new();
61 betas.push("interleaved-thinking-2025-05-14"); if request.enable_caching {
63 betas.push("prompt-caching-2024-07-31");
64 }
65 if !betas.is_empty() {
66 headers.insert(
67 "anthropic-beta",
68 HeaderValue::from_str(&betas.join(",")).unwrap_or(HeaderValue::from_static("")),
69 );
70 }
71
72 let tools: Vec<serde_json::Value> = request
74 .tools
75 .iter()
76 .map(|t| {
77 serde_json::json!({
78 "name": t.name,
79 "description": t.description,
80 "input_schema": t.input_schema,
81 })
82 })
83 .collect();
84
85 let system = if request.enable_caching {
87 serde_json::json!([{
88 "type": "text",
89 "text": request.system_prompt,
90 "cache_control": { "type": "ephemeral" }
91 }])
92 } else {
93 serde_json::json!(request.system_prompt)
94 };
95
96 let mut body = serde_json::json!({
97 "model": request.model,
98 "max_tokens": request.max_tokens,
99 "stream": true,
100 "system": system,
101 "messages": if request.enable_caching {
102 messages_to_api_params_cached(&request.messages)
103 } else {
104 messages_to_api_params(&request.messages)
105 },
106 "tools": tools,
107 });
108
109 if let Some(temp) = request.temperature {
110 body["temperature"] = serde_json::json!(temp);
111 }
112
113 if !request.tools.is_empty() {
115 use super::provider::ToolChoice;
116 match &request.tool_choice {
117 ToolChoice::Auto => {
118 body["tool_choice"] = serde_json::json!({"type": "auto"});
119 }
120 ToolChoice::Any => {
121 body["tool_choice"] = serde_json::json!({"type": "any"});
122 }
123 ToolChoice::None => {
124 body.as_object_mut().unwrap().remove("tools");
126 }
127 ToolChoice::Specific(name) => {
128 body["tool_choice"] = serde_json::json!({
129 "type": "tool",
130 "name": name
131 });
132 }
133 }
134 }
135
136 if let Some(ref meta) = request.metadata {
138 body["metadata"] = meta.clone();
139 }
140
141 let thinking_budget =
143 crate::services::tokens::max_thinking_tokens_for_model(&request.model);
144 body["thinking"] = serde_json::json!({
145 "type": "enabled",
146 "budget_tokens": thinking_budget,
147 });
148
149 debug!("Anthropic request to {url} (thinking budget: {thinking_budget})");
150
151 let response = self
152 .http
153 .post(&url)
154 .headers(headers)
155 .json(&body)
156 .send()
157 .await
158 .map_err(|e| ProviderError::Network(e.to_string()))?;
159
160 let status = response.status();
161 if !status.is_success() {
162 let body_text = response.text().await.unwrap_or_default();
163 return match status.as_u16() {
164 401 | 403 => Err(ProviderError::Auth(body_text)),
165 429 => {
166 let retry = parse_retry_after(&body_text);
167 Err(ProviderError::RateLimited {
168 retry_after_ms: retry,
169 })
170 }
171 529 => Err(ProviderError::Overloaded),
172 413 => Err(ProviderError::RequestTooLarge(body_text)),
173 _ => Err(ProviderError::Network(format!("{status}: {body_text}"))),
174 };
175 }
176
177 let (tx, rx) = mpsc::channel(64);
179 tokio::spawn(async move {
180 let mut parser = StreamParser::new();
181 let mut byte_stream = response.bytes_stream();
182 let mut buffer = String::new();
183 let start = std::time::Instant::now();
184 let mut first_token = false;
185
186 while let Some(chunk_result) = byte_stream.next().await {
187 let chunk = match chunk_result {
188 Ok(c) => c,
189 Err(e) => {
190 let _ = tx.send(StreamEvent::Error(e.to_string())).await;
191 break;
192 }
193 };
194
195 buffer.push_str(&String::from_utf8_lossy(&chunk));
196
197 while let Some(pos) = buffer.find("\n\n") {
198 let event_text = buffer[..pos].to_string();
199 buffer = buffer[pos + 2..].to_string();
200
201 if let Some(data) = extract_sse_data(&event_text) {
202 if data == "[DONE]" {
203 return;
204 }
205
206 match serde_json::from_str::<RawSseEvent>(data) {
207 Ok(raw) => {
208 let events = parser.process(raw);
209 for event in events {
210 if !first_token && matches!(event, StreamEvent::TextDelta(_)) {
211 first_token = true;
212 let ttft = start.elapsed().as_millis() as u64;
213 let _ = tx.send(StreamEvent::Ttft(ttft)).await;
214 }
215 if tx.send(event).await.is_err() {
216 return;
217 }
218 }
219 }
220 Err(e) => {
221 warn!("SSE parse error: {e}");
222 }
223 }
224 }
225 }
226 }
227 });
228
229 Ok(rx)
230 }
231}
232
233fn extract_sse_data(event_text: &str) -> Option<&str> {
234 for line in event_text.lines() {
235 if let Some(data) = line.strip_prefix("data: ") {
236 return Some(data);
237 }
238 if let Some(data) = line.strip_prefix("data:") {
239 return Some(data.trim_start());
240 }
241 }
242 None
243}
244
245fn parse_retry_after(body: &str) -> u64 {
246 if let Ok(v) = serde_json::from_str::<serde_json::Value>(body)
247 && let Some(retry) = v
248 .get("error")
249 .and_then(|e| e.get("retry_after"))
250 .and_then(|r| r.as_f64())
251 {
252 return (retry * 1000.0) as u64;
253 }
254 1000
255}