1use super::http::{default_http_client, normalize_base_url, HttpClient};
4use super::types::*;
5use super::LlmClient;
6use crate::retry::{AttemptOutcome, RetryConfig};
7use anyhow::{Context, Result};
8use async_trait::async_trait;
9use futures::StreamExt;
10use serde::Deserialize;
11use std::sync::Arc;
12use std::time::Instant;
13use tokio::sync::mpsc;
14use tokio_util::sync::CancellationToken;
15
16pub(crate) const DEFAULT_MAX_TOKENS: usize = 8192;
18
19pub struct AnthropicClient {
21 pub(crate) provider_name: String,
22 pub(crate) api_key: SecretString,
23 pub(crate) model: String,
24 pub(crate) base_url: String,
25 pub(crate) max_tokens: usize,
26 pub(crate) temperature: Option<f32>,
27 pub(crate) thinking_budget: Option<usize>,
28 pub(crate) http: Arc<dyn HttpClient>,
29 pub(crate) retry_config: RetryConfig,
30}
31
32impl AnthropicClient {
33 pub fn new(api_key: String, model: String) -> Self {
34 Self {
35 provider_name: "anthropic".to_string(),
36 api_key: SecretString::new(api_key),
37 model,
38 base_url: "https://api.anthropic.com".to_string(),
39 max_tokens: DEFAULT_MAX_TOKENS,
40 temperature: None,
41 thinking_budget: None,
42 http: default_http_client(),
43 retry_config: RetryConfig::default(),
44 }
45 }
46
47 pub fn with_base_url(mut self, base_url: String) -> Self {
48 self.base_url = normalize_base_url(&base_url);
49 self
50 }
51
52 pub fn with_provider_name(mut self, provider_name: impl Into<String>) -> Self {
53 self.provider_name = provider_name.into();
54 self
55 }
56
57 pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
58 self.max_tokens = max_tokens;
59 self
60 }
61
62 pub fn with_temperature(mut self, temperature: f32) -> Self {
63 self.temperature = Some(temperature);
64 self
65 }
66
67 pub fn with_thinking_budget(mut self, budget: usize) -> Self {
68 self.thinking_budget = Some(budget);
69 self
70 }
71
72 pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
73 self.retry_config = retry_config;
74 self
75 }
76
77 pub fn with_http_client(mut self, http: Arc<dyn HttpClient>) -> Self {
78 self.http = http;
79 self
80 }
81
82 fn initial_tool_input_json(input: &serde_json::Value) -> Option<String> {
83 match input {
84 serde_json::Value::Object(map) if map.is_empty() => None,
85 serde_json::Value::Null => None,
86 value => serde_json::to_string(value).ok(),
87 }
88 }
89
90 pub(crate) fn build_request(
91 &self,
92 messages: &[Message],
93 system: Option<&str>,
94 tools: &[ToolDefinition],
95 ) -> serde_json::Value {
96 let mut request = serde_json::json!({
97 "model": self.model,
98 "max_tokens": self.max_tokens,
99 "messages": messages,
100 });
101
102 if let Some(sys) = system {
106 request["system"] = serde_json::json!([
107 {
108 "type": "text",
109 "text": sys,
110 "cache_control": { "type": "ephemeral" }
111 }
112 ]);
113 }
114
115 if !tools.is_empty() {
116 let mut tool_defs: Vec<serde_json::Value> = tools
117 .iter()
118 .map(|t| {
119 serde_json::json!({
120 "name": t.name,
121 "description": t.description,
122 "input_schema": t.parameters,
123 })
124 })
125 .collect();
126
127 if let Some(last) = tool_defs.last_mut() {
130 last["cache_control"] = serde_json::json!({ "type": "ephemeral" });
131 }
132
133 request["tools"] = serde_json::json!(tool_defs);
134 }
135
136 if let Some(temp) = self.temperature {
138 request["temperature"] = serde_json::json!(temp);
139 }
140
141 if let Some(budget) = self.thinking_budget {
143 request["thinking"] = serde_json::json!({
144 "type": "enabled",
145 "budget_tokens": budget
146 });
147 request["temperature"] = serde_json::json!(1.0);
149 }
150
151 request
152 }
153}
154
155#[async_trait]
156impl LlmClient for AnthropicClient {
157 async fn complete(
158 &self,
159 messages: &[Message],
160 system: Option<&str>,
161 tools: &[ToolDefinition],
162 ) -> Result<LlmResponse> {
163 {
164 let request_started_at = Instant::now();
165 let request_body = self.build_request(messages, system, tools);
166 let url = format!("{}/v1/messages", self.base_url);
167
168 let headers = vec![
169 ("x-api-key", self.api_key.expose()),
170 ("anthropic-version", "2023-06-01"),
171 ("anthropic-beta", "prompt-caching-2024-07-31"),
172 ];
173
174 let response = crate::retry::with_retry(&self.retry_config, |_attempt| {
175 let http = &self.http;
176 let url = &url;
177 let headers = headers.clone();
178 let request_body = &request_body;
179 async move {
180 match http
181 .post(url, headers, request_body, CancellationToken::new())
182 .await
183 {
184 Ok(resp) => {
185 let status = reqwest::StatusCode::from_u16(resp.status)
186 .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
187 if status.is_success() {
188 AttemptOutcome::Success(resp.body)
189 } else if self.retry_config.is_retryable_status(status) {
190 AttemptOutcome::Retryable {
191 status,
192 body: resp.body,
193 retry_after: None,
194 }
195 } else {
196 AttemptOutcome::Fatal(anyhow::anyhow!(
197 "Anthropic API error at {} ({}): {}",
198 url,
199 status,
200 resp.body
201 ))
202 }
203 }
204 Err(e) => AttemptOutcome::Fatal(e),
205 }
206 }
207 })
208 .await?;
209
210 let parsed: AnthropicResponse =
211 serde_json::from_str(&response).context("Failed to parse Anthropic response")?;
212
213 tracing::debug!("Anthropic response: {:?}", parsed);
214
215 let content: Vec<ContentBlock> = parsed
216 .content
217 .into_iter()
218 .map(|block| match block {
219 AnthropicContentBlock::Text { text } => ContentBlock::Text { text },
220 AnthropicContentBlock::ToolUse { id, name, input } => {
221 ContentBlock::ToolUse { id, name, input }
222 }
223 })
224 .collect();
225
226 let llm_response = LlmResponse {
227 message: Message {
228 role: "assistant".to_string(),
229 content,
230 reasoning_content: None,
231 },
232 usage: TokenUsage {
233 prompt_tokens: parsed.usage.input_tokens,
234 completion_tokens: parsed.usage.output_tokens,
235 total_tokens: parsed.usage.input_tokens + parsed.usage.output_tokens,
236 cache_read_tokens: parsed.usage.cache_read_input_tokens,
237 cache_write_tokens: parsed.usage.cache_creation_input_tokens,
238 },
239 stop_reason: Some(parsed.stop_reason),
240 meta: Some(LlmResponseMeta {
241 provider: Some(self.provider_name.clone()),
242 request_model: Some(self.model.clone()),
243 request_url: Some(url.clone()),
244 response_id: parsed.id,
245 response_model: parsed.model,
246 response_object: parsed.response_type,
247 first_token_ms: None,
248 duration_ms: Some(request_started_at.elapsed().as_millis() as u64),
249 }),
250 };
251
252 crate::telemetry::record_llm_usage(
253 llm_response.usage.prompt_tokens,
254 llm_response.usage.completion_tokens,
255 llm_response.usage.total_tokens,
256 llm_response.stop_reason.as_deref(),
257 );
258
259 Ok(llm_response)
260 }
261 }
262
263 async fn complete_streaming(
264 &self,
265 messages: &[Message],
266 system: Option<&str>,
267 tools: &[ToolDefinition],
268 cancel_token: CancellationToken,
269 ) -> Result<mpsc::Receiver<StreamEvent>> {
270 {
271 let request_started_at = Instant::now();
272 let mut request_body = self.build_request(messages, system, tools);
273 request_body["stream"] = serde_json::json!(true);
274
275 let url = format!("{}/v1/messages", self.base_url);
276
277 let headers = vec![
278 ("x-api-key", self.api_key.expose()),
279 ("anthropic-version", "2023-06-01"),
280 ("anthropic-beta", "prompt-caching-2024-07-31"),
281 ];
282
283 let streaming_resp = crate::retry::with_retry(&self.retry_config, |_attempt| {
284 let http = &self.http;
285 let url = &url;
286 let headers = headers.clone();
287 let request_body = &request_body;
288 let cancel_token = cancel_token.clone();
289 async move {
290 let resp = tokio::select! {
291 _ = cancel_token.cancelled() => {
292 return AttemptOutcome::Fatal(anyhow::anyhow!("HTTP request cancelled"));
293 }
294 result = http.post_streaming(url, headers, request_body, cancel_token.clone()) => {
295 match result {
296 Ok(r) => r,
297 Err(e) => {
298 return AttemptOutcome::Fatal(anyhow::anyhow!("HTTP request failed: {}", e));
299 }
300 }
301 }
302 };
303 let status = reqwest::StatusCode::from_u16(resp.status)
304 .unwrap_or(reqwest::StatusCode::INTERNAL_SERVER_ERROR);
305 if status.is_success() {
306 AttemptOutcome::Success(resp)
307 } else {
308 let retry_after = resp
309 .retry_after
310 .as_deref()
311 .and_then(|v| RetryConfig::parse_retry_after(Some(v)));
312 if self.retry_config.is_retryable_status(status) {
313 AttemptOutcome::Retryable {
314 status,
315 body: resp.error_body,
316 retry_after,
317 }
318 } else {
319 AttemptOutcome::Fatal(anyhow::anyhow!(
320 "Anthropic API error at {} ({}): {}",
321 url,
322 status,
323 resp.error_body
324 ))
325 }
326 }
327 }
328 })
329 .await?;
330
331 let (tx, rx) = mpsc::channel(100);
332
333 let mut stream = streaming_resp.byte_stream;
334 let provider_name = self.provider_name.clone();
335 let request_model = self.model.clone();
336 let request_url = url.clone();
337 tokio::spawn(async move {
338 let mut buffer = String::new();
339 let mut content_blocks: Vec<ContentBlock> = Vec::new();
340 let mut text_content = String::new();
341 let mut current_tool_id = String::new();
342 let mut current_tool_name = String::new();
343 let mut current_tool_input = String::new();
344 let mut usage = TokenUsage::default();
345 let mut stop_reason = None;
346 let mut response_id = None;
347 let mut response_model = None;
348 let mut response_object = Some("message".to_string());
349 let mut first_token_ms = None;
350
351 while let Some(chunk_result) = stream.next().await {
352 let chunk = match chunk_result {
353 Ok(c) => c,
354 Err(e) => {
355 tracing::error!("Stream error: {}", e);
356 break;
357 }
358 };
359
360 buffer.push_str(&String::from_utf8_lossy(&chunk));
361
362 while let Some(event_end) = buffer.find("\n\n") {
363 let event_data: String = buffer.drain(..event_end).collect();
364 buffer.drain(..2);
365
366 for line in event_data.lines() {
367 if let Some(data) = line.strip_prefix("data: ") {
368 if data == "[DONE]" {
369 continue;
370 }
371
372 if let Ok(event) =
373 serde_json::from_str::<AnthropicStreamEvent>(data)
374 {
375 match event {
376 AnthropicStreamEvent::ContentBlockStart {
377 index: _,
378 content_block,
379 } => match content_block {
380 AnthropicContentBlock::Text { .. } => {}
381 AnthropicContentBlock::ToolUse { id, name, input } => {
382 if !text_content.is_empty() {
383 content_blocks.push(ContentBlock::Text {
384 text: std::mem::take(&mut text_content),
385 });
386 }
387 current_tool_id = id.clone();
388 current_tool_name = name.clone();
389 current_tool_input =
390 Self::initial_tool_input_json(&input)
391 .unwrap_or_default();
392 let _ = tx
393 .send(StreamEvent::ToolUseStart { id, name })
394 .await;
395 if !current_tool_input.is_empty() {
396 if first_token_ms.is_none() {
397 first_token_ms = Some(
398 request_started_at.elapsed().as_millis()
399 as u64,
400 );
401 }
402 let _ = tx
403 .send(StreamEvent::ToolUseInputDelta(
404 current_tool_input.clone(),
405 ))
406 .await;
407 }
408 }
409 },
410 AnthropicStreamEvent::ContentBlockDelta {
411 index: _,
412 delta,
413 } => match delta {
414 AnthropicDelta::TextDelta { text } => {
415 if first_token_ms.is_none() {
416 first_token_ms = Some(
417 request_started_at.elapsed().as_millis()
418 as u64,
419 );
420 }
421 text_content.push_str(&text);
422 let _ = tx.send(StreamEvent::TextDelta(text)).await;
423 }
424 AnthropicDelta::InputJsonDelta { partial_json } => {
425 if first_token_ms.is_none() {
426 first_token_ms = Some(
427 request_started_at.elapsed().as_millis()
428 as u64,
429 );
430 }
431 current_tool_input.push_str(&partial_json);
432 let _ = tx
433 .send(StreamEvent::ToolUseInputDelta(
434 partial_json,
435 ))
436 .await;
437 }
438 },
439 AnthropicStreamEvent::ContentBlockStop { index: _ }
440 if !current_tool_id.is_empty() =>
441 {
442 let input: serde_json::Value = if current_tool_input
443 .trim()
444 .is_empty()
445 {
446 serde_json::Value::Object(Default::default())
447 } else {
448 serde_json::from_str(¤t_tool_input)
449 .unwrap_or_else(|e| {
450 tracing::warn!(
451 "Failed to parse tool input JSON for tool '{}': {}",
452 current_tool_name, e
453 );
454 serde_json::json!({
455 "__parse_error": format!(
456 "Malformed tool arguments: {}. Raw input: {}",
457 e, ¤t_tool_input
458 )
459 })
460 })
461 };
462 content_blocks.push(ContentBlock::ToolUse {
463 id: current_tool_id.clone(),
464 name: current_tool_name.clone(),
465 input,
466 });
467 current_tool_id.clear();
468 current_tool_name.clear();
469 current_tool_input.clear();
470 }
471 AnthropicStreamEvent::MessageStart { message } => {
472 response_id = message.id;
473 response_model = message.model;
474 response_object = message.message_type;
475 usage.prompt_tokens = message.usage.input_tokens;
476 }
477 AnthropicStreamEvent::MessageDelta {
478 delta,
479 usage: msg_usage,
480 } => {
481 stop_reason = Some(delta.stop_reason);
482 usage.completion_tokens = msg_usage.output_tokens;
483 usage.total_tokens =
484 usage.prompt_tokens + usage.completion_tokens;
485 }
486 AnthropicStreamEvent::MessageStop => {
487 if !text_content.is_empty() {
488 content_blocks.push(ContentBlock::Text {
489 text: std::mem::take(&mut text_content),
490 });
491 }
492 crate::telemetry::record_llm_usage(
493 usage.prompt_tokens,
494 usage.completion_tokens,
495 usage.total_tokens,
496 stop_reason.as_deref(),
497 );
498
499 let response = LlmResponse {
500 message: Message {
501 role: "assistant".to_string(),
502 content: std::mem::take(&mut content_blocks),
503 reasoning_content: None,
504 },
505 usage: usage.clone(),
506 stop_reason: stop_reason.clone(),
507 meta: Some(LlmResponseMeta {
508 provider: Some(provider_name.clone()),
509 request_model: Some(request_model.clone()),
510 request_url: Some(request_url.clone()),
511 response_id: response_id.clone(),
512 response_model: response_model.clone(),
513 response_object: response_object.clone(),
514 first_token_ms,
515 duration_ms: Some(
516 request_started_at.elapsed().as_millis()
517 as u64,
518 ),
519 }),
520 };
521 let _ = tx.send(StreamEvent::Done(response)).await;
522 }
523 _ => {}
524 }
525 }
526 }
527 }
528 }
529 }
530 });
531
532 Ok(rx)
533 }
534 }
535}
536
537#[derive(Debug, Deserialize)]
539pub(crate) struct AnthropicResponse {
540 #[serde(default)]
541 pub(crate) id: Option<String>,
542 #[serde(default)]
543 pub(crate) model: Option<String>,
544 #[serde(rename = "type", default)]
545 pub(crate) response_type: Option<String>,
546 pub(crate) content: Vec<AnthropicContentBlock>,
547 pub(crate) stop_reason: String,
548 pub(crate) usage: AnthropicUsage,
549}
550
551#[derive(Debug, Deserialize)]
552#[serde(tag = "type")]
553pub(crate) enum AnthropicContentBlock {
554 #[serde(rename = "text")]
555 Text { text: String },
556 #[serde(rename = "tool_use")]
557 ToolUse {
558 id: String,
559 name: String,
560 input: serde_json::Value,
561 },
562}
563
564#[derive(Debug, Deserialize)]
565pub(crate) struct AnthropicUsage {
566 pub(crate) input_tokens: usize,
567 pub(crate) output_tokens: usize,
568 pub(crate) cache_read_input_tokens: Option<usize>,
569 pub(crate) cache_creation_input_tokens: Option<usize>,
570}
571
572#[derive(Debug, Deserialize)]
573#[serde(tag = "type")]
574#[allow(dead_code)]
575pub(crate) enum AnthropicStreamEvent {
576 #[serde(rename = "message_start")]
577 MessageStart { message: AnthropicMessageStart },
578 #[serde(rename = "content_block_start")]
579 ContentBlockStart {
580 index: usize,
581 content_block: AnthropicContentBlock,
582 },
583 #[serde(rename = "content_block_delta")]
584 ContentBlockDelta { index: usize, delta: AnthropicDelta },
585 #[serde(rename = "content_block_stop")]
586 ContentBlockStop { index: usize },
587 #[serde(rename = "message_delta")]
588 MessageDelta {
589 delta: AnthropicMessageDeltaData,
590 usage: AnthropicOutputUsage,
591 },
592 #[serde(rename = "message_stop")]
593 MessageStop,
594 #[serde(rename = "ping")]
595 Ping,
596 #[serde(rename = "error")]
597 Error { error: AnthropicError },
598}
599
600#[derive(Debug, Deserialize)]
601pub(crate) struct AnthropicMessageStart {
602 #[serde(default)]
603 pub(crate) id: Option<String>,
604 #[serde(default)]
605 pub(crate) model: Option<String>,
606 #[serde(rename = "type", default)]
607 pub(crate) message_type: Option<String>,
608 pub(crate) usage: AnthropicUsage,
609}
610
611#[derive(Debug, Deserialize)]
612#[serde(tag = "type")]
613pub(crate) enum AnthropicDelta {
614 #[serde(rename = "text_delta")]
615 TextDelta { text: String },
616 #[serde(rename = "input_json_delta")]
617 InputJsonDelta { partial_json: String },
618}
619
620#[derive(Debug, Deserialize)]
621pub(crate) struct AnthropicMessageDeltaData {
622 pub(crate) stop_reason: String,
623}
624
625#[derive(Debug, Deserialize)]
626pub(crate) struct AnthropicOutputUsage {
627 pub(crate) output_tokens: usize,
628}
629
630#[derive(Debug, Deserialize)]
631#[allow(dead_code)]
632pub(crate) struct AnthropicError {
633 #[serde(rename = "type")]
634 pub(crate) error_type: String,
635 pub(crate) message: String,
636}
637
638#[cfg(test)]
643mod tests {
644 use super::*;
645 use crate::llm::types::{Message, ToolDefinition};
646
647 fn make_client() -> AnthropicClient {
648 AnthropicClient::new("test-key".to_string(), "claude-opus-4-6".to_string())
649 }
650
651 #[test]
652 fn test_build_request_basic() {
653 let client = make_client();
654 let messages = vec![Message::user("Hello")];
655 let req = client.build_request(&messages, None, &[]);
656
657 assert_eq!(req["model"], "claude-opus-4-6");
658 assert_eq!(req["max_tokens"], DEFAULT_MAX_TOKENS);
659 assert!(req["thinking"].is_null());
660 }
661
662 #[test]
663 fn test_build_request_with_thinking_budget() {
664 let client = make_client().with_thinking_budget(10_000);
665 let messages = vec![Message::user("Think carefully.")];
666 let req = client.build_request(&messages, None, &[]);
667
668 assert_eq!(req["thinking"]["type"], "enabled");
670 assert_eq!(req["thinking"]["budget_tokens"], 10_000);
671 assert_eq!(req["temperature"], 1.0_f64);
673 }
674
675 #[test]
676 fn test_build_request_thinking_overrides_temperature() {
677 let client = make_client()
679 .with_temperature(0.5)
680 .with_thinking_budget(5_000);
681 let messages = vec![Message::user("Test")];
682 let req = client.build_request(&messages, None, &[]);
683
684 assert_eq!(req["temperature"], 1.0_f64);
685 assert_eq!(req["thinking"]["budget_tokens"], 5_000);
686 }
687
688 #[test]
689 fn test_build_request_no_thinking_uses_temperature() {
690 let client = make_client().with_temperature(0.7);
691 let messages = vec![Message::user("Test")];
692 let req = client.build_request(&messages, None, &[]);
693
694 let temp = req["temperature"].as_f64().unwrap();
696 assert!((temp - 0.7).abs() < 0.01);
697 assert!(req["thinking"].is_null());
698 }
699
700 #[test]
701 fn test_build_request_with_system_prompt() {
702 let client = make_client();
703 let messages = vec![Message::user("Hello")];
704 let req = client.build_request(&messages, Some("You are helpful."), &[]);
705
706 let system = &req["system"];
707 assert!(system.is_array());
708 assert_eq!(system[0]["type"], "text");
709 assert_eq!(system[0]["text"], "You are helpful.");
710 assert!(system[0]["cache_control"].is_object());
711 }
712
713 #[test]
714 fn test_build_request_with_tools() {
715 let client = make_client();
716 let messages = vec![Message::user("Use a tool")];
717 let tools = vec![ToolDefinition {
718 name: "read_file".to_string(),
719 description: "Read a file".to_string(),
720 parameters: serde_json::json!({"type": "object", "properties": {}}),
721 }];
722 let req = client.build_request(&messages, None, &tools);
723
724 assert!(req["tools"].is_array());
725 assert_eq!(req["tools"][0]["name"], "read_file");
726 assert!(req["tools"][0]["cache_control"].is_object());
728 }
729
730 #[test]
731 fn test_build_request_thinking_budget_sets_max_tokens() {
732 let client = make_client()
734 .with_max_tokens(16_000)
735 .with_thinking_budget(8_000);
736 let messages = vec![Message::user("Test")];
737 let req = client.build_request(&messages, None, &[]);
738
739 assert_eq!(req["max_tokens"], 16_000);
740 assert_eq!(req["thinking"]["budget_tokens"], 8_000);
741 }
742}