Skip to main content

a3s_code_core/llm/
anthropic.rs

1//! Anthropic Claude LLM client
2
3use super::http::{default_http_client, normalize_base_url, HttpClient};
4use super::types::*;
5use super::LlmClient;
6use crate::retry::{AttemptOutcome, RetryConfig};
7use anyhow::{Context, Result};
8use async_trait::async_trait;
9use futures::StreamExt;
10use serde::Deserialize;
11use std::sync::Arc;
12use std::time::Instant;
13use tokio::sync::mpsc;
14
15/// Default max tokens for LLM responses
16pub(crate) const DEFAULT_MAX_TOKENS: usize = 8192;
17
18/// Anthropic Claude client
19pub struct AnthropicClient {
20    pub(crate) provider_name: String,
21    pub(crate) api_key: SecretString,
22    pub(crate) model: String,
23    pub(crate) base_url: String,
24    pub(crate) max_tokens: usize,
25    pub(crate) temperature: Option<f32>,
26    pub(crate) thinking_budget: Option<usize>,
27    pub(crate) http: Arc<dyn HttpClient>,
28    pub(crate) retry_config: RetryConfig,
29}
30
31impl AnthropicClient {
32    pub fn new(api_key: String, model: String) -> Self {
33        Self {
34            provider_name: "anthropic".to_string(),
35            api_key: SecretString::new(api_key),
36            model,
37            base_url: "https://api.anthropic.com".to_string(),
38            max_tokens: DEFAULT_MAX_TOKENS,
39            temperature: None,
40            thinking_budget: None,
41            http: default_http_client(),
42            retry_config: RetryConfig::default(),
43        }
44    }
45
46    pub fn with_base_url(mut self, base_url: String) -> Self {
47        self.base_url = normalize_base_url(&base_url);
48        self
49    }
50
51    pub fn with_provider_name(mut self, provider_name: impl Into<String>) -> Self {
52        self.provider_name = provider_name.into();
53        self
54    }
55
56    pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
57        self.max_tokens = max_tokens;
58        self
59    }
60
61    pub fn with_temperature(mut self, temperature: f32) -> Self {
62        self.temperature = Some(temperature);
63        self
64    }
65
66    pub fn with_thinking_budget(mut self, budget: usize) -> Self {
67        self.thinking_budget = Some(budget);
68        self
69    }
70
71    pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
72        self.retry_config = retry_config;
73        self
74    }
75
76    pub fn with_http_client(mut self, http: Arc<dyn HttpClient>) -> Self {
77        self.http = http;
78        self
79    }
80
81    pub(crate) fn build_request(
82        &self,
83        messages: &[Message],
84        system: Option<&str>,
85        tools: &[ToolDefinition],
86    ) -> serde_json::Value {
87        let mut request = serde_json::json!({
88            "model": self.model,
89            "max_tokens": self.max_tokens,
90            "messages": messages,
91        });
92
93        // System prompt with cache_control for prompt caching.
94        // Anthropic caches system content blocks marked with
95        // `cache_control: { type: "ephemeral" }`.
96        if let Some(sys) = system {
97            request["system"] = serde_json::json!([
98                {
99                    "type": "text",
100                    "text": sys,
101                    "cache_control": { "type": "ephemeral" }
102                }
103            ]);
104        }
105
106        if !tools.is_empty() {
107            let mut tool_defs: Vec<serde_json::Value> = tools
108                .iter()
109                .map(|t| {
110                    serde_json::json!({
111                        "name": t.name,
112                        "description": t.description,
113                        "input_schema": t.parameters,
114                    })
115                })
116                .collect();
117
118            // Mark the last tool definition with cache_control so the
119            // entire tool block is cached on subsequent requests.
120            if let Some(last) = tool_defs.last_mut() {
121                last["cache_control"] = serde_json::json!({ "type": "ephemeral" });
122            }
123
124            request["tools"] = serde_json::json!(tool_defs);
125        }
126
127        // Apply optional sampling parameters
128        if let Some(temp) = self.temperature {
129            request["temperature"] = serde_json::json!(temp);
130        }
131
132        // Extended thinking (Anthropic-specific)
133        if let Some(budget) = self.thinking_budget {
134            request["thinking"] = serde_json::json!({
135                "type": "enabled",
136                "budget_tokens": budget
137            });
138            // Thinking requires temperature=1 per Anthropic docs
139            request["temperature"] = serde_json::json!(1.0);
140        }
141
142        request
143    }
144}
145
146#[async_trait]
147impl LlmClient for AnthropicClient {
148    async fn complete(
149        &self,
150        messages: &[Message],
151        system: Option<&str>,
152        tools: &[ToolDefinition],
153    ) -> Result<LlmResponse> {
154        {
155            let request_started_at = Instant::now();
156            let request_body = self.build_request(messages, system, tools);
157            let url = format!("{}/v1/messages", self.base_url);
158
159            let headers = vec![
160                ("x-api-key", self.api_key.expose()),
161                ("anthropic-version", "2023-06-01"),
162                ("anthropic-beta", "prompt-caching-2024-07-31"),
163            ];
164
165            let response = crate::retry::with_retry(&self.retry_config, |_attempt| {
166                let http = &self.http;
167                let url = &url;
168                let headers = headers.clone();
169                let request_body = &request_body;
170                async move {
171                    match http.post(url, headers, request_body).await {
172                        Ok(resp) => {
173                            let status = reqwest::StatusCode::from_u16(resp.status)
174                                .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
175                            if status.is_success() {
176                                AttemptOutcome::Success(resp.body)
177                            } else if self.retry_config.is_retryable_status(status) {
178                                AttemptOutcome::Retryable {
179                                    status,
180                                    body: resp.body,
181                                    retry_after: None,
182                                }
183                            } else {
184                                AttemptOutcome::Fatal(anyhow::anyhow!(
185                                    "Anthropic API error at {} ({}): {}",
186                                    url,
187                                    status,
188                                    resp.body
189                                ))
190                            }
191                        }
192                        Err(e) => AttemptOutcome::Fatal(e),
193                    }
194                }
195            })
196            .await?;
197
198            let parsed: AnthropicResponse =
199                serde_json::from_str(&response).context("Failed to parse Anthropic response")?;
200
201            tracing::debug!("Anthropic response: {:?}", parsed);
202
203            let content: Vec<ContentBlock> = parsed
204                .content
205                .into_iter()
206                .map(|block| match block {
207                    AnthropicContentBlock::Text { text } => ContentBlock::Text { text },
208                    AnthropicContentBlock::ToolUse { id, name, input } => {
209                        ContentBlock::ToolUse { id, name, input }
210                    }
211                })
212                .collect();
213
214            let llm_response = LlmResponse {
215                message: Message {
216                    role: "assistant".to_string(),
217                    content,
218                    reasoning_content: None,
219                },
220                usage: TokenUsage {
221                    prompt_tokens: parsed.usage.input_tokens,
222                    completion_tokens: parsed.usage.output_tokens,
223                    total_tokens: parsed.usage.input_tokens + parsed.usage.output_tokens,
224                    cache_read_tokens: parsed.usage.cache_read_input_tokens,
225                    cache_write_tokens: parsed.usage.cache_creation_input_tokens,
226                },
227                stop_reason: Some(parsed.stop_reason),
228                meta: Some(LlmResponseMeta {
229                    provider: Some(self.provider_name.clone()),
230                    request_model: Some(self.model.clone()),
231                    request_url: Some(url.clone()),
232                    response_id: parsed.id,
233                    response_model: parsed.model,
234                    response_object: parsed.response_type,
235                    first_token_ms: None,
236                    duration_ms: Some(request_started_at.elapsed().as_millis() as u64),
237                }),
238            };
239
240            crate::telemetry::record_llm_usage(
241                llm_response.usage.prompt_tokens,
242                llm_response.usage.completion_tokens,
243                llm_response.usage.total_tokens,
244                llm_response.stop_reason.as_deref(),
245            );
246
247            Ok(llm_response)
248        }
249    }
250
251    async fn complete_streaming(
252        &self,
253        messages: &[Message],
254        system: Option<&str>,
255        tools: &[ToolDefinition],
256    ) -> Result<mpsc::Receiver<StreamEvent>> {
257        {
258            let request_started_at = Instant::now();
259            let mut request_body = self.build_request(messages, system, tools);
260            request_body["stream"] = serde_json::json!(true);
261
262            let url = format!("{}/v1/messages", self.base_url);
263
264            let headers = vec![
265                ("x-api-key", self.api_key.expose()),
266                ("anthropic-version", "2023-06-01"),
267                ("anthropic-beta", "prompt-caching-2024-07-31"),
268            ];
269
270            let streaming_resp = crate::retry::with_retry(&self.retry_config, |_attempt| {
271                let http = &self.http;
272                let url = &url;
273                let headers = headers.clone();
274                let request_body = &request_body;
275                async move {
276                    match http.post_streaming(url, headers, request_body).await {
277                        Ok(resp) => {
278                            let status = reqwest::StatusCode::from_u16(resp.status)
279                                .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
280                            if status.is_success() {
281                                AttemptOutcome::Success(resp)
282                            } else {
283                                let retry_after = resp
284                                    .retry_after
285                                    .as_deref()
286                                    .and_then(|v| RetryConfig::parse_retry_after(Some(v)));
287                                if self.retry_config.is_retryable_status(status) {
288                                    AttemptOutcome::Retryable {
289                                        status,
290                                        body: resp.error_body,
291                                        retry_after,
292                                    }
293                                } else {
294                                    AttemptOutcome::Fatal(anyhow::anyhow!(
295                                        "Anthropic API error at {} ({}): {}",
296                                        url,
297                                        status,
298                                        resp.error_body
299                                    ))
300                                }
301                            }
302                        }
303                        Err(e) => AttemptOutcome::Fatal(anyhow::anyhow!(
304                            "Failed to send streaming request: {}",
305                            e
306                        )),
307                    }
308                }
309            })
310            .await?;
311
312            let (tx, rx) = mpsc::channel(100);
313
314            let mut stream = streaming_resp.byte_stream;
315            let provider_name = self.provider_name.clone();
316            let request_model = self.model.clone();
317            let request_url = url.clone();
318            tokio::spawn(async move {
319                let mut buffer = String::new();
320                let mut content_blocks: Vec<ContentBlock> = Vec::new();
321                let mut current_tool_id = String::new();
322                let mut current_tool_name = String::new();
323                let mut current_tool_input = String::new();
324                let mut usage = TokenUsage::default();
325                let mut stop_reason = None;
326                let mut response_id = None;
327                let mut response_model = None;
328                let mut response_object = Some("message".to_string());
329                let mut first_token_ms = None;
330
331                while let Some(chunk_result) = stream.next().await {
332                    let chunk = match chunk_result {
333                        Ok(c) => c,
334                        Err(e) => {
335                            tracing::error!("Stream error: {}", e);
336                            break;
337                        }
338                    };
339
340                    buffer.push_str(&String::from_utf8_lossy(&chunk));
341
342                    while let Some(event_end) = buffer.find("\n\n") {
343                        let event_data: String = buffer.drain(..event_end).collect();
344                        buffer.drain(..2);
345
346                        for line in event_data.lines() {
347                            if let Some(data) = line.strip_prefix("data: ") {
348                                if data == "[DONE]" {
349                                    continue;
350                                }
351
352                                if let Ok(event) =
353                                    serde_json::from_str::<AnthropicStreamEvent>(data)
354                                {
355                                    match event {
356                                        AnthropicStreamEvent::ContentBlockStart {
357                                            index: _,
358                                            content_block,
359                                        } => match content_block {
360                                            AnthropicContentBlock::Text { .. } => {}
361                                            AnthropicContentBlock::ToolUse { id, name, .. } => {
362                                                current_tool_id = id.clone();
363                                                current_tool_name = name.clone();
364                                                current_tool_input.clear();
365                                                let _ = tx
366                                                    .send(StreamEvent::ToolUseStart { id, name })
367                                                    .await;
368                                            }
369                                        },
370                                        AnthropicStreamEvent::ContentBlockDelta {
371                                            index: _,
372                                            delta,
373                                        } => match delta {
374                                            AnthropicDelta::TextDelta { text } => {
375                                                if first_token_ms.is_none() {
376                                                    first_token_ms = Some(
377                                                        request_started_at.elapsed().as_millis()
378                                                            as u64,
379                                                    );
380                                                }
381                                                let _ = tx.send(StreamEvent::TextDelta(text)).await;
382                                            }
383                                            AnthropicDelta::InputJsonDelta { partial_json } => {
384                                                if first_token_ms.is_none() {
385                                                    first_token_ms = Some(
386                                                        request_started_at.elapsed().as_millis()
387                                                            as u64,
388                                                    );
389                                                }
390                                                current_tool_input.push_str(&partial_json);
391                                                let _ = tx
392                                                    .send(StreamEvent::ToolUseInputDelta(
393                                                        partial_json,
394                                                    ))
395                                                    .await;
396                                            }
397                                        },
398                                        AnthropicStreamEvent::ContentBlockStop { index: _ } => {
399                                            if !current_tool_id.is_empty() {
400                                                let input: serde_json::Value = if current_tool_input
401                                                    .trim()
402                                                    .is_empty()
403                                                {
404                                                    serde_json::Value::Object(Default::default())
405                                                } else {
406                                                    serde_json::from_str(&current_tool_input)
407                                                        .unwrap_or_else(|e| {
408                                                            tracing::warn!(
409                                                                "Failed to parse tool input JSON for tool '{}': {}",
410                                                                current_tool_name, e
411                                                            );
412                                                            serde_json::json!({
413                                                                "__parse_error": format!(
414                                                                    "Malformed tool arguments: {}. Raw input: {}",
415                                                                    e, &current_tool_input
416                                                                )
417                                                            })
418                                                        })
419                                                };
420                                                content_blocks.push(ContentBlock::ToolUse {
421                                                    id: current_tool_id.clone(),
422                                                    name: current_tool_name.clone(),
423                                                    input,
424                                                });
425                                                current_tool_id.clear();
426                                                current_tool_name.clear();
427                                                current_tool_input.clear();
428                                            }
429                                        }
430                                        AnthropicStreamEvent::MessageStart { message } => {
431                                            response_id = message.id;
432                                            response_model = message.model;
433                                            response_object = message.message_type;
434                                            usage.prompt_tokens = message.usage.input_tokens;
435                                        }
436                                        AnthropicStreamEvent::MessageDelta {
437                                            delta,
438                                            usage: msg_usage,
439                                        } => {
440                                            stop_reason = Some(delta.stop_reason);
441                                            usage.completion_tokens = msg_usage.output_tokens;
442                                            usage.total_tokens =
443                                                usage.prompt_tokens + usage.completion_tokens;
444                                        }
445                                        AnthropicStreamEvent::MessageStop => {
446                                            crate::telemetry::record_llm_usage(
447                                                usage.prompt_tokens,
448                                                usage.completion_tokens,
449                                                usage.total_tokens,
450                                                stop_reason.as_deref(),
451                                            );
452
453                                            let response = LlmResponse {
454                                                message: Message {
455                                                    role: "assistant".to_string(),
456                                                    content: std::mem::take(&mut content_blocks),
457                                                    reasoning_content: None,
458                                                },
459                                                usage: usage.clone(),
460                                                stop_reason: stop_reason.clone(),
461                                                meta: Some(LlmResponseMeta {
462                                                    provider: Some(provider_name.clone()),
463                                                    request_model: Some(request_model.clone()),
464                                                    request_url: Some(request_url.clone()),
465                                                    response_id: response_id.clone(),
466                                                    response_model: response_model.clone(),
467                                                    response_object: response_object.clone(),
468                                                    first_token_ms,
469                                                    duration_ms: Some(
470                                                        request_started_at.elapsed().as_millis()
471                                                            as u64,
472                                                    ),
473                                                }),
474                                            };
475                                            let _ = tx.send(StreamEvent::Done(response)).await;
476                                        }
477                                        _ => {}
478                                    }
479                                }
480                            }
481                        }
482                    }
483                }
484            });
485
486            Ok(rx)
487        }
488    }
489}
490
491// Anthropic API response types (private)
492#[derive(Debug, Deserialize)]
493pub(crate) struct AnthropicResponse {
494    #[serde(default)]
495    pub(crate) id: Option<String>,
496    #[serde(default)]
497    pub(crate) model: Option<String>,
498    #[serde(rename = "type", default)]
499    pub(crate) response_type: Option<String>,
500    pub(crate) content: Vec<AnthropicContentBlock>,
501    pub(crate) stop_reason: String,
502    pub(crate) usage: AnthropicUsage,
503}
504
505#[derive(Debug, Deserialize)]
506#[serde(tag = "type")]
507pub(crate) enum AnthropicContentBlock {
508    #[serde(rename = "text")]
509    Text { text: String },
510    #[serde(rename = "tool_use")]
511    ToolUse {
512        id: String,
513        name: String,
514        input: serde_json::Value,
515    },
516}
517
518#[derive(Debug, Deserialize)]
519pub(crate) struct AnthropicUsage {
520    pub(crate) input_tokens: usize,
521    pub(crate) output_tokens: usize,
522    pub(crate) cache_read_input_tokens: Option<usize>,
523    pub(crate) cache_creation_input_tokens: Option<usize>,
524}
525
526#[derive(Debug, Deserialize)]
527#[serde(tag = "type")]
528#[allow(dead_code)]
529pub(crate) enum AnthropicStreamEvent {
530    #[serde(rename = "message_start")]
531    MessageStart { message: AnthropicMessageStart },
532    #[serde(rename = "content_block_start")]
533    ContentBlockStart {
534        index: usize,
535        content_block: AnthropicContentBlock,
536    },
537    #[serde(rename = "content_block_delta")]
538    ContentBlockDelta { index: usize, delta: AnthropicDelta },
539    #[serde(rename = "content_block_stop")]
540    ContentBlockStop { index: usize },
541    #[serde(rename = "message_delta")]
542    MessageDelta {
543        delta: AnthropicMessageDeltaData,
544        usage: AnthropicOutputUsage,
545    },
546    #[serde(rename = "message_stop")]
547    MessageStop,
548    #[serde(rename = "ping")]
549    Ping,
550    #[serde(rename = "error")]
551    Error { error: AnthropicError },
552}
553
554#[derive(Debug, Deserialize)]
555pub(crate) struct AnthropicMessageStart {
556    #[serde(default)]
557    pub(crate) id: Option<String>,
558    #[serde(default)]
559    pub(crate) model: Option<String>,
560    #[serde(rename = "type", default)]
561    pub(crate) message_type: Option<String>,
562    pub(crate) usage: AnthropicUsage,
563}
564
565#[derive(Debug, Deserialize)]
566#[serde(tag = "type")]
567pub(crate) enum AnthropicDelta {
568    #[serde(rename = "text_delta")]
569    TextDelta { text: String },
570    #[serde(rename = "input_json_delta")]
571    InputJsonDelta { partial_json: String },
572}
573
574#[derive(Debug, Deserialize)]
575pub(crate) struct AnthropicMessageDeltaData {
576    pub(crate) stop_reason: String,
577}
578
579#[derive(Debug, Deserialize)]
580pub(crate) struct AnthropicOutputUsage {
581    pub(crate) output_tokens: usize,
582}
583
584#[derive(Debug, Deserialize)]
585#[allow(dead_code)]
586pub(crate) struct AnthropicError {
587    #[serde(rename = "type")]
588    pub(crate) error_type: String,
589    pub(crate) message: String,
590}
591
592// ============================================================================
593// Tests
594// ============================================================================
595
596#[cfg(test)]
597mod tests {
598    use super::*;
599    use crate::llm::types::{Message, ToolDefinition};
600
601    fn make_client() -> AnthropicClient {
602        AnthropicClient::new("test-key".to_string(), "claude-opus-4-6".to_string())
603    }
604
605    #[test]
606    fn test_build_request_basic() {
607        let client = make_client();
608        let messages = vec![Message::user("Hello")];
609        let req = client.build_request(&messages, None, &[]);
610
611        assert_eq!(req["model"], "claude-opus-4-6");
612        assert_eq!(req["max_tokens"], DEFAULT_MAX_TOKENS);
613        assert!(req["thinking"].is_null());
614    }
615
616    #[test]
617    fn test_build_request_with_thinking_budget() {
618        let client = make_client().with_thinking_budget(10_000);
619        let messages = vec![Message::user("Think carefully.")];
620        let req = client.build_request(&messages, None, &[]);
621
622        // thinking block must be present
623        assert_eq!(req["thinking"]["type"], "enabled");
624        assert_eq!(req["thinking"]["budget_tokens"], 10_000);
625        // temperature must be 1.0 when thinking is enabled
626        assert_eq!(req["temperature"], 1.0_f64);
627    }
628
629    #[test]
630    fn test_build_request_thinking_overrides_temperature() {
631        // Even if temperature was set, thinking forces it to 1.0
632        let client = make_client()
633            .with_temperature(0.5)
634            .with_thinking_budget(5_000);
635        let messages = vec![Message::user("Test")];
636        let req = client.build_request(&messages, None, &[]);
637
638        assert_eq!(req["temperature"], 1.0_f64);
639        assert_eq!(req["thinking"]["budget_tokens"], 5_000);
640    }
641
642    #[test]
643    fn test_build_request_no_thinking_uses_temperature() {
644        let client = make_client().with_temperature(0.7);
645        let messages = vec![Message::user("Test")];
646        let req = client.build_request(&messages, None, &[]);
647
648        // Use approximate comparison for f64
649        let temp = req["temperature"].as_f64().unwrap();
650        assert!((temp - 0.7).abs() < 0.01);
651        assert!(req["thinking"].is_null());
652    }
653
654    #[test]
655    fn test_build_request_with_system_prompt() {
656        let client = make_client();
657        let messages = vec![Message::user("Hello")];
658        let req = client.build_request(&messages, Some("You are helpful."), &[]);
659
660        let system = &req["system"];
661        assert!(system.is_array());
662        assert_eq!(system[0]["type"], "text");
663        assert_eq!(system[0]["text"], "You are helpful.");
664        assert!(system[0]["cache_control"].is_object());
665    }
666
667    #[test]
668    fn test_build_request_with_tools() {
669        let client = make_client();
670        let messages = vec![Message::user("Use a tool")];
671        let tools = vec![ToolDefinition {
672            name: "read_file".to_string(),
673            description: "Read a file".to_string(),
674            parameters: serde_json::json!({"type": "object", "properties": {}}),
675        }];
676        let req = client.build_request(&messages, None, &tools);
677
678        assert!(req["tools"].is_array());
679        assert_eq!(req["tools"][0]["name"], "read_file");
680        // Last tool should have cache_control
681        assert!(req["tools"][0]["cache_control"].is_object());
682    }
683
684    #[test]
685    fn test_build_request_thinking_budget_sets_max_tokens() {
686        // max_tokens is still respected when thinking is enabled
687        let client = make_client()
688            .with_max_tokens(16_000)
689            .with_thinking_budget(8_000);
690        let messages = vec![Message::user("Test")];
691        let req = client.build_request(&messages, None, &[]);
692
693        assert_eq!(req["max_tokens"], 16_000);
694        assert_eq!(req["thinking"]["budget_tokens"], 8_000);
695    }
696}