1use super::http::{default_http_client, normalize_base_url, HttpClient};
4use super::types::*;
5use super::LlmClient;
6use anyhow::{Context, Result};
7use async_trait::async_trait;
8use crate::retry::{AttemptOutcome, RetryConfig};
9use futures::StreamExt;
10use serde::Deserialize;
11use std::sync::Arc;
12use tokio::sync::mpsc;
13use tracing::Instrument;
14
15pub(crate) const DEFAULT_MAX_TOKENS: usize = 8192;
17
18pub struct AnthropicClient {
20 pub(crate) api_key: SecretString,
21 pub(crate) model: String,
22 pub(crate) base_url: String,
23 pub(crate) max_tokens: usize,
24 pub(crate) http: Arc<dyn HttpClient>,
25 pub(crate) retry_config: RetryConfig,
26}
27
28impl AnthropicClient {
29 pub fn new(api_key: String, model: String) -> Self {
30 Self {
31 api_key: SecretString::new(api_key),
32 model,
33 base_url: "https://api.anthropic.com".to_string(),
34 max_tokens: DEFAULT_MAX_TOKENS,
35 http: default_http_client(),
36 retry_config: RetryConfig::default(),
37 }
38 }
39
40 pub fn with_base_url(mut self, base_url: String) -> Self {
41 self.base_url = normalize_base_url(&base_url);
42 self
43 }
44
45 pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
46 self.max_tokens = max_tokens;
47 self
48 }
49
50 pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
51 self.retry_config = retry_config;
52 self
53 }
54
55 pub fn with_http_client(mut self, http: Arc<dyn HttpClient>) -> Self {
56 self.http = http;
57 self
58 }
59
60 pub(crate) fn build_request(
61 &self,
62 messages: &[Message],
63 system: Option<&str>,
64 tools: &[ToolDefinition],
65 ) -> serde_json::Value {
66 let mut request = serde_json::json!({
67 "model": self.model,
68 "max_tokens": self.max_tokens,
69 "messages": messages,
70 });
71
72 if let Some(sys) = system {
73 request["system"] = serde_json::json!(sys);
74 }
75
76 if !tools.is_empty() {
77 let tool_defs: Vec<serde_json::Value> = tools
78 .iter()
79 .map(|t| {
80 serde_json::json!({
81 "name": t.name,
82 "description": t.description,
83 "input_schema": t.parameters,
84 })
85 })
86 .collect();
87 request["tools"] = serde_json::json!(tool_defs);
88 }
89
90 request
91 }
92}
93
94#[async_trait]
95impl LlmClient for AnthropicClient {
96 async fn complete(
97 &self,
98 messages: &[Message],
99 system: Option<&str>,
100 tools: &[ToolDefinition],
101 ) -> Result<LlmResponse> {
102 let span = tracing::info_span!(
103 "a3s.llm.completion",
104 "a3s.llm.provider" = "anthropic",
105 "a3s.llm.model" = %self.model,
106 "a3s.llm.streaming" = false,
107 "a3s.llm.prompt_tokens" = tracing::field::Empty,
108 "a3s.llm.completion_tokens" = tracing::field::Empty,
109 "a3s.llm.total_tokens" = tracing::field::Empty,
110 "a3s.llm.stop_reason" = tracing::field::Empty,
111 );
112 async {
113 let request_body = self.build_request(messages, system, tools);
114 let url = format!("{}/v1/messages", self.base_url);
115
116 let headers = vec![
117 ("x-api-key", self.api_key.expose()),
118 ("anthropic-version", "2023-06-01"),
119 ];
120
121 let response = crate::retry::with_retry(&self.retry_config, |_attempt| {
122 let http = &self.http;
123 let url = &url;
124 let headers = headers.clone();
125 let request_body = &request_body;
126 async move {
127 match http.post(url, headers, request_body).await {
128 Ok(resp) => {
129 let status = reqwest::StatusCode::from_u16(resp.status)
130 .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
131 if status.is_success() {
132 AttemptOutcome::Success(resp.body)
133 } else if self.retry_config.is_retryable_status(status) {
134 AttemptOutcome::Retryable {
135 status,
136 body: resp.body,
137 retry_after: None,
138 }
139 } else {
140 AttemptOutcome::Fatal(anyhow::anyhow!(
141 "Anthropic API error at {} ({}): {}",
142 url,
143 status,
144 resp.body
145 ))
146 }
147 }
148 Err(e) => AttemptOutcome::Fatal(e),
149 }
150 }
151 })
152 .await?;
153
154 let parsed: AnthropicResponse =
155 serde_json::from_str(&response).context("Failed to parse Anthropic response")?;
156
157 tracing::debug!("Anthropic response: {:?}", parsed);
158
159 let content: Vec<ContentBlock> = parsed
160 .content
161 .into_iter()
162 .map(|block| match block {
163 AnthropicContentBlock::Text { text } => ContentBlock::Text { text },
164 AnthropicContentBlock::ToolUse { id, name, input } => {
165 ContentBlock::ToolUse { id, name, input }
166 }
167 })
168 .collect();
169
170 let llm_response = LlmResponse {
171 message: Message {
172 role: "assistant".to_string(),
173 content,
174 reasoning_content: None,
175 },
176 usage: TokenUsage {
177 prompt_tokens: parsed.usage.input_tokens,
178 completion_tokens: parsed.usage.output_tokens,
179 total_tokens: parsed.usage.input_tokens + parsed.usage.output_tokens,
180 cache_read_tokens: parsed.usage.cache_read_input_tokens,
181 cache_write_tokens: parsed.usage.cache_creation_input_tokens,
182 },
183 stop_reason: Some(parsed.stop_reason),
184 };
185
186 crate::telemetry::record_llm_usage(
187 llm_response.usage.prompt_tokens,
188 llm_response.usage.completion_tokens,
189 llm_response.usage.total_tokens,
190 llm_response.stop_reason.as_deref(),
191 );
192
193 Ok(llm_response)
194 }
195 .instrument(span)
196 .await
197 }
198
199 async fn complete_streaming(
200 &self,
201 messages: &[Message],
202 system: Option<&str>,
203 tools: &[ToolDefinition],
204 ) -> Result<mpsc::Receiver<StreamEvent>> {
205 let span = tracing::info_span!(
206 "a3s.llm.completion",
207 "a3s.llm.provider" = "anthropic",
208 "a3s.llm.model" = %self.model,
209 "a3s.llm.streaming" = true,
210 "a3s.llm.prompt_tokens" = tracing::field::Empty,
211 "a3s.llm.completion_tokens" = tracing::field::Empty,
212 "a3s.llm.total_tokens" = tracing::field::Empty,
213 "a3s.llm.stop_reason" = tracing::field::Empty,
214 );
215 async {
216 let mut request_body = self.build_request(messages, system, tools);
217 request_body["stream"] = serde_json::json!(true);
218
219 let url = format!("{}/v1/messages", self.base_url);
220
221 let headers = vec![
222 ("x-api-key", self.api_key.expose()),
223 ("anthropic-version", "2023-06-01"),
224 ];
225
226 let streaming_resp = crate::retry::with_retry(&self.retry_config, |_attempt| {
227 let http = &self.http;
228 let url = &url;
229 let headers = headers.clone();
230 let request_body = &request_body;
231 async move {
232 match http.post_streaming(url, headers, request_body).await {
233 Ok(resp) => {
234 let status = reqwest::StatusCode::from_u16(resp.status)
235 .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
236 if status.is_success() {
237 AttemptOutcome::Success(resp)
238 } else {
239 let retry_after = resp
240 .retry_after
241 .as_deref()
242 .and_then(|v| RetryConfig::parse_retry_after(Some(v)));
243 if self.retry_config.is_retryable_status(status) {
244 AttemptOutcome::Retryable {
245 status,
246 body: resp.error_body,
247 retry_after,
248 }
249 } else {
250 AttemptOutcome::Fatal(anyhow::anyhow!(
251 "Anthropic API error at {} ({}): {}", url, status, resp.error_body
252 ))
253 }
254 }
255 }
256 Err(e) => AttemptOutcome::Fatal(anyhow::anyhow!(
257 "Failed to send streaming request: {}", e
258 )),
259 }
260 }
261 })
262 .await?;
263
264 let (tx, rx) = mpsc::channel(100);
265
266 let mut stream = streaming_resp.byte_stream;
267 tokio::spawn(async move {
268 let mut buffer = String::new();
269 let mut content_blocks: Vec<ContentBlock> = Vec::new();
270 let mut current_tool_id = String::new();
271 let mut current_tool_name = String::new();
272 let mut current_tool_input = String::new();
273 let mut usage = TokenUsage::default();
274 let mut stop_reason = None;
275
276 while let Some(chunk_result) = stream.next().await {
277 let chunk = match chunk_result {
278 Ok(c) => c,
279 Err(e) => {
280 tracing::error!("Stream error: {}", e);
281 break;
282 }
283 };
284
285 buffer.push_str(&String::from_utf8_lossy(&chunk));
286
287 while let Some(event_end) = buffer.find("\n\n") {
288 let event_data: String = buffer.drain(..event_end).collect();
289 buffer.drain(..2);
290
291 for line in event_data.lines() {
292 if let Some(data) = line.strip_prefix("data: ") {
293 if data == "[DONE]" {
294 continue;
295 }
296
297 if let Ok(event) = serde_json::from_str::<AnthropicStreamEvent>(data) {
298 match event {
299 AnthropicStreamEvent::ContentBlockStart {
300 index: _,
301 content_block,
302 } => match content_block {
303 AnthropicContentBlock::Text { .. } => {}
304 AnthropicContentBlock::ToolUse { id, name, .. } => {
305 current_tool_id = id.clone();
306 current_tool_name = name.clone();
307 current_tool_input.clear();
308 let _ = tx
309 .send(StreamEvent::ToolUseStart { id, name })
310 .await;
311 }
312 },
313 AnthropicStreamEvent::ContentBlockDelta { index: _, delta } => {
314 match delta {
315 AnthropicDelta::TextDelta { text } => {
316 let _ = tx.send(StreamEvent::TextDelta(text)).await;
317 }
318 AnthropicDelta::InputJsonDelta { partial_json } => {
319 current_tool_input.push_str(&partial_json);
320 let _ = tx
321 .send(StreamEvent::ToolUseInputDelta(
322 partial_json,
323 ))
324 .await;
325 }
326 }
327 }
328 AnthropicStreamEvent::ContentBlockStop { index: _ } => {
329 if !current_tool_id.is_empty() {
330 let input: serde_json::Value =
331 serde_json::from_str(¤t_tool_input)
332 .unwrap_or_else(|e| {
333 tracing::warn!(
334 "Failed to parse tool input JSON for tool '{}': {}",
335 current_tool_name, e
336 );
337 serde_json::json!({
338 "__parse_error": format!(
339 "Malformed tool arguments: {}. Raw input: {}",
340 e, ¤t_tool_input
341 )
342 })
343 });
344 content_blocks.push(ContentBlock::ToolUse {
345 id: current_tool_id.clone(),
346 name: current_tool_name.clone(),
347 input,
348 });
349 current_tool_id.clear();
350 current_tool_name.clear();
351 current_tool_input.clear();
352 }
353 }
354 AnthropicStreamEvent::MessageStart { message } => {
355 usage.prompt_tokens = message.usage.input_tokens;
356 }
357 AnthropicStreamEvent::MessageDelta {
358 delta,
359 usage: msg_usage,
360 } => {
361 stop_reason = Some(delta.stop_reason);
362 usage.completion_tokens = msg_usage.output_tokens;
363 usage.total_tokens =
364 usage.prompt_tokens + usage.completion_tokens;
365 }
366 AnthropicStreamEvent::MessageStop => {
367 crate::telemetry::record_llm_usage(
368 usage.prompt_tokens,
369 usage.completion_tokens,
370 usage.total_tokens,
371 stop_reason.as_deref(),
372 );
373
374 let response = LlmResponse {
375 message: Message {
376 role: "assistant".to_string(),
377 content: std::mem::take(&mut content_blocks),
378 reasoning_content: None,
379 },
380 usage: usage.clone(),
381 stop_reason: stop_reason.clone(),
382 };
383 let _ = tx.send(StreamEvent::Done(response)).await;
384 }
385 _ => {}
386 }
387 }
388 }
389 }
390 }
391 }
392 });
393
394 Ok(rx)
395 }
396 .instrument(span)
397 .await
398 }
399}
400
401#[derive(Debug, Deserialize)]
403pub(crate) struct AnthropicResponse {
404 pub(crate) content: Vec<AnthropicContentBlock>,
405 pub(crate) stop_reason: String,
406 pub(crate) usage: AnthropicUsage,
407}
408
409#[derive(Debug, Deserialize)]
410#[serde(tag = "type")]
411pub(crate) enum AnthropicContentBlock {
412 #[serde(rename = "text")]
413 Text { text: String },
414 #[serde(rename = "tool_use")]
415 ToolUse {
416 id: String,
417 name: String,
418 input: serde_json::Value,
419 },
420}
421
422#[derive(Debug, Deserialize)]
423pub(crate) struct AnthropicUsage {
424 pub(crate) input_tokens: usize,
425 pub(crate) output_tokens: usize,
426 pub(crate) cache_read_input_tokens: Option<usize>,
427 pub(crate) cache_creation_input_tokens: Option<usize>,
428}
429
430#[derive(Debug, Deserialize)]
431#[serde(tag = "type")]
432#[allow(dead_code)]
433pub(crate) enum AnthropicStreamEvent {
434 #[serde(rename = "message_start")]
435 MessageStart { message: AnthropicMessageStart },
436 #[serde(rename = "content_block_start")]
437 ContentBlockStart {
438 index: usize,
439 content_block: AnthropicContentBlock,
440 },
441 #[serde(rename = "content_block_delta")]
442 ContentBlockDelta { index: usize, delta: AnthropicDelta },
443 #[serde(rename = "content_block_stop")]
444 ContentBlockStop { index: usize },
445 #[serde(rename = "message_delta")]
446 MessageDelta {
447 delta: AnthropicMessageDeltaData,
448 usage: AnthropicOutputUsage,
449 },
450 #[serde(rename = "message_stop")]
451 MessageStop,
452 #[serde(rename = "ping")]
453 Ping,
454 #[serde(rename = "error")]
455 Error { error: AnthropicError },
456}
457
458#[derive(Debug, Deserialize)]
459pub(crate) struct AnthropicMessageStart {
460 pub(crate) usage: AnthropicUsage,
461}
462
463#[derive(Debug, Deserialize)]
464#[serde(tag = "type")]
465pub(crate) enum AnthropicDelta {
466 #[serde(rename = "text_delta")]
467 TextDelta { text: String },
468 #[serde(rename = "input_json_delta")]
469 InputJsonDelta { partial_json: String },
470}
471
472#[derive(Debug, Deserialize)]
473pub(crate) struct AnthropicMessageDeltaData {
474 pub(crate) stop_reason: String,
475}
476
477#[derive(Debug, Deserialize)]
478pub(crate) struct AnthropicOutputUsage {
479 pub(crate) output_tokens: usize,
480}
481
482#[derive(Debug, Deserialize)]
483#[allow(dead_code)]
484pub(crate) struct AnthropicError {
485 #[serde(rename = "type")]
486 pub(crate) error_type: String,
487 pub(crate) message: String,
488}