1use super::http::{default_http_client, normalize_base_url, HttpClient};
4use super::types::*;
5use super::LlmClient;
6use crate::retry::{AttemptOutcome, RetryConfig};
7use anyhow::{Context, Result};
8use async_trait::async_trait;
9use futures::StreamExt;
10use serde::Deserialize;
11use std::sync::Arc;
12use tokio::sync::mpsc;
13
14pub(crate) const DEFAULT_MAX_TOKENS: usize = 8192;
16
17pub struct AnthropicClient {
19 pub(crate) api_key: SecretString,
20 pub(crate) model: String,
21 pub(crate) base_url: String,
22 pub(crate) max_tokens: usize,
23 pub(crate) http: Arc<dyn HttpClient>,
24 pub(crate) retry_config: RetryConfig,
25}
26
27impl AnthropicClient {
28 pub fn new(api_key: String, model: String) -> Self {
29 Self {
30 api_key: SecretString::new(api_key),
31 model,
32 base_url: "https://api.anthropic.com".to_string(),
33 max_tokens: DEFAULT_MAX_TOKENS,
34 http: default_http_client(),
35 retry_config: RetryConfig::default(),
36 }
37 }
38
39 pub fn with_base_url(mut self, base_url: String) -> Self {
40 self.base_url = normalize_base_url(&base_url);
41 self
42 }
43
44 pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
45 self.max_tokens = max_tokens;
46 self
47 }
48
49 pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
50 self.retry_config = retry_config;
51 self
52 }
53
54 pub fn with_http_client(mut self, http: Arc<dyn HttpClient>) -> Self {
55 self.http = http;
56 self
57 }
58
59 pub(crate) fn build_request(
60 &self,
61 messages: &[Message],
62 system: Option<&str>,
63 tools: &[ToolDefinition],
64 ) -> serde_json::Value {
65 let mut request = serde_json::json!({
66 "model": self.model,
67 "max_tokens": self.max_tokens,
68 "messages": messages,
69 });
70
71 if let Some(sys) = system {
72 request["system"] = serde_json::json!(sys);
73 }
74
75 if !tools.is_empty() {
76 let tool_defs: Vec<serde_json::Value> = tools
77 .iter()
78 .map(|t| {
79 serde_json::json!({
80 "name": t.name,
81 "description": t.description,
82 "input_schema": t.parameters,
83 })
84 })
85 .collect();
86 request["tools"] = serde_json::json!(tool_defs);
87 }
88
89 request
90 }
91}
92
93#[async_trait]
94impl LlmClient for AnthropicClient {
95 async fn complete(
96 &self,
97 messages: &[Message],
98 system: Option<&str>,
99 tools: &[ToolDefinition],
100 ) -> Result<LlmResponse> {
101 {
102 let request_body = self.build_request(messages, system, tools);
103 let url = format!("{}/v1/messages", self.base_url);
104
105 let headers = vec![
106 ("x-api-key", self.api_key.expose()),
107 ("anthropic-version", "2023-06-01"),
108 ];
109
110 let response = crate::retry::with_retry(&self.retry_config, |_attempt| {
111 let http = &self.http;
112 let url = &url;
113 let headers = headers.clone();
114 let request_body = &request_body;
115 async move {
116 match http.post(url, headers, request_body).await {
117 Ok(resp) => {
118 let status = reqwest::StatusCode::from_u16(resp.status)
119 .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
120 if status.is_success() {
121 AttemptOutcome::Success(resp.body)
122 } else if self.retry_config.is_retryable_status(status) {
123 AttemptOutcome::Retryable {
124 status,
125 body: resp.body,
126 retry_after: None,
127 }
128 } else {
129 AttemptOutcome::Fatal(anyhow::anyhow!(
130 "Anthropic API error at {} ({}): {}",
131 url,
132 status,
133 resp.body
134 ))
135 }
136 }
137 Err(e) => AttemptOutcome::Fatal(e),
138 }
139 }
140 })
141 .await?;
142
143 let parsed: AnthropicResponse =
144 serde_json::from_str(&response).context("Failed to parse Anthropic response")?;
145
146 tracing::debug!("Anthropic response: {:?}", parsed);
147
148 let content: Vec<ContentBlock> = parsed
149 .content
150 .into_iter()
151 .map(|block| match block {
152 AnthropicContentBlock::Text { text } => ContentBlock::Text { text },
153 AnthropicContentBlock::ToolUse { id, name, input } => {
154 ContentBlock::ToolUse { id, name, input }
155 }
156 })
157 .collect();
158
159 let llm_response = LlmResponse {
160 message: Message {
161 role: "assistant".to_string(),
162 content,
163 reasoning_content: None,
164 },
165 usage: TokenUsage {
166 prompt_tokens: parsed.usage.input_tokens,
167 completion_tokens: parsed.usage.output_tokens,
168 total_tokens: parsed.usage.input_tokens + parsed.usage.output_tokens,
169 cache_read_tokens: parsed.usage.cache_read_input_tokens,
170 cache_write_tokens: parsed.usage.cache_creation_input_tokens,
171 },
172 stop_reason: Some(parsed.stop_reason),
173 };
174
175 crate::telemetry::record_llm_usage(
176 llm_response.usage.prompt_tokens,
177 llm_response.usage.completion_tokens,
178 llm_response.usage.total_tokens,
179 llm_response.stop_reason.as_deref(),
180 );
181
182 Ok(llm_response)
183 }
184 }
185
186 async fn complete_streaming(
187 &self,
188 messages: &[Message],
189 system: Option<&str>,
190 tools: &[ToolDefinition],
191 ) -> Result<mpsc::Receiver<StreamEvent>> {
192 {
193 let mut request_body = self.build_request(messages, system, tools);
194 request_body["stream"] = serde_json::json!(true);
195
196 let url = format!("{}/v1/messages", self.base_url);
197
198 let headers = vec![
199 ("x-api-key", self.api_key.expose()),
200 ("anthropic-version", "2023-06-01"),
201 ];
202
203 let streaming_resp = crate::retry::with_retry(&self.retry_config, |_attempt| {
204 let http = &self.http;
205 let url = &url;
206 let headers = headers.clone();
207 let request_body = &request_body;
208 async move {
209 match http.post_streaming(url, headers, request_body).await {
210 Ok(resp) => {
211 let status = reqwest::StatusCode::from_u16(resp.status)
212 .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
213 if status.is_success() {
214 AttemptOutcome::Success(resp)
215 } else {
216 let retry_after = resp
217 .retry_after
218 .as_deref()
219 .and_then(|v| RetryConfig::parse_retry_after(Some(v)));
220 if self.retry_config.is_retryable_status(status) {
221 AttemptOutcome::Retryable {
222 status,
223 body: resp.error_body,
224 retry_after,
225 }
226 } else {
227 AttemptOutcome::Fatal(anyhow::anyhow!(
228 "Anthropic API error at {} ({}): {}",
229 url,
230 status,
231 resp.error_body
232 ))
233 }
234 }
235 }
236 Err(e) => AttemptOutcome::Fatal(anyhow::anyhow!(
237 "Failed to send streaming request: {}",
238 e
239 )),
240 }
241 }
242 })
243 .await?;
244
245 let (tx, rx) = mpsc::channel(100);
246
247 let mut stream = streaming_resp.byte_stream;
248 tokio::spawn(async move {
249 let mut buffer = String::new();
250 let mut content_blocks: Vec<ContentBlock> = Vec::new();
251 let mut current_tool_id = String::new();
252 let mut current_tool_name = String::new();
253 let mut current_tool_input = String::new();
254 let mut usage = TokenUsage::default();
255 let mut stop_reason = None;
256
257 while let Some(chunk_result) = stream.next().await {
258 let chunk = match chunk_result {
259 Ok(c) => c,
260 Err(e) => {
261 tracing::error!("Stream error: {}", e);
262 break;
263 }
264 };
265
266 buffer.push_str(&String::from_utf8_lossy(&chunk));
267
268 while let Some(event_end) = buffer.find("\n\n") {
269 let event_data: String = buffer.drain(..event_end).collect();
270 buffer.drain(..2);
271
272 for line in event_data.lines() {
273 if let Some(data) = line.strip_prefix("data: ") {
274 if data == "[DONE]" {
275 continue;
276 }
277
278 if let Ok(event) =
279 serde_json::from_str::<AnthropicStreamEvent>(data)
280 {
281 match event {
282 AnthropicStreamEvent::ContentBlockStart {
283 index: _,
284 content_block,
285 } => match content_block {
286 AnthropicContentBlock::Text { .. } => {}
287 AnthropicContentBlock::ToolUse { id, name, .. } => {
288 current_tool_id = id.clone();
289 current_tool_name = name.clone();
290 current_tool_input.clear();
291 let _ = tx
292 .send(StreamEvent::ToolUseStart { id, name })
293 .await;
294 }
295 },
296 AnthropicStreamEvent::ContentBlockDelta {
297 index: _,
298 delta,
299 } => match delta {
300 AnthropicDelta::TextDelta { text } => {
301 let _ = tx.send(StreamEvent::TextDelta(text)).await;
302 }
303 AnthropicDelta::InputJsonDelta { partial_json } => {
304 current_tool_input.push_str(&partial_json);
305 let _ = tx
306 .send(StreamEvent::ToolUseInputDelta(
307 partial_json,
308 ))
309 .await;
310 }
311 },
312 AnthropicStreamEvent::ContentBlockStop { index: _ } => {
313 if !current_tool_id.is_empty() {
314 let input: serde_json::Value =
315 serde_json::from_str(¤t_tool_input)
316 .unwrap_or_else(|e| {
317 tracing::warn!(
318 "Failed to parse tool input JSON for tool '{}': {}",
319 current_tool_name, e
320 );
321 serde_json::json!({
322 "__parse_error": format!(
323 "Malformed tool arguments: {}. Raw input: {}",
324 e, ¤t_tool_input
325 )
326 })
327 });
328 content_blocks.push(ContentBlock::ToolUse {
329 id: current_tool_id.clone(),
330 name: current_tool_name.clone(),
331 input,
332 });
333 current_tool_id.clear();
334 current_tool_name.clear();
335 current_tool_input.clear();
336 }
337 }
338 AnthropicStreamEvent::MessageStart { message } => {
339 usage.prompt_tokens = message.usage.input_tokens;
340 }
341 AnthropicStreamEvent::MessageDelta {
342 delta,
343 usage: msg_usage,
344 } => {
345 stop_reason = Some(delta.stop_reason);
346 usage.completion_tokens = msg_usage.output_tokens;
347 usage.total_tokens =
348 usage.prompt_tokens + usage.completion_tokens;
349 }
350 AnthropicStreamEvent::MessageStop => {
351 crate::telemetry::record_llm_usage(
352 usage.prompt_tokens,
353 usage.completion_tokens,
354 usage.total_tokens,
355 stop_reason.as_deref(),
356 );
357
358 let response = LlmResponse {
359 message: Message {
360 role: "assistant".to_string(),
361 content: std::mem::take(&mut content_blocks),
362 reasoning_content: None,
363 },
364 usage: usage.clone(),
365 stop_reason: stop_reason.clone(),
366 };
367 let _ = tx.send(StreamEvent::Done(response)).await;
368 }
369 _ => {}
370 }
371 }
372 }
373 }
374 }
375 }
376 });
377
378 Ok(rx)
379 }
380 }
381}
382
383#[derive(Debug, Deserialize)]
385pub(crate) struct AnthropicResponse {
386 pub(crate) content: Vec<AnthropicContentBlock>,
387 pub(crate) stop_reason: String,
388 pub(crate) usage: AnthropicUsage,
389}
390
391#[derive(Debug, Deserialize)]
392#[serde(tag = "type")]
393pub(crate) enum AnthropicContentBlock {
394 #[serde(rename = "text")]
395 Text { text: String },
396 #[serde(rename = "tool_use")]
397 ToolUse {
398 id: String,
399 name: String,
400 input: serde_json::Value,
401 },
402}
403
404#[derive(Debug, Deserialize)]
405pub(crate) struct AnthropicUsage {
406 pub(crate) input_tokens: usize,
407 pub(crate) output_tokens: usize,
408 pub(crate) cache_read_input_tokens: Option<usize>,
409 pub(crate) cache_creation_input_tokens: Option<usize>,
410}
411
412#[derive(Debug, Deserialize)]
413#[serde(tag = "type")]
414#[allow(dead_code)]
415pub(crate) enum AnthropicStreamEvent {
416 #[serde(rename = "message_start")]
417 MessageStart { message: AnthropicMessageStart },
418 #[serde(rename = "content_block_start")]
419 ContentBlockStart {
420 index: usize,
421 content_block: AnthropicContentBlock,
422 },
423 #[serde(rename = "content_block_delta")]
424 ContentBlockDelta { index: usize, delta: AnthropicDelta },
425 #[serde(rename = "content_block_stop")]
426 ContentBlockStop { index: usize },
427 #[serde(rename = "message_delta")]
428 MessageDelta {
429 delta: AnthropicMessageDeltaData,
430 usage: AnthropicOutputUsage,
431 },
432 #[serde(rename = "message_stop")]
433 MessageStop,
434 #[serde(rename = "ping")]
435 Ping,
436 #[serde(rename = "error")]
437 Error { error: AnthropicError },
438}
439
440#[derive(Debug, Deserialize)]
441pub(crate) struct AnthropicMessageStart {
442 pub(crate) usage: AnthropicUsage,
443}
444
445#[derive(Debug, Deserialize)]
446#[serde(tag = "type")]
447pub(crate) enum AnthropicDelta {
448 #[serde(rename = "text_delta")]
449 TextDelta { text: String },
450 #[serde(rename = "input_json_delta")]
451 InputJsonDelta { partial_json: String },
452}
453
454#[derive(Debug, Deserialize)]
455pub(crate) struct AnthropicMessageDeltaData {
456 pub(crate) stop_reason: String,
457}
458
459#[derive(Debug, Deserialize)]
460pub(crate) struct AnthropicOutputUsage {
461 pub(crate) output_tokens: usize,
462}
463
464#[derive(Debug, Deserialize)]
465#[allow(dead_code)]
466pub(crate) struct AnthropicError {
467 #[serde(rename = "type")]
468 pub(crate) error_type: String,
469 pub(crate) message: String,
470}