Skip to main content

a3s_code_core/llm/
anthropic.rs

1//! Anthropic Claude LLM client
2
3use super::http::{default_http_client, normalize_base_url, HttpClient};
4use super::types::*;
5use super::LlmClient;
6use crate::retry::{AttemptOutcome, RetryConfig};
7use anyhow::{Context, Result};
8use async_trait::async_trait;
9use futures::StreamExt;
10use serde::Deserialize;
11use std::sync::Arc;
12use std::time::Instant;
13use tokio::sync::mpsc;
14
15/// Default max tokens for LLM responses
16pub(crate) const DEFAULT_MAX_TOKENS: usize = 8192;
17
18/// Anthropic Claude client
19pub struct AnthropicClient {
20    pub(crate) provider_name: String,
21    pub(crate) api_key: SecretString,
22    pub(crate) model: String,
23    pub(crate) base_url: String,
24    pub(crate) max_tokens: usize,
25    pub(crate) temperature: Option<f32>,
26    pub(crate) thinking_budget: Option<usize>,
27    pub(crate) http: Arc<dyn HttpClient>,
28    pub(crate) retry_config: RetryConfig,
29}
30
31impl AnthropicClient {
32    pub fn new(api_key: String, model: String) -> Self {
33        Self {
34            provider_name: "anthropic".to_string(),
35            api_key: SecretString::new(api_key),
36            model,
37            base_url: "https://api.anthropic.com".to_string(),
38            max_tokens: DEFAULT_MAX_TOKENS,
39            temperature: None,
40            thinking_budget: None,
41            http: default_http_client(),
42            retry_config: RetryConfig::default(),
43        }
44    }
45
46    pub fn with_base_url(mut self, base_url: String) -> Self {
47        self.base_url = normalize_base_url(&base_url);
48        self
49    }
50
51    pub fn with_provider_name(mut self, provider_name: impl Into<String>) -> Self {
52        self.provider_name = provider_name.into();
53        self
54    }
55
56    pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
57        self.max_tokens = max_tokens;
58        self
59    }
60
61    pub fn with_temperature(mut self, temperature: f32) -> Self {
62        self.temperature = Some(temperature);
63        self
64    }
65
66    pub fn with_thinking_budget(mut self, budget: usize) -> Self {
67        self.thinking_budget = Some(budget);
68        self
69    }
70
71    pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
72        self.retry_config = retry_config;
73        self
74    }
75
76    pub fn with_http_client(mut self, http: Arc<dyn HttpClient>) -> Self {
77        self.http = http;
78        self
79    }
80
81    fn initial_tool_input_json(input: &serde_json::Value) -> Option<String> {
82        match input {
83            serde_json::Value::Object(map) if map.is_empty() => None,
84            serde_json::Value::Null => None,
85            value => serde_json::to_string(value).ok(),
86        }
87    }
88
89    pub(crate) fn build_request(
90        &self,
91        messages: &[Message],
92        system: Option<&str>,
93        tools: &[ToolDefinition],
94    ) -> serde_json::Value {
95        let mut request = serde_json::json!({
96            "model": self.model,
97            "max_tokens": self.max_tokens,
98            "messages": messages,
99        });
100
101        // System prompt with cache_control for prompt caching.
102        // Anthropic caches system content blocks marked with
103        // `cache_control: { type: "ephemeral" }`.
104        if let Some(sys) = system {
105            request["system"] = serde_json::json!([
106                {
107                    "type": "text",
108                    "text": sys,
109                    "cache_control": { "type": "ephemeral" }
110                }
111            ]);
112        }
113
114        if !tools.is_empty() {
115            let mut tool_defs: Vec<serde_json::Value> = tools
116                .iter()
117                .map(|t| {
118                    serde_json::json!({
119                        "name": t.name,
120                        "description": t.description,
121                        "input_schema": t.parameters,
122                    })
123                })
124                .collect();
125
126            // Mark the last tool definition with cache_control so the
127            // entire tool block is cached on subsequent requests.
128            if let Some(last) = tool_defs.last_mut() {
129                last["cache_control"] = serde_json::json!({ "type": "ephemeral" });
130            }
131
132            request["tools"] = serde_json::json!(tool_defs);
133        }
134
135        // Apply optional sampling parameters
136        if let Some(temp) = self.temperature {
137            request["temperature"] = serde_json::json!(temp);
138        }
139
140        // Extended thinking (Anthropic-specific)
141        if let Some(budget) = self.thinking_budget {
142            request["thinking"] = serde_json::json!({
143                "type": "enabled",
144                "budget_tokens": budget
145            });
146            // Thinking requires temperature=1 per Anthropic docs
147            request["temperature"] = serde_json::json!(1.0);
148        }
149
150        request
151    }
152}
153
154#[async_trait]
155impl LlmClient for AnthropicClient {
156    async fn complete(
157        &self,
158        messages: &[Message],
159        system: Option<&str>,
160        tools: &[ToolDefinition],
161    ) -> Result<LlmResponse> {
162        {
163            let request_started_at = Instant::now();
164            let request_body = self.build_request(messages, system, tools);
165            let url = format!("{}/v1/messages", self.base_url);
166
167            let headers = vec![
168                ("x-api-key", self.api_key.expose()),
169                ("anthropic-version", "2023-06-01"),
170                ("anthropic-beta", "prompt-caching-2024-07-31"),
171            ];
172
173            let response = crate::retry::with_retry(&self.retry_config, |_attempt| {
174                let http = &self.http;
175                let url = &url;
176                let headers = headers.clone();
177                let request_body = &request_body;
178                async move {
179                    match http.post(url, headers, request_body).await {
180                        Ok(resp) => {
181                            let status = reqwest::StatusCode::from_u16(resp.status)
182                                .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
183                            if status.is_success() {
184                                AttemptOutcome::Success(resp.body)
185                            } else if self.retry_config.is_retryable_status(status) {
186                                AttemptOutcome::Retryable {
187                                    status,
188                                    body: resp.body,
189                                    retry_after: None,
190                                }
191                            } else {
192                                AttemptOutcome::Fatal(anyhow::anyhow!(
193                                    "Anthropic API error at {} ({}): {}",
194                                    url,
195                                    status,
196                                    resp.body
197                                ))
198                            }
199                        }
200                        Err(e) => AttemptOutcome::Fatal(e),
201                    }
202                }
203            })
204            .await?;
205
206            let parsed: AnthropicResponse =
207                serde_json::from_str(&response).context("Failed to parse Anthropic response")?;
208
209            tracing::debug!("Anthropic response: {:?}", parsed);
210
211            let content: Vec<ContentBlock> = parsed
212                .content
213                .into_iter()
214                .map(|block| match block {
215                    AnthropicContentBlock::Text { text } => ContentBlock::Text { text },
216                    AnthropicContentBlock::ToolUse { id, name, input } => {
217                        ContentBlock::ToolUse { id, name, input }
218                    }
219                })
220                .collect();
221
222            let llm_response = LlmResponse {
223                message: Message {
224                    role: "assistant".to_string(),
225                    content,
226                    reasoning_content: None,
227                },
228                usage: TokenUsage {
229                    prompt_tokens: parsed.usage.input_tokens,
230                    completion_tokens: parsed.usage.output_tokens,
231                    total_tokens: parsed.usage.input_tokens + parsed.usage.output_tokens,
232                    cache_read_tokens: parsed.usage.cache_read_input_tokens,
233                    cache_write_tokens: parsed.usage.cache_creation_input_tokens,
234                },
235                stop_reason: Some(parsed.stop_reason),
236                meta: Some(LlmResponseMeta {
237                    provider: Some(self.provider_name.clone()),
238                    request_model: Some(self.model.clone()),
239                    request_url: Some(url.clone()),
240                    response_id: parsed.id,
241                    response_model: parsed.model,
242                    response_object: parsed.response_type,
243                    first_token_ms: None,
244                    duration_ms: Some(request_started_at.elapsed().as_millis() as u64),
245                }),
246            };
247
248            crate::telemetry::record_llm_usage(
249                llm_response.usage.prompt_tokens,
250                llm_response.usage.completion_tokens,
251                llm_response.usage.total_tokens,
252                llm_response.stop_reason.as_deref(),
253            );
254
255            Ok(llm_response)
256        }
257    }
258
259    async fn complete_streaming(
260        &self,
261        messages: &[Message],
262        system: Option<&str>,
263        tools: &[ToolDefinition],
264    ) -> Result<mpsc::Receiver<StreamEvent>> {
265        {
266            let request_started_at = Instant::now();
267            let mut request_body = self.build_request(messages, system, tools);
268            request_body["stream"] = serde_json::json!(true);
269
270            let url = format!("{}/v1/messages", self.base_url);
271
272            let headers = vec![
273                ("x-api-key", self.api_key.expose()),
274                ("anthropic-version", "2023-06-01"),
275                ("anthropic-beta", "prompt-caching-2024-07-31"),
276            ];
277
278            let streaming_resp = crate::retry::with_retry(&self.retry_config, |_attempt| {
279                let http = &self.http;
280                let url = &url;
281                let headers = headers.clone();
282                let request_body = &request_body;
283                async move {
284                    match http.post_streaming(url, headers, request_body).await {
285                        Ok(resp) => {
286                            let status = reqwest::StatusCode::from_u16(resp.status)
287                                .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
288                            if status.is_success() {
289                                AttemptOutcome::Success(resp)
290                            } else {
291                                let retry_after = resp
292                                    .retry_after
293                                    .as_deref()
294                                    .and_then(|v| RetryConfig::parse_retry_after(Some(v)));
295                                if self.retry_config.is_retryable_status(status) {
296                                    AttemptOutcome::Retryable {
297                                        status,
298                                        body: resp.error_body,
299                                        retry_after,
300                                    }
301                                } else {
302                                    AttemptOutcome::Fatal(anyhow::anyhow!(
303                                        "Anthropic API error at {} ({}): {}",
304                                        url,
305                                        status,
306                                        resp.error_body
307                                    ))
308                                }
309                            }
310                        }
311                        Err(e) => AttemptOutcome::Fatal(anyhow::anyhow!(
312                            "Failed to send streaming request: {}",
313                            e
314                        )),
315                    }
316                }
317            })
318            .await?;
319
320            let (tx, rx) = mpsc::channel(100);
321
322            let mut stream = streaming_resp.byte_stream;
323            let provider_name = self.provider_name.clone();
324            let request_model = self.model.clone();
325            let request_url = url.clone();
326            tokio::spawn(async move {
327                let mut buffer = String::new();
328                let mut content_blocks: Vec<ContentBlock> = Vec::new();
329                let mut text_content = String::new();
330                let mut current_tool_id = String::new();
331                let mut current_tool_name = String::new();
332                let mut current_tool_input = String::new();
333                let mut usage = TokenUsage::default();
334                let mut stop_reason = None;
335                let mut response_id = None;
336                let mut response_model = None;
337                let mut response_object = Some("message".to_string());
338                let mut first_token_ms = None;
339
340                while let Some(chunk_result) = stream.next().await {
341                    let chunk = match chunk_result {
342                        Ok(c) => c,
343                        Err(e) => {
344                            tracing::error!("Stream error: {}", e);
345                            break;
346                        }
347                    };
348
349                    buffer.push_str(&String::from_utf8_lossy(&chunk));
350
351                    while let Some(event_end) = buffer.find("\n\n") {
352                        let event_data: String = buffer.drain(..event_end).collect();
353                        buffer.drain(..2);
354
355                        for line in event_data.lines() {
356                            if let Some(data) = line.strip_prefix("data: ") {
357                                if data == "[DONE]" {
358                                    continue;
359                                }
360
361                                if let Ok(event) =
362                                    serde_json::from_str::<AnthropicStreamEvent>(data)
363                                {
364                                    match event {
365                                        AnthropicStreamEvent::ContentBlockStart {
366                                            index: _,
367                                            content_block,
368                                        } => match content_block {
369                                            AnthropicContentBlock::Text { .. } => {}
370                                            AnthropicContentBlock::ToolUse { id, name, input } => {
371                                                if !text_content.is_empty() {
372                                                    content_blocks.push(ContentBlock::Text {
373                                                        text: std::mem::take(&mut text_content),
374                                                    });
375                                                }
376                                                current_tool_id = id.clone();
377                                                current_tool_name = name.clone();
378                                                current_tool_input =
379                                                    Self::initial_tool_input_json(&input)
380                                                        .unwrap_or_default();
381                                                let _ = tx
382                                                    .send(StreamEvent::ToolUseStart { id, name })
383                                                    .await;
384                                                if !current_tool_input.is_empty() {
385                                                    if first_token_ms.is_none() {
386                                                        first_token_ms = Some(
387                                                            request_started_at.elapsed().as_millis()
388                                                                as u64,
389                                                        );
390                                                    }
391                                                    let _ = tx
392                                                        .send(StreamEvent::ToolUseInputDelta(
393                                                            current_tool_input.clone(),
394                                                        ))
395                                                        .await;
396                                                }
397                                            }
398                                        },
399                                        AnthropicStreamEvent::ContentBlockDelta {
400                                            index: _,
401                                            delta,
402                                        } => match delta {
403                                            AnthropicDelta::TextDelta { text } => {
404                                                if first_token_ms.is_none() {
405                                                    first_token_ms = Some(
406                                                        request_started_at.elapsed().as_millis()
407                                                            as u64,
408                                                    );
409                                                }
410                                                text_content.push_str(&text);
411                                                let _ = tx.send(StreamEvent::TextDelta(text)).await;
412                                            }
413                                            AnthropicDelta::InputJsonDelta { partial_json } => {
414                                                if first_token_ms.is_none() {
415                                                    first_token_ms = Some(
416                                                        request_started_at.elapsed().as_millis()
417                                                            as u64,
418                                                    );
419                                                }
420                                                current_tool_input.push_str(&partial_json);
421                                                let _ = tx
422                                                    .send(StreamEvent::ToolUseInputDelta(
423                                                        partial_json,
424                                                    ))
425                                                    .await;
426                                            }
427                                        },
428                                        AnthropicStreamEvent::ContentBlockStop { index: _ } => {
429                                            if !current_tool_id.is_empty() {
430                                                let input: serde_json::Value = if current_tool_input
431                                                    .trim()
432                                                    .is_empty()
433                                                {
434                                                    serde_json::Value::Object(Default::default())
435                                                } else {
436                                                    serde_json::from_str(&current_tool_input)
437                                                        .unwrap_or_else(|e| {
438                                                            tracing::warn!(
439                                                                "Failed to parse tool input JSON for tool '{}': {}",
440                                                                current_tool_name, e
441                                                            );
442                                                            serde_json::json!({
443                                                                "__parse_error": format!(
444                                                                    "Malformed tool arguments: {}. Raw input: {}",
445                                                                    e, &current_tool_input
446                                                                )
447                                                            })
448                                                        })
449                                                };
450                                                content_blocks.push(ContentBlock::ToolUse {
451                                                    id: current_tool_id.clone(),
452                                                    name: current_tool_name.clone(),
453                                                    input,
454                                                });
455                                                current_tool_id.clear();
456                                                current_tool_name.clear();
457                                                current_tool_input.clear();
458                                            }
459                                        }
460                                        AnthropicStreamEvent::MessageStart { message } => {
461                                            response_id = message.id;
462                                            response_model = message.model;
463                                            response_object = message.message_type;
464                                            usage.prompt_tokens = message.usage.input_tokens;
465                                        }
466                                        AnthropicStreamEvent::MessageDelta {
467                                            delta,
468                                            usage: msg_usage,
469                                        } => {
470                                            stop_reason = Some(delta.stop_reason);
471                                            usage.completion_tokens = msg_usage.output_tokens;
472                                            usage.total_tokens =
473                                                usage.prompt_tokens + usage.completion_tokens;
474                                        }
475                                        AnthropicStreamEvent::MessageStop => {
476                                            if !text_content.is_empty() {
477                                                content_blocks.push(ContentBlock::Text {
478                                                    text: std::mem::take(&mut text_content),
479                                                });
480                                            }
481                                            crate::telemetry::record_llm_usage(
482                                                usage.prompt_tokens,
483                                                usage.completion_tokens,
484                                                usage.total_tokens,
485                                                stop_reason.as_deref(),
486                                            );
487
488                                            let response = LlmResponse {
489                                                message: Message {
490                                                    role: "assistant".to_string(),
491                                                    content: std::mem::take(&mut content_blocks),
492                                                    reasoning_content: None,
493                                                },
494                                                usage: usage.clone(),
495                                                stop_reason: stop_reason.clone(),
496                                                meta: Some(LlmResponseMeta {
497                                                    provider: Some(provider_name.clone()),
498                                                    request_model: Some(request_model.clone()),
499                                                    request_url: Some(request_url.clone()),
500                                                    response_id: response_id.clone(),
501                                                    response_model: response_model.clone(),
502                                                    response_object: response_object.clone(),
503                                                    first_token_ms,
504                                                    duration_ms: Some(
505                                                        request_started_at.elapsed().as_millis()
506                                                            as u64,
507                                                    ),
508                                                }),
509                                            };
510                                            let _ = tx.send(StreamEvent::Done(response)).await;
511                                        }
512                                        _ => {}
513                                    }
514                                }
515                            }
516                        }
517                    }
518                }
519            });
520
521            Ok(rx)
522        }
523    }
524}
525
526// Anthropic API response types (private)
527#[derive(Debug, Deserialize)]
528pub(crate) struct AnthropicResponse {
529    #[serde(default)]
530    pub(crate) id: Option<String>,
531    #[serde(default)]
532    pub(crate) model: Option<String>,
533    #[serde(rename = "type", default)]
534    pub(crate) response_type: Option<String>,
535    pub(crate) content: Vec<AnthropicContentBlock>,
536    pub(crate) stop_reason: String,
537    pub(crate) usage: AnthropicUsage,
538}
539
540#[derive(Debug, Deserialize)]
541#[serde(tag = "type")]
542pub(crate) enum AnthropicContentBlock {
543    #[serde(rename = "text")]
544    Text { text: String },
545    #[serde(rename = "tool_use")]
546    ToolUse {
547        id: String,
548        name: String,
549        input: serde_json::Value,
550    },
551}
552
553#[derive(Debug, Deserialize)]
554pub(crate) struct AnthropicUsage {
555    pub(crate) input_tokens: usize,
556    pub(crate) output_tokens: usize,
557    pub(crate) cache_read_input_tokens: Option<usize>,
558    pub(crate) cache_creation_input_tokens: Option<usize>,
559}
560
561#[derive(Debug, Deserialize)]
562#[serde(tag = "type")]
563#[allow(dead_code)]
564pub(crate) enum AnthropicStreamEvent {
565    #[serde(rename = "message_start")]
566    MessageStart { message: AnthropicMessageStart },
567    #[serde(rename = "content_block_start")]
568    ContentBlockStart {
569        index: usize,
570        content_block: AnthropicContentBlock,
571    },
572    #[serde(rename = "content_block_delta")]
573    ContentBlockDelta { index: usize, delta: AnthropicDelta },
574    #[serde(rename = "content_block_stop")]
575    ContentBlockStop { index: usize },
576    #[serde(rename = "message_delta")]
577    MessageDelta {
578        delta: AnthropicMessageDeltaData,
579        usage: AnthropicOutputUsage,
580    },
581    #[serde(rename = "message_stop")]
582    MessageStop,
583    #[serde(rename = "ping")]
584    Ping,
585    #[serde(rename = "error")]
586    Error { error: AnthropicError },
587}
588
589#[derive(Debug, Deserialize)]
590pub(crate) struct AnthropicMessageStart {
591    #[serde(default)]
592    pub(crate) id: Option<String>,
593    #[serde(default)]
594    pub(crate) model: Option<String>,
595    #[serde(rename = "type", default)]
596    pub(crate) message_type: Option<String>,
597    pub(crate) usage: AnthropicUsage,
598}
599
600#[derive(Debug, Deserialize)]
601#[serde(tag = "type")]
602pub(crate) enum AnthropicDelta {
603    #[serde(rename = "text_delta")]
604    TextDelta { text: String },
605    #[serde(rename = "input_json_delta")]
606    InputJsonDelta { partial_json: String },
607}
608
609#[derive(Debug, Deserialize)]
610pub(crate) struct AnthropicMessageDeltaData {
611    pub(crate) stop_reason: String,
612}
613
614#[derive(Debug, Deserialize)]
615pub(crate) struct AnthropicOutputUsage {
616    pub(crate) output_tokens: usize,
617}
618
619#[derive(Debug, Deserialize)]
620#[allow(dead_code)]
621pub(crate) struct AnthropicError {
622    #[serde(rename = "type")]
623    pub(crate) error_type: String,
624    pub(crate) message: String,
625}
626
627// ============================================================================
628// Tests
629// ============================================================================
630
631#[cfg(test)]
632mod tests {
633    use super::*;
634    use crate::llm::types::{Message, ToolDefinition};
635
636    fn make_client() -> AnthropicClient {
637        AnthropicClient::new("test-key".to_string(), "claude-opus-4-6".to_string())
638    }
639
640    #[test]
641    fn test_build_request_basic() {
642        let client = make_client();
643        let messages = vec![Message::user("Hello")];
644        let req = client.build_request(&messages, None, &[]);
645
646        assert_eq!(req["model"], "claude-opus-4-6");
647        assert_eq!(req["max_tokens"], DEFAULT_MAX_TOKENS);
648        assert!(req["thinking"].is_null());
649    }
650
651    #[test]
652    fn test_build_request_with_thinking_budget() {
653        let client = make_client().with_thinking_budget(10_000);
654        let messages = vec![Message::user("Think carefully.")];
655        let req = client.build_request(&messages, None, &[]);
656
657        // thinking block must be present
658        assert_eq!(req["thinking"]["type"], "enabled");
659        assert_eq!(req["thinking"]["budget_tokens"], 10_000);
660        // temperature must be 1.0 when thinking is enabled
661        assert_eq!(req["temperature"], 1.0_f64);
662    }
663
664    #[test]
665    fn test_build_request_thinking_overrides_temperature() {
666        // Even if temperature was set, thinking forces it to 1.0
667        let client = make_client()
668            .with_temperature(0.5)
669            .with_thinking_budget(5_000);
670        let messages = vec![Message::user("Test")];
671        let req = client.build_request(&messages, None, &[]);
672
673        assert_eq!(req["temperature"], 1.0_f64);
674        assert_eq!(req["thinking"]["budget_tokens"], 5_000);
675    }
676
677    #[test]
678    fn test_build_request_no_thinking_uses_temperature() {
679        let client = make_client().with_temperature(0.7);
680        let messages = vec![Message::user("Test")];
681        let req = client.build_request(&messages, None, &[]);
682
683        // Use approximate comparison for f64
684        let temp = req["temperature"].as_f64().unwrap();
685        assert!((temp - 0.7).abs() < 0.01);
686        assert!(req["thinking"].is_null());
687    }
688
689    #[test]
690    fn test_build_request_with_system_prompt() {
691        let client = make_client();
692        let messages = vec![Message::user("Hello")];
693        let req = client.build_request(&messages, Some("You are helpful."), &[]);
694
695        let system = &req["system"];
696        assert!(system.is_array());
697        assert_eq!(system[0]["type"], "text");
698        assert_eq!(system[0]["text"], "You are helpful.");
699        assert!(system[0]["cache_control"].is_object());
700    }
701
702    #[test]
703    fn test_build_request_with_tools() {
704        let client = make_client();
705        let messages = vec![Message::user("Use a tool")];
706        let tools = vec![ToolDefinition {
707            name: "read_file".to_string(),
708            description: "Read a file".to_string(),
709            parameters: serde_json::json!({"type": "object", "properties": {}}),
710        }];
711        let req = client.build_request(&messages, None, &tools);
712
713        assert!(req["tools"].is_array());
714        assert_eq!(req["tools"][0]["name"], "read_file");
715        // Last tool should have cache_control
716        assert!(req["tools"][0]["cache_control"].is_object());
717    }
718
719    #[test]
720    fn test_build_request_thinking_budget_sets_max_tokens() {
721        // max_tokens is still respected when thinking is enabled
722        let client = make_client()
723            .with_max_tokens(16_000)
724            .with_thinking_budget(8_000);
725        let messages = vec![Message::user("Test")];
726        let req = client.build_request(&messages, None, &[]);
727
728        assert_eq!(req["max_tokens"], 16_000);
729        assert_eq!(req["thinking"]["budget_tokens"], 8_000);
730    }
731}