Skip to main content

a3s_code_core/llm/
anthropic.rs

1//! Anthropic Claude LLM client
2
3use super::http::{default_http_client, normalize_base_url, HttpClient};
4use super::types::*;
5use super::LlmClient;
6use crate::retry::{AttemptOutcome, RetryConfig};
7use anyhow::{Context, Result};
8use async_trait::async_trait;
9use futures::StreamExt;
10use serde::Deserialize;
11use std::sync::Arc;
12use tokio::sync::mpsc;
13
14/// Default max tokens for LLM responses
15pub(crate) const DEFAULT_MAX_TOKENS: usize = 8192;
16
17/// Anthropic Claude client
18pub struct AnthropicClient {
19    pub(crate) api_key: SecretString,
20    pub(crate) model: String,
21    pub(crate) base_url: String,
22    pub(crate) max_tokens: usize,
23    pub(crate) temperature: Option<f32>,
24    pub(crate) thinking_budget: Option<usize>,
25    pub(crate) http: Arc<dyn HttpClient>,
26    pub(crate) retry_config: RetryConfig,
27}
28
29impl AnthropicClient {
30    pub fn new(api_key: String, model: String) -> Self {
31        Self {
32            api_key: SecretString::new(api_key),
33            model,
34            base_url: "https://api.anthropic.com".to_string(),
35            max_tokens: DEFAULT_MAX_TOKENS,
36            temperature: None,
37            thinking_budget: None,
38            http: default_http_client(),
39            retry_config: RetryConfig::default(),
40        }
41    }
42
43    pub fn with_base_url(mut self, base_url: String) -> Self {
44        self.base_url = normalize_base_url(&base_url);
45        self
46    }
47
48    pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
49        self.max_tokens = max_tokens;
50        self
51    }
52
53    pub fn with_temperature(mut self, temperature: f32) -> Self {
54        self.temperature = Some(temperature);
55        self
56    }
57
58    pub fn with_thinking_budget(mut self, budget: usize) -> Self {
59        self.thinking_budget = Some(budget);
60        self
61    }
62
63    pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
64        self.retry_config = retry_config;
65        self
66    }
67
68    pub fn with_http_client(mut self, http: Arc<dyn HttpClient>) -> Self {
69        self.http = http;
70        self
71    }
72
73    pub(crate) fn build_request(
74        &self,
75        messages: &[Message],
76        system: Option<&str>,
77        tools: &[ToolDefinition],
78    ) -> serde_json::Value {
79        let mut request = serde_json::json!({
80            "model": self.model,
81            "max_tokens": self.max_tokens,
82            "messages": messages,
83        });
84
85        // System prompt with cache_control for prompt caching.
86        // Anthropic caches system content blocks marked with
87        // `cache_control: { type: "ephemeral" }`.
88        if let Some(sys) = system {
89            request["system"] = serde_json::json!([
90                {
91                    "type": "text",
92                    "text": sys,
93                    "cache_control": { "type": "ephemeral" }
94                }
95            ]);
96        }
97
98        if !tools.is_empty() {
99            let mut tool_defs: Vec<serde_json::Value> = tools
100                .iter()
101                .map(|t| {
102                    serde_json::json!({
103                        "name": t.name,
104                        "description": t.description,
105                        "input_schema": t.parameters,
106                    })
107                })
108                .collect();
109
110            // Mark the last tool definition with cache_control so the
111            // entire tool block is cached on subsequent requests.
112            if let Some(last) = tool_defs.last_mut() {
113                last["cache_control"] = serde_json::json!({ "type": "ephemeral" });
114            }
115
116            request["tools"] = serde_json::json!(tool_defs);
117        }
118
119        // Apply optional sampling parameters
120        if let Some(temp) = self.temperature {
121            request["temperature"] = serde_json::json!(temp);
122        }
123
124        // Extended thinking (Anthropic-specific)
125        if let Some(budget) = self.thinking_budget {
126            request["thinking"] = serde_json::json!({
127                "type": "enabled",
128                "budget_tokens": budget
129            });
130            // Thinking requires temperature=1 per Anthropic docs
131            request["temperature"] = serde_json::json!(1.0);
132        }
133
134        request
135    }
136}
137
138#[async_trait]
139impl LlmClient for AnthropicClient {
140    async fn complete(
141        &self,
142        messages: &[Message],
143        system: Option<&str>,
144        tools: &[ToolDefinition],
145    ) -> Result<LlmResponse> {
146        {
147            let request_body = self.build_request(messages, system, tools);
148            let url = format!("{}/v1/messages", self.base_url);
149
150            let headers = vec![
151                ("x-api-key", self.api_key.expose()),
152                ("anthropic-version", "2023-06-01"),
153                ("anthropic-beta", "prompt-caching-2024-07-31"),
154            ];
155
156            let response = crate::retry::with_retry(&self.retry_config, |_attempt| {
157                let http = &self.http;
158                let url = &url;
159                let headers = headers.clone();
160                let request_body = &request_body;
161                async move {
162                    match http.post(url, headers, request_body).await {
163                        Ok(resp) => {
164                            let status = reqwest::StatusCode::from_u16(resp.status)
165                                .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
166                            if status.is_success() {
167                                AttemptOutcome::Success(resp.body)
168                            } else if self.retry_config.is_retryable_status(status) {
169                                AttemptOutcome::Retryable {
170                                    status,
171                                    body: resp.body,
172                                    retry_after: None,
173                                }
174                            } else {
175                                AttemptOutcome::Fatal(anyhow::anyhow!(
176                                    "Anthropic API error at {} ({}): {}",
177                                    url,
178                                    status,
179                                    resp.body
180                                ))
181                            }
182                        }
183                        Err(e) => AttemptOutcome::Fatal(e),
184                    }
185                }
186            })
187            .await?;
188
189            let parsed: AnthropicResponse =
190                serde_json::from_str(&response).context("Failed to parse Anthropic response")?;
191
192            tracing::debug!("Anthropic response: {:?}", parsed);
193
194            let content: Vec<ContentBlock> = parsed
195                .content
196                .into_iter()
197                .map(|block| match block {
198                    AnthropicContentBlock::Text { text } => ContentBlock::Text { text },
199                    AnthropicContentBlock::ToolUse { id, name, input } => {
200                        ContentBlock::ToolUse { id, name, input }
201                    }
202                })
203                .collect();
204
205            let llm_response = LlmResponse {
206                message: Message {
207                    role: "assistant".to_string(),
208                    content,
209                    reasoning_content: None,
210                },
211                usage: TokenUsage {
212                    prompt_tokens: parsed.usage.input_tokens,
213                    completion_tokens: parsed.usage.output_tokens,
214                    total_tokens: parsed.usage.input_tokens + parsed.usage.output_tokens,
215                    cache_read_tokens: parsed.usage.cache_read_input_tokens,
216                    cache_write_tokens: parsed.usage.cache_creation_input_tokens,
217                },
218                stop_reason: Some(parsed.stop_reason),
219            };
220
221            crate::telemetry::record_llm_usage(
222                llm_response.usage.prompt_tokens,
223                llm_response.usage.completion_tokens,
224                llm_response.usage.total_tokens,
225                llm_response.stop_reason.as_deref(),
226            );
227
228            Ok(llm_response)
229        }
230    }
231
232    async fn complete_streaming(
233        &self,
234        messages: &[Message],
235        system: Option<&str>,
236        tools: &[ToolDefinition],
237    ) -> Result<mpsc::Receiver<StreamEvent>> {
238        {
239            let mut request_body = self.build_request(messages, system, tools);
240            request_body["stream"] = serde_json::json!(true);
241
242            let url = format!("{}/v1/messages", self.base_url);
243
244            let headers = vec![
245                ("x-api-key", self.api_key.expose()),
246                ("anthropic-version", "2023-06-01"),
247                ("anthropic-beta", "prompt-caching-2024-07-31"),
248            ];
249
250            let streaming_resp = crate::retry::with_retry(&self.retry_config, |_attempt| {
251                let http = &self.http;
252                let url = &url;
253                let headers = headers.clone();
254                let request_body = &request_body;
255                async move {
256                    match http.post_streaming(url, headers, request_body).await {
257                        Ok(resp) => {
258                            let status = reqwest::StatusCode::from_u16(resp.status)
259                                .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
260                            if status.is_success() {
261                                AttemptOutcome::Success(resp)
262                            } else {
263                                let retry_after = resp
264                                    .retry_after
265                                    .as_deref()
266                                    .and_then(|v| RetryConfig::parse_retry_after(Some(v)));
267                                if self.retry_config.is_retryable_status(status) {
268                                    AttemptOutcome::Retryable {
269                                        status,
270                                        body: resp.error_body,
271                                        retry_after,
272                                    }
273                                } else {
274                                    AttemptOutcome::Fatal(anyhow::anyhow!(
275                                        "Anthropic API error at {} ({}): {}",
276                                        url,
277                                        status,
278                                        resp.error_body
279                                    ))
280                                }
281                            }
282                        }
283                        Err(e) => AttemptOutcome::Fatal(anyhow::anyhow!(
284                            "Failed to send streaming request: {}",
285                            e
286                        )),
287                    }
288                }
289            })
290            .await?;
291
292            let (tx, rx) = mpsc::channel(100);
293
294            let mut stream = streaming_resp.byte_stream;
295            tokio::spawn(async move {
296                let mut buffer = String::new();
297                let mut content_blocks: Vec<ContentBlock> = Vec::new();
298                let mut current_tool_id = String::new();
299                let mut current_tool_name = String::new();
300                let mut current_tool_input = String::new();
301                let mut usage = TokenUsage::default();
302                let mut stop_reason = None;
303
304                while let Some(chunk_result) = stream.next().await {
305                    let chunk = match chunk_result {
306                        Ok(c) => c,
307                        Err(e) => {
308                            tracing::error!("Stream error: {}", e);
309                            break;
310                        }
311                    };
312
313                    buffer.push_str(&String::from_utf8_lossy(&chunk));
314
315                    while let Some(event_end) = buffer.find("\n\n") {
316                        let event_data: String = buffer.drain(..event_end).collect();
317                        buffer.drain(..2);
318
319                        for line in event_data.lines() {
320                            if let Some(data) = line.strip_prefix("data: ") {
321                                if data == "[DONE]" {
322                                    continue;
323                                }
324
325                                if let Ok(event) =
326                                    serde_json::from_str::<AnthropicStreamEvent>(data)
327                                {
328                                    match event {
329                                        AnthropicStreamEvent::ContentBlockStart {
330                                            index: _,
331                                            content_block,
332                                        } => match content_block {
333                                            AnthropicContentBlock::Text { .. } => {}
334                                            AnthropicContentBlock::ToolUse { id, name, .. } => {
335                                                current_tool_id = id.clone();
336                                                current_tool_name = name.clone();
337                                                current_tool_input.clear();
338                                                let _ = tx
339                                                    .send(StreamEvent::ToolUseStart { id, name })
340                                                    .await;
341                                            }
342                                        },
343                                        AnthropicStreamEvent::ContentBlockDelta {
344                                            index: _,
345                                            delta,
346                                        } => match delta {
347                                            AnthropicDelta::TextDelta { text } => {
348                                                let _ = tx.send(StreamEvent::TextDelta(text)).await;
349                                            }
350                                            AnthropicDelta::InputJsonDelta { partial_json } => {
351                                                current_tool_input.push_str(&partial_json);
352                                                let _ = tx
353                                                    .send(StreamEvent::ToolUseInputDelta(
354                                                        partial_json,
355                                                    ))
356                                                    .await;
357                                            }
358                                        },
359                                        AnthropicStreamEvent::ContentBlockStop { index: _ } => {
360                                            if !current_tool_id.is_empty() {
361                                                let input: serde_json::Value =
362                                                serde_json::from_str(&current_tool_input)
363                                                    .unwrap_or_else(|e| {
364                                                        tracing::warn!(
365                                                            "Failed to parse tool input JSON for tool '{}': {}",
366                                                            current_tool_name, e
367                                                        );
368                                                        serde_json::json!({
369                                                            "__parse_error": format!(
370                                                                "Malformed tool arguments: {}. Raw input: {}",
371                                                                e, &current_tool_input
372                                                            )
373                                                        })
374                                                    });
375                                                content_blocks.push(ContentBlock::ToolUse {
376                                                    id: current_tool_id.clone(),
377                                                    name: current_tool_name.clone(),
378                                                    input,
379                                                });
380                                                current_tool_id.clear();
381                                                current_tool_name.clear();
382                                                current_tool_input.clear();
383                                            }
384                                        }
385                                        AnthropicStreamEvent::MessageStart { message } => {
386                                            usage.prompt_tokens = message.usage.input_tokens;
387                                        }
388                                        AnthropicStreamEvent::MessageDelta {
389                                            delta,
390                                            usage: msg_usage,
391                                        } => {
392                                            stop_reason = Some(delta.stop_reason);
393                                            usage.completion_tokens = msg_usage.output_tokens;
394                                            usage.total_tokens =
395                                                usage.prompt_tokens + usage.completion_tokens;
396                                        }
397                                        AnthropicStreamEvent::MessageStop => {
398                                            crate::telemetry::record_llm_usage(
399                                                usage.prompt_tokens,
400                                                usage.completion_tokens,
401                                                usage.total_tokens,
402                                                stop_reason.as_deref(),
403                                            );
404
405                                            let response = LlmResponse {
406                                                message: Message {
407                                                    role: "assistant".to_string(),
408                                                    content: std::mem::take(&mut content_blocks),
409                                                    reasoning_content: None,
410                                                },
411                                                usage: usage.clone(),
412                                                stop_reason: stop_reason.clone(),
413                                            };
414                                            let _ = tx.send(StreamEvent::Done(response)).await;
415                                        }
416                                        _ => {}
417                                    }
418                                }
419                            }
420                        }
421                    }
422                }
423            });
424
425            Ok(rx)
426        }
427    }
428}
429
430// Anthropic API response types (private)
431#[derive(Debug, Deserialize)]
432pub(crate) struct AnthropicResponse {
433    pub(crate) content: Vec<AnthropicContentBlock>,
434    pub(crate) stop_reason: String,
435    pub(crate) usage: AnthropicUsage,
436}
437
438#[derive(Debug, Deserialize)]
439#[serde(tag = "type")]
440pub(crate) enum AnthropicContentBlock {
441    #[serde(rename = "text")]
442    Text { text: String },
443    #[serde(rename = "tool_use")]
444    ToolUse {
445        id: String,
446        name: String,
447        input: serde_json::Value,
448    },
449}
450
451#[derive(Debug, Deserialize)]
452pub(crate) struct AnthropicUsage {
453    pub(crate) input_tokens: usize,
454    pub(crate) output_tokens: usize,
455    pub(crate) cache_read_input_tokens: Option<usize>,
456    pub(crate) cache_creation_input_tokens: Option<usize>,
457}
458
459#[derive(Debug, Deserialize)]
460#[serde(tag = "type")]
461#[allow(dead_code)]
462pub(crate) enum AnthropicStreamEvent {
463    #[serde(rename = "message_start")]
464    MessageStart { message: AnthropicMessageStart },
465    #[serde(rename = "content_block_start")]
466    ContentBlockStart {
467        index: usize,
468        content_block: AnthropicContentBlock,
469    },
470    #[serde(rename = "content_block_delta")]
471    ContentBlockDelta { index: usize, delta: AnthropicDelta },
472    #[serde(rename = "content_block_stop")]
473    ContentBlockStop { index: usize },
474    #[serde(rename = "message_delta")]
475    MessageDelta {
476        delta: AnthropicMessageDeltaData,
477        usage: AnthropicOutputUsage,
478    },
479    #[serde(rename = "message_stop")]
480    MessageStop,
481    #[serde(rename = "ping")]
482    Ping,
483    #[serde(rename = "error")]
484    Error { error: AnthropicError },
485}
486
487#[derive(Debug, Deserialize)]
488pub(crate) struct AnthropicMessageStart {
489    pub(crate) usage: AnthropicUsage,
490}
491
492#[derive(Debug, Deserialize)]
493#[serde(tag = "type")]
494pub(crate) enum AnthropicDelta {
495    #[serde(rename = "text_delta")]
496    TextDelta { text: String },
497    #[serde(rename = "input_json_delta")]
498    InputJsonDelta { partial_json: String },
499}
500
501#[derive(Debug, Deserialize)]
502pub(crate) struct AnthropicMessageDeltaData {
503    pub(crate) stop_reason: String,
504}
505
506#[derive(Debug, Deserialize)]
507pub(crate) struct AnthropicOutputUsage {
508    pub(crate) output_tokens: usize,
509}
510
511#[derive(Debug, Deserialize)]
512#[allow(dead_code)]
513pub(crate) struct AnthropicError {
514    #[serde(rename = "type")]
515    pub(crate) error_type: String,
516    pub(crate) message: String,
517}