Skip to main content

a3s_code_core/llm/
anthropic.rs

1//! Anthropic Claude LLM client
2
3use super::http::{default_http_client, normalize_base_url, HttpClient};
4use super::types::*;
5use super::LlmClient;
6use crate::retry::{AttemptOutcome, RetryConfig};
7use anyhow::{Context, Result};
8use async_trait::async_trait;
9use futures::StreamExt;
10use serde::Deserialize;
11use std::sync::Arc;
12use tokio::sync::mpsc;
13
14/// Default max tokens for LLM responses
15pub(crate) const DEFAULT_MAX_TOKENS: usize = 8192;
16
17/// Anthropic Claude client
18pub struct AnthropicClient {
19    pub(crate) api_key: SecretString,
20    pub(crate) model: String,
21    pub(crate) base_url: String,
22    pub(crate) max_tokens: usize,
23    pub(crate) http: Arc<dyn HttpClient>,
24    pub(crate) retry_config: RetryConfig,
25}
26
27impl AnthropicClient {
28    pub fn new(api_key: String, model: String) -> Self {
29        Self {
30            api_key: SecretString::new(api_key),
31            model,
32            base_url: "https://api.anthropic.com".to_string(),
33            max_tokens: DEFAULT_MAX_TOKENS,
34            http: default_http_client(),
35            retry_config: RetryConfig::default(),
36        }
37    }
38
39    pub fn with_base_url(mut self, base_url: String) -> Self {
40        self.base_url = normalize_base_url(&base_url);
41        self
42    }
43
44    pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
45        self.max_tokens = max_tokens;
46        self
47    }
48
49    pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
50        self.retry_config = retry_config;
51        self
52    }
53
54    pub fn with_http_client(mut self, http: Arc<dyn HttpClient>) -> Self {
55        self.http = http;
56        self
57    }
58
59    pub(crate) fn build_request(
60        &self,
61        messages: &[Message],
62        system: Option<&str>,
63        tools: &[ToolDefinition],
64    ) -> serde_json::Value {
65        let mut request = serde_json::json!({
66            "model": self.model,
67            "max_tokens": self.max_tokens,
68            "messages": messages,
69        });
70
71        if let Some(sys) = system {
72            request["system"] = serde_json::json!(sys);
73        }
74
75        if !tools.is_empty() {
76            let tool_defs: Vec<serde_json::Value> = tools
77                .iter()
78                .map(|t| {
79                    serde_json::json!({
80                        "name": t.name,
81                        "description": t.description,
82                        "input_schema": t.parameters,
83                    })
84                })
85                .collect();
86            request["tools"] = serde_json::json!(tool_defs);
87        }
88
89        request
90    }
91}
92
93#[async_trait]
94impl LlmClient for AnthropicClient {
95    async fn complete(
96        &self,
97        messages: &[Message],
98        system: Option<&str>,
99        tools: &[ToolDefinition],
100    ) -> Result<LlmResponse> {
101        {
102            let request_body = self.build_request(messages, system, tools);
103            let url = format!("{}/v1/messages", self.base_url);
104
105            let headers = vec![
106                ("x-api-key", self.api_key.expose()),
107                ("anthropic-version", "2023-06-01"),
108            ];
109
110            let response = crate::retry::with_retry(&self.retry_config, |_attempt| {
111                let http = &self.http;
112                let url = &url;
113                let headers = headers.clone();
114                let request_body = &request_body;
115                async move {
116                    match http.post(url, headers, request_body).await {
117                        Ok(resp) => {
118                            let status = reqwest::StatusCode::from_u16(resp.status)
119                                .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
120                            if status.is_success() {
121                                AttemptOutcome::Success(resp.body)
122                            } else if self.retry_config.is_retryable_status(status) {
123                                AttemptOutcome::Retryable {
124                                    status,
125                                    body: resp.body,
126                                    retry_after: None,
127                                }
128                            } else {
129                                AttemptOutcome::Fatal(anyhow::anyhow!(
130                                    "Anthropic API error at {} ({}): {}",
131                                    url,
132                                    status,
133                                    resp.body
134                                ))
135                            }
136                        }
137                        Err(e) => AttemptOutcome::Fatal(e),
138                    }
139                }
140            })
141            .await?;
142
143            let parsed: AnthropicResponse =
144                serde_json::from_str(&response).context("Failed to parse Anthropic response")?;
145
146            tracing::debug!("Anthropic response: {:?}", parsed);
147
148            let content: Vec<ContentBlock> = parsed
149                .content
150                .into_iter()
151                .map(|block| match block {
152                    AnthropicContentBlock::Text { text } => ContentBlock::Text { text },
153                    AnthropicContentBlock::ToolUse { id, name, input } => {
154                        ContentBlock::ToolUse { id, name, input }
155                    }
156                })
157                .collect();
158
159            let llm_response = LlmResponse {
160                message: Message {
161                    role: "assistant".to_string(),
162                    content,
163                    reasoning_content: None,
164                },
165                usage: TokenUsage {
166                    prompt_tokens: parsed.usage.input_tokens,
167                    completion_tokens: parsed.usage.output_tokens,
168                    total_tokens: parsed.usage.input_tokens + parsed.usage.output_tokens,
169                    cache_read_tokens: parsed.usage.cache_read_input_tokens,
170                    cache_write_tokens: parsed.usage.cache_creation_input_tokens,
171                },
172                stop_reason: Some(parsed.stop_reason),
173            };
174
175            crate::telemetry::record_llm_usage(
176                llm_response.usage.prompt_tokens,
177                llm_response.usage.completion_tokens,
178                llm_response.usage.total_tokens,
179                llm_response.stop_reason.as_deref(),
180            );
181
182            Ok(llm_response)
183        }
184    }
185
186    async fn complete_streaming(
187        &self,
188        messages: &[Message],
189        system: Option<&str>,
190        tools: &[ToolDefinition],
191    ) -> Result<mpsc::Receiver<StreamEvent>> {
192        {
193            let mut request_body = self.build_request(messages, system, tools);
194            request_body["stream"] = serde_json::json!(true);
195
196            let url = format!("{}/v1/messages", self.base_url);
197
198            let headers = vec![
199                ("x-api-key", self.api_key.expose()),
200                ("anthropic-version", "2023-06-01"),
201            ];
202
203            let streaming_resp = crate::retry::with_retry(&self.retry_config, |_attempt| {
204                let http = &self.http;
205                let url = &url;
206                let headers = headers.clone();
207                let request_body = &request_body;
208                async move {
209                    match http.post_streaming(url, headers, request_body).await {
210                        Ok(resp) => {
211                            let status = reqwest::StatusCode::from_u16(resp.status)
212                                .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
213                            if status.is_success() {
214                                AttemptOutcome::Success(resp)
215                            } else {
216                                let retry_after = resp
217                                    .retry_after
218                                    .as_deref()
219                                    .and_then(|v| RetryConfig::parse_retry_after(Some(v)));
220                                if self.retry_config.is_retryable_status(status) {
221                                    AttemptOutcome::Retryable {
222                                        status,
223                                        body: resp.error_body,
224                                        retry_after,
225                                    }
226                                } else {
227                                    AttemptOutcome::Fatal(anyhow::anyhow!(
228                                        "Anthropic API error at {} ({}): {}",
229                                        url,
230                                        status,
231                                        resp.error_body
232                                    ))
233                                }
234                            }
235                        }
236                        Err(e) => AttemptOutcome::Fatal(anyhow::anyhow!(
237                            "Failed to send streaming request: {}",
238                            e
239                        )),
240                    }
241                }
242            })
243            .await?;
244
245            let (tx, rx) = mpsc::channel(100);
246
247            let mut stream = streaming_resp.byte_stream;
248            tokio::spawn(async move {
249                let mut buffer = String::new();
250                let mut content_blocks: Vec<ContentBlock> = Vec::new();
251                let mut current_tool_id = String::new();
252                let mut current_tool_name = String::new();
253                let mut current_tool_input = String::new();
254                let mut usage = TokenUsage::default();
255                let mut stop_reason = None;
256
257                while let Some(chunk_result) = stream.next().await {
258                    let chunk = match chunk_result {
259                        Ok(c) => c,
260                        Err(e) => {
261                            tracing::error!("Stream error: {}", e);
262                            break;
263                        }
264                    };
265
266                    buffer.push_str(&String::from_utf8_lossy(&chunk));
267
268                    while let Some(event_end) = buffer.find("\n\n") {
269                        let event_data: String = buffer.drain(..event_end).collect();
270                        buffer.drain(..2);
271
272                        for line in event_data.lines() {
273                            if let Some(data) = line.strip_prefix("data: ") {
274                                if data == "[DONE]" {
275                                    continue;
276                                }
277
278                                if let Ok(event) =
279                                    serde_json::from_str::<AnthropicStreamEvent>(data)
280                                {
281                                    match event {
282                                        AnthropicStreamEvent::ContentBlockStart {
283                                            index: _,
284                                            content_block,
285                                        } => match content_block {
286                                            AnthropicContentBlock::Text { .. } => {}
287                                            AnthropicContentBlock::ToolUse { id, name, .. } => {
288                                                current_tool_id = id.clone();
289                                                current_tool_name = name.clone();
290                                                current_tool_input.clear();
291                                                let _ = tx
292                                                    .send(StreamEvent::ToolUseStart { id, name })
293                                                    .await;
294                                            }
295                                        },
296                                        AnthropicStreamEvent::ContentBlockDelta {
297                                            index: _,
298                                            delta,
299                                        } => match delta {
300                                            AnthropicDelta::TextDelta { text } => {
301                                                let _ = tx.send(StreamEvent::TextDelta(text)).await;
302                                            }
303                                            AnthropicDelta::InputJsonDelta { partial_json } => {
304                                                current_tool_input.push_str(&partial_json);
305                                                let _ = tx
306                                                    .send(StreamEvent::ToolUseInputDelta(
307                                                        partial_json,
308                                                    ))
309                                                    .await;
310                                            }
311                                        },
312                                        AnthropicStreamEvent::ContentBlockStop { index: _ } => {
313                                            if !current_tool_id.is_empty() {
314                                                let input: serde_json::Value =
315                                                serde_json::from_str(&current_tool_input)
316                                                    .unwrap_or_else(|e| {
317                                                        tracing::warn!(
318                                                            "Failed to parse tool input JSON for tool '{}': {}",
319                                                            current_tool_name, e
320                                                        );
321                                                        serde_json::json!({
322                                                            "__parse_error": format!(
323                                                                "Malformed tool arguments: {}. Raw input: {}",
324                                                                e, &current_tool_input
325                                                            )
326                                                        })
327                                                    });
328                                                content_blocks.push(ContentBlock::ToolUse {
329                                                    id: current_tool_id.clone(),
330                                                    name: current_tool_name.clone(),
331                                                    input,
332                                                });
333                                                current_tool_id.clear();
334                                                current_tool_name.clear();
335                                                current_tool_input.clear();
336                                            }
337                                        }
338                                        AnthropicStreamEvent::MessageStart { message } => {
339                                            usage.prompt_tokens = message.usage.input_tokens;
340                                        }
341                                        AnthropicStreamEvent::MessageDelta {
342                                            delta,
343                                            usage: msg_usage,
344                                        } => {
345                                            stop_reason = Some(delta.stop_reason);
346                                            usage.completion_tokens = msg_usage.output_tokens;
347                                            usage.total_tokens =
348                                                usage.prompt_tokens + usage.completion_tokens;
349                                        }
350                                        AnthropicStreamEvent::MessageStop => {
351                                            crate::telemetry::record_llm_usage(
352                                                usage.prompt_tokens,
353                                                usage.completion_tokens,
354                                                usage.total_tokens,
355                                                stop_reason.as_deref(),
356                                            );
357
358                                            let response = LlmResponse {
359                                                message: Message {
360                                                    role: "assistant".to_string(),
361                                                    content: std::mem::take(&mut content_blocks),
362                                                    reasoning_content: None,
363                                                },
364                                                usage: usage.clone(),
365                                                stop_reason: stop_reason.clone(),
366                                            };
367                                            let _ = tx.send(StreamEvent::Done(response)).await;
368                                        }
369                                        _ => {}
370                                    }
371                                }
372                            }
373                        }
374                    }
375                }
376            });
377
378            Ok(rx)
379        }
380    }
381}
382
383// Anthropic API response types (private)
384#[derive(Debug, Deserialize)]
385pub(crate) struct AnthropicResponse {
386    pub(crate) content: Vec<AnthropicContentBlock>,
387    pub(crate) stop_reason: String,
388    pub(crate) usage: AnthropicUsage,
389}
390
391#[derive(Debug, Deserialize)]
392#[serde(tag = "type")]
393pub(crate) enum AnthropicContentBlock {
394    #[serde(rename = "text")]
395    Text { text: String },
396    #[serde(rename = "tool_use")]
397    ToolUse {
398        id: String,
399        name: String,
400        input: serde_json::Value,
401    },
402}
403
404#[derive(Debug, Deserialize)]
405pub(crate) struct AnthropicUsage {
406    pub(crate) input_tokens: usize,
407    pub(crate) output_tokens: usize,
408    pub(crate) cache_read_input_tokens: Option<usize>,
409    pub(crate) cache_creation_input_tokens: Option<usize>,
410}
411
412#[derive(Debug, Deserialize)]
413#[serde(tag = "type")]
414#[allow(dead_code)]
415pub(crate) enum AnthropicStreamEvent {
416    #[serde(rename = "message_start")]
417    MessageStart { message: AnthropicMessageStart },
418    #[serde(rename = "content_block_start")]
419    ContentBlockStart {
420        index: usize,
421        content_block: AnthropicContentBlock,
422    },
423    #[serde(rename = "content_block_delta")]
424    ContentBlockDelta { index: usize, delta: AnthropicDelta },
425    #[serde(rename = "content_block_stop")]
426    ContentBlockStop { index: usize },
427    #[serde(rename = "message_delta")]
428    MessageDelta {
429        delta: AnthropicMessageDeltaData,
430        usage: AnthropicOutputUsage,
431    },
432    #[serde(rename = "message_stop")]
433    MessageStop,
434    #[serde(rename = "ping")]
435    Ping,
436    #[serde(rename = "error")]
437    Error { error: AnthropicError },
438}
439
440#[derive(Debug, Deserialize)]
441pub(crate) struct AnthropicMessageStart {
442    pub(crate) usage: AnthropicUsage,
443}
444
445#[derive(Debug, Deserialize)]
446#[serde(tag = "type")]
447pub(crate) enum AnthropicDelta {
448    #[serde(rename = "text_delta")]
449    TextDelta { text: String },
450    #[serde(rename = "input_json_delta")]
451    InputJsonDelta { partial_json: String },
452}
453
454#[derive(Debug, Deserialize)]
455pub(crate) struct AnthropicMessageDeltaData {
456    pub(crate) stop_reason: String,
457}
458
459#[derive(Debug, Deserialize)]
460pub(crate) struct AnthropicOutputUsage {
461    pub(crate) output_tokens: usize,
462}
463
464#[derive(Debug, Deserialize)]
465#[allow(dead_code)]
466pub(crate) struct AnthropicError {
467    #[serde(rename = "type")]
468    pub(crate) error_type: String,
469    pub(crate) message: String,
470}