agent_code_lib/llm/
anthropic.rs1use async_trait::async_trait;
9use futures::StreamExt;
10use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
11use tokio::sync::mpsc;
12use tracing::{debug, warn};
13
14use super::message::{messages_to_api_params, messages_to_api_params_cached};
15use super::provider::{Provider, ProviderError, ProviderRequest};
16use super::stream::{RawSseEvent, StreamEvent, StreamParser};
17
18pub struct AnthropicProvider {
20 http: reqwest::Client,
21 base_url: String,
22 api_key: String,
23}
24
25impl AnthropicProvider {
26 pub fn new(base_url: &str, api_key: &str) -> Self {
27 let http = reqwest::Client::builder()
28 .timeout(std::time::Duration::from_secs(300))
29 .build()
30 .expect("failed to build HTTP client");
31
32 Self {
33 http,
34 base_url: base_url.trim_end_matches('/').to_string(),
35 api_key: api_key.to_string(),
36 }
37 }
38}
39
40#[async_trait]
41impl Provider for AnthropicProvider {
42 fn name(&self) -> &str {
43 "anthropic"
44 }
45
46 async fn stream(
47 &self,
48 request: &ProviderRequest,
49 ) -> Result<mpsc::Receiver<StreamEvent>, ProviderError> {
50 let url = format!("{}/messages", self.base_url);
51
52 let mut headers = HeaderMap::new();
53 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
54 headers.insert(
55 "x-api-key",
56 HeaderValue::from_str(&self.api_key).map_err(|e| ProviderError::Auth(e.to_string()))?,
57 );
58 headers.insert("anthropic-version", HeaderValue::from_static("2023-06-01"));
59
60 let mut betas = Vec::new();
62 betas.push("interleaved-thinking-2025-05-14"); if request.enable_caching {
64 betas.push("prompt-caching-2024-07-31");
65 }
66 if !betas.is_empty() {
67 headers.insert(
68 "anthropic-beta",
69 HeaderValue::from_str(&betas.join(",")).unwrap_or(HeaderValue::from_static("")),
70 );
71 }
72
73 let tools: Vec<serde_json::Value> = request
75 .tools
76 .iter()
77 .map(|t| {
78 serde_json::json!({
79 "name": t.name,
80 "description": t.description,
81 "input_schema": t.input_schema,
82 })
83 })
84 .collect();
85
86 let system = if request.enable_caching {
88 serde_json::json!([{
89 "type": "text",
90 "text": request.system_prompt,
91 "cache_control": { "type": "ephemeral" }
92 }])
93 } else {
94 serde_json::json!(request.system_prompt)
95 };
96
97 let mut body = serde_json::json!({
98 "model": request.model,
99 "max_tokens": request.max_tokens,
100 "stream": true,
101 "system": system,
102 "messages": if request.enable_caching {
103 messages_to_api_params_cached(&request.messages)
104 } else {
105 messages_to_api_params(&request.messages)
106 },
107 "tools": tools,
108 });
109
110 if let Some(temp) = request.temperature {
111 body["temperature"] = serde_json::json!(temp);
112 }
113
114 if !request.tools.is_empty() {
116 use super::provider::ToolChoice;
117 match &request.tool_choice {
118 ToolChoice::Auto => {
119 body["tool_choice"] = serde_json::json!({"type": "auto"});
120 }
121 ToolChoice::Any => {
122 body["tool_choice"] = serde_json::json!({"type": "any"});
123 }
124 ToolChoice::None => {
125 body.as_object_mut().unwrap().remove("tools");
127 }
128 ToolChoice::Specific(name) => {
129 body["tool_choice"] = serde_json::json!({
130 "type": "tool",
131 "name": name
132 });
133 }
134 }
135 }
136
137 if let Some(ref meta) = request.metadata {
139 body["metadata"] = meta.clone();
140 }
141
142 let thinking_budget =
144 crate::services::tokens::max_thinking_tokens_for_model(&request.model);
145 body["thinking"] = serde_json::json!({
146 "type": "enabled",
147 "budget_tokens": thinking_budget,
148 });
149
150 debug!("Anthropic request to {url} (thinking budget: {thinking_budget})");
151
152 let response = self
153 .http
154 .post(&url)
155 .headers(headers)
156 .json(&body)
157 .send()
158 .await
159 .map_err(|e| ProviderError::Network(e.to_string()))?;
160
161 let status = response.status();
162 if !status.is_success() {
163 let body_text = response.text().await.unwrap_or_default();
164 return match status.as_u16() {
165 401 | 403 => Err(ProviderError::Auth(body_text)),
166 429 => {
167 let retry = parse_retry_after(&body_text);
168 Err(ProviderError::RateLimited {
169 retry_after_ms: retry,
170 })
171 }
172 529 => Err(ProviderError::Overloaded),
173 413 => Err(ProviderError::RequestTooLarge(body_text)),
174 _ => Err(ProviderError::Network(format!("{status}: {body_text}"))),
175 };
176 }
177
178 let (tx, rx) = mpsc::channel(64);
180 tokio::spawn(async move {
181 let mut parser = StreamParser::new();
182 let mut byte_stream = response.bytes_stream();
183 let mut buffer = String::new();
184 let start = std::time::Instant::now();
185 let mut first_token = false;
186
187 while let Some(chunk_result) = byte_stream.next().await {
188 let chunk = match chunk_result {
189 Ok(c) => c,
190 Err(e) => {
191 let _ = tx.send(StreamEvent::Error(e.to_string())).await;
192 break;
193 }
194 };
195
196 buffer.push_str(&String::from_utf8_lossy(&chunk));
197
198 while let Some(pos) = buffer.find("\n\n") {
199 let event_text = buffer[..pos].to_string();
200 buffer = buffer[pos + 2..].to_string();
201
202 if let Some(data) = extract_sse_data(&event_text) {
203 if data == "[DONE]" {
204 return;
205 }
206
207 match serde_json::from_str::<RawSseEvent>(data) {
208 Ok(raw) => {
209 let events = parser.process(raw);
210 for event in events {
211 if !first_token && matches!(event, StreamEvent::TextDelta(_)) {
212 first_token = true;
213 let ttft = start.elapsed().as_millis() as u64;
214 let _ = tx.send(StreamEvent::Ttft(ttft)).await;
215 }
216 if tx.send(event).await.is_err() {
217 return;
218 }
219 }
220 }
221 Err(e) => {
222 warn!("SSE parse error: {e}");
223 }
224 }
225 }
226 }
227 }
228 });
229
230 Ok(rx)
231 }
232}
233
234fn extract_sse_data(event_text: &str) -> Option<&str> {
235 for line in event_text.lines() {
236 if let Some(data) = line.strip_prefix("data: ") {
237 return Some(data);
238 }
239 if let Some(data) = line.strip_prefix("data:") {
240 return Some(data.trim_start());
241 }
242 }
243 None
244}
245
246fn parse_retry_after(body: &str) -> u64 {
247 if let Ok(v) = serde_json::from_str::<serde_json::Value>(body)
248 && let Some(retry) = v
249 .get("error")
250 .and_then(|e| e.get("retry_after"))
251 .and_then(|r| r.as_f64())
252 {
253 return (retry * 1000.0) as u64;
254 }
255 1000
256}