agent_code_lib/llm/
anthropic.rs1use async_trait::async_trait;
9use futures::StreamExt;
10use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
11use tokio::sync::mpsc;
12use tracing::{debug, warn};
13
14use super::message::{messages_to_api_params, messages_to_api_params_cached};
15use super::provider::{Provider, ProviderError, ProviderRequest};
16use super::stream::{RawSseEvent, StreamEvent, StreamParser};
17
18pub struct AnthropicProvider {
20 http: reqwest::Client,
21 base_url: String,
22 api_key: String,
23}
24
25impl AnthropicProvider {
26 pub fn new(base_url: &str, api_key: &str) -> Self {
27 let http = reqwest::Client::builder()
28 .timeout(std::time::Duration::from_secs(300))
29 .build()
30 .expect("failed to build HTTP client");
31
32 Self {
33 http,
34 base_url: base_url.trim_end_matches('/').to_string(),
35 api_key: api_key.to_string(),
36 }
37 }
38}
39
40#[async_trait]
41impl Provider for AnthropicProvider {
42 fn name(&self) -> &str {
43 "anthropic"
44 }
45
46 async fn stream(
47 &self,
48 request: &ProviderRequest,
49 ) -> Result<mpsc::Receiver<StreamEvent>, ProviderError> {
50 let url = format!("{}/messages", self.base_url);
51
52 let mut headers = HeaderMap::new();
53 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
54 headers.insert(
55 "x-api-key",
56 HeaderValue::from_str(&self.api_key).map_err(|e| ProviderError::Auth(e.to_string()))?,
57 );
58 headers.insert("anthropic-version", HeaderValue::from_static("2023-06-01"));
59
60 let mut betas = Vec::new();
62 betas.push("interleaved-thinking-2025-05-14"); if request.enable_caching {
64 betas.push("prompt-caching-2024-07-31");
65 }
66 if !betas.is_empty() {
67 headers.insert(
68 "anthropic-beta",
69 HeaderValue::from_str(&betas.join(",")).unwrap_or(HeaderValue::from_static("")),
70 );
71 }
72
73 let tool_count = request.tools.len();
78 let tools: Vec<serde_json::Value> = request
79 .tools
80 .iter()
81 .enumerate()
82 .map(|(i, t)| {
83 let mut tool = serde_json::json!({
84 "name": t.name,
85 "description": t.description,
86 "input_schema": t.input_schema,
87 });
88 if request.enable_caching && i == tool_count - 1 && tool_count > 0 {
89 tool["cache_control"] = serde_json::json!({"type": "ephemeral"});
90 }
91 tool
92 })
93 .collect();
94
95 let system = if request.enable_caching {
97 serde_json::json!([{
98 "type": "text",
99 "text": request.system_prompt,
100 "cache_control": { "type": "ephemeral" }
101 }])
102 } else {
103 serde_json::json!(request.system_prompt)
104 };
105
106 let mut body = serde_json::json!({
107 "model": request.model,
108 "max_tokens": request.max_tokens,
109 "stream": true,
110 "system": system,
111 "messages": if request.enable_caching {
112 messages_to_api_params_cached(&request.messages)
113 } else {
114 messages_to_api_params(&request.messages)
115 },
116 "tools": tools,
117 });
118
119 if let Some(temp) = request.temperature {
120 body["temperature"] = serde_json::json!(temp);
121 }
122
123 if !request.tools.is_empty() {
125 use super::provider::ToolChoice;
126 match &request.tool_choice {
127 ToolChoice::Auto => {
128 body["tool_choice"] = serde_json::json!({"type": "auto"});
129 }
130 ToolChoice::Any => {
131 body["tool_choice"] = serde_json::json!({"type": "any"});
132 }
133 ToolChoice::None => {
134 body.as_object_mut().unwrap().remove("tools");
136 }
137 ToolChoice::Specific(name) => {
138 body["tool_choice"] = serde_json::json!({
139 "type": "tool",
140 "name": name
141 });
142 }
143 }
144 }
145
146 if let Some(ref meta) = request.metadata {
148 body["metadata"] = meta.clone();
149 }
150
151 let thinking_budget =
153 crate::services::tokens::max_thinking_tokens_for_model(&request.model);
154 body["thinking"] = serde_json::json!({
155 "type": "enabled",
156 "budget_tokens": thinking_budget,
157 });
158
159 debug!("Anthropic request to {url} (thinking budget: {thinking_budget})");
160
161 let response = self
162 .http
163 .post(&url)
164 .headers(headers)
165 .json(&body)
166 .send()
167 .await
168 .map_err(|e| ProviderError::Network(e.to_string()))?;
169
170 let status = response.status();
171 if !status.is_success() {
172 let body_text = response.text().await.unwrap_or_default();
173 return match status.as_u16() {
174 401 | 403 => Err(ProviderError::Auth(body_text)),
175 429 => {
176 let retry = parse_retry_after(&body_text);
177 Err(ProviderError::RateLimited {
178 retry_after_ms: retry,
179 })
180 }
181 529 => Err(ProviderError::Overloaded),
182 413 => Err(ProviderError::RequestTooLarge(body_text)),
183 _ => Err(ProviderError::Network(format!("{status}: {body_text}"))),
184 };
185 }
186
187 let (tx, rx) = mpsc::channel(64);
189 tokio::spawn(async move {
190 let mut parser = StreamParser::new();
191 let mut byte_stream = response.bytes_stream();
192 let mut buffer = String::new();
193 let start = std::time::Instant::now();
194 let mut first_token = false;
195
196 while let Some(chunk_result) = byte_stream.next().await {
197 let chunk = match chunk_result {
198 Ok(c) => c,
199 Err(e) => {
200 let _ = tx.send(StreamEvent::Error(e.to_string())).await;
201 break;
202 }
203 };
204
205 buffer.push_str(&String::from_utf8_lossy(&chunk));
206
207 while let Some(pos) = buffer.find("\n\n") {
208 let event_text = buffer[..pos].to_string();
209 buffer = buffer[pos + 2..].to_string();
210
211 if let Some(data) = extract_sse_data(&event_text) {
212 if data == "[DONE]" {
213 return;
214 }
215
216 match serde_json::from_str::<RawSseEvent>(data) {
217 Ok(raw) => {
218 let events = parser.process(raw);
219 for event in events {
220 if !first_token && matches!(event, StreamEvent::TextDelta(_)) {
221 first_token = true;
222 let ttft = start.elapsed().as_millis() as u64;
223 let _ = tx.send(StreamEvent::Ttft(ttft)).await;
224 }
225 if tx.send(event).await.is_err() {
226 return;
227 }
228 }
229 }
230 Err(e) => {
231 warn!("SSE parse error: {e}");
232 }
233 }
234 }
235 }
236 }
237 });
238
239 Ok(rx)
240 }
241}
242
243fn extract_sse_data(event_text: &str) -> Option<&str> {
244 for line in event_text.lines() {
245 if let Some(data) = line.strip_prefix("data: ") {
246 return Some(data);
247 }
248 if let Some(data) = line.strip_prefix("data:") {
249 return Some(data.trim_start());
250 }
251 }
252 None
253}
254
255fn parse_retry_after(body: &str) -> u64 {
256 if let Ok(v) = serde_json::from_str::<serde_json::Value>(body)
257 && let Some(retry) = v
258 .get("error")
259 .and_then(|e| e.get("retry_after"))
260 .and_then(|r| r.as_f64())
261 {
262 return (retry * 1000.0) as u64;
263 }
264 1000
265}