Skip to main content

codetether_agent/provider/
anthropic.rs

1//! Anthropic provider implementation using the Messages API
2//!
3//! Supports Claude Sonnet 4, Claude Opus 4, and other Claude models.
4//! Uses the native Anthropic API format (not OpenAI-compatible).
5//! Reference: https://docs.anthropic.com/en/api/messages
6
7use super::{
8    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
9    Role, StreamChunk, ToolDefinition, Usage,
10};
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use serde::Deserialize;
15use serde_json::{Value, json};
16
17const ANTHROPIC_API_BASE: &str = "https://api.anthropic.com";
18const ANTHROPIC_VERSION: &str = "2023-06-01";
19
20pub struct AnthropicProvider {
21    client: Client,
22    api_key: String,
23    base_url: String,
24    provider_name: String,
25    enable_prompt_caching: bool,
26}
27
28impl std::fmt::Debug for AnthropicProvider {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        f.debug_struct("AnthropicProvider")
31            .field("api_key", &"<REDACTED>")
32            .field("api_key_len", &self.api_key.len())
33            .field("base_url", &self.base_url)
34            .field("provider_name", &self.provider_name)
35            .field("enable_prompt_caching", &self.enable_prompt_caching)
36            .finish()
37    }
38}
39
40impl AnthropicProvider {
41    pub fn new(api_key: String) -> Result<Self> {
42        Self::with_base_url(api_key, ANTHROPIC_API_BASE.to_string(), "anthropic")
43    }
44
45    pub fn with_base_url(
46        api_key: String,
47        base_url: String,
48        provider_name: impl Into<String>,
49    ) -> Result<Self> {
50        let provider_name = provider_name.into();
51        let enable_prompt_caching = std::env::var("CODETETHER_ANTHROPIC_PROMPT_CACHING")
52            .ok()
53            .and_then(|v| parse_bool_env(&v))
54            .unwrap_or_else(|| provider_name.eq_ignore_ascii_case("minimax"));
55
56        tracing::debug!(
57            provider = %provider_name,
58            api_key_len = api_key.len(),
59            base_url = %base_url,
60            enable_prompt_caching,
61            "Creating Anthropic provider"
62        );
63
64        Ok(Self {
65            client: Client::new(),
66            api_key,
67            base_url,
68            provider_name,
69            enable_prompt_caching,
70        })
71    }
72
73    fn validate_api_key(&self) -> Result<()> {
74        if self.api_key.is_empty() {
75            anyhow::bail!("Anthropic API key is empty");
76        }
77        Ok(())
78    }
79
80    /// Convert our generic messages to Anthropic Messages API format.
81    ///
82    /// Anthropic uses a different format:
83    /// - system prompt is a top-level field, not a message
84    /// - tool results go in user messages with type "tool_result"
85    /// - tool calls appear in assistant messages with type "tool_use"
86    fn convert_messages(
87        messages: &[Message],
88        enable_prompt_caching: bool,
89    ) -> (Option<Vec<Value>>, Vec<Value>) {
90        let mut system_blocks: Vec<Value> = Vec::new();
91        let mut api_messages: Vec<Value> = Vec::new();
92
93        for msg in messages {
94            match msg.role {
95                Role::System => {
96                    for part in &msg.content {
97                        match part {
98                            ContentPart::Text { text } => {
99                                system_blocks.push(json!({
100                                    "type": "text",
101                                    "text": text,
102                                }));
103                            }
104                            ContentPart::Thinking { text } => {
105                                system_blocks.push(json!({
106                                    "type": "thinking",
107                                    "thinking": text,
108                                }));
109                            }
110                            _ => {}
111                        }
112                    }
113                }
114                Role::User => {
115                    let mut content_parts: Vec<Value> = Vec::new();
116                    for part in &msg.content {
117                        match part {
118                            ContentPart::Text { text } => {
119                                content_parts.push(json!({
120                                    "type": "text",
121                                    "text": text,
122                                }));
123                            }
124                            ContentPart::Thinking { text } => {
125                                content_parts.push(json!({
126                                    "type": "thinking",
127                                    "thinking": text,
128                                }));
129                            }
130                            _ => {}
131                        }
132                    }
133                    if content_parts.is_empty() {
134                        content_parts.push(json!({"type": "text", "text": " "}));
135                    }
136                    api_messages.push(json!({
137                        "role": "user",
138                        "content": content_parts
139                    }));
140                }
141                Role::Assistant => {
142                    let mut content_parts: Vec<Value> = Vec::new();
143
144                    for part in &msg.content {
145                        match part {
146                            ContentPart::Text { text } => {
147                                content_parts.push(json!({
148                                    "type": "text",
149                                    "text": text
150                                }));
151                            }
152                            ContentPart::Thinking { text } => {
153                                content_parts.push(json!({
154                                    "type": "thinking",
155                                    "thinking": text
156                                }));
157                            }
158                            ContentPart::ToolCall {
159                                id,
160                                name,
161                                arguments,
162                            } => {
163                                let input: Value = serde_json::from_str(arguments)
164                                    .unwrap_or_else(|_| json!({"raw": arguments}));
165                                content_parts.push(json!({
166                                    "type": "tool_use",
167                                    "id": id,
168                                    "name": name,
169                                    "input": input
170                                }));
171                            }
172                            _ => {}
173                        }
174                    }
175
176                    if content_parts.is_empty() {
177                        content_parts.push(json!({"type": "text", "text": " "}));
178                    }
179
180                    api_messages.push(json!({
181                        "role": "assistant",
182                        "content": content_parts
183                    }));
184                }
185                Role::Tool => {
186                    let mut tool_results: Vec<Value> = Vec::new();
187                    for part in &msg.content {
188                        if let ContentPart::ToolResult {
189                            tool_call_id,
190                            content,
191                        } = part
192                        {
193                            tool_results.push(json!({
194                                "type": "tool_result",
195                                "tool_use_id": tool_call_id,
196                                "content": content
197                            }));
198                        }
199                    }
200                    if !tool_results.is_empty() {
201                        api_messages.push(json!({
202                            "role": "user",
203                            "content": tool_results
204                        }));
205                    }
206                }
207            }
208        }
209
210        if enable_prompt_caching {
211            if let Some(last_tool_or_text_msg) = api_messages.iter_mut().rev().find_map(|msg| {
212                msg.get_mut("content")
213                    .and_then(Value::as_array_mut)
214                    .and_then(|parts| parts.last_mut())
215            }) {
216                Self::add_ephemeral_cache_control(last_tool_or_text_msg);
217            }
218            if let Some(last_system) = system_blocks.last_mut() {
219                Self::add_ephemeral_cache_control(last_system);
220            }
221        }
222
223        let system = if system_blocks.is_empty() {
224            None
225        } else {
226            Some(system_blocks)
227        };
228
229        (system, api_messages)
230    }
231
232    fn convert_tools(tools: &[ToolDefinition], enable_prompt_caching: bool) -> Vec<Value> {
233        let mut converted: Vec<Value> = tools
234            .iter()
235            .map(|t| {
236                json!({
237                    "name": t.name,
238                    "description": t.description,
239                    "input_schema": t.parameters
240                })
241            })
242            .collect();
243
244        if enable_prompt_caching && let Some(last_tool) = converted.last_mut() {
245            Self::add_ephemeral_cache_control(last_tool);
246        }
247
248        converted
249    }
250
251    fn add_ephemeral_cache_control(block: &mut Value) {
252        if let Some(obj) = block.as_object_mut() {
253            obj.insert("cache_control".to_string(), json!({ "type": "ephemeral" }));
254        }
255    }
256}
257
258#[derive(Debug, Deserialize)]
259struct AnthropicResponse {
260    #[allow(dead_code)]
261    id: String,
262    #[allow(dead_code)]
263    model: String,
264    content: Vec<AnthropicContent>,
265    #[serde(default)]
266    stop_reason: Option<String>,
267    #[serde(default)]
268    usage: Option<AnthropicUsage>,
269}
270
271#[derive(Debug, Deserialize)]
272#[serde(tag = "type")]
273enum AnthropicContent {
274    #[serde(rename = "text")]
275    Text { text: String },
276    #[serde(rename = "thinking")]
277    Thinking {
278        #[serde(default)]
279        thinking: Option<String>,
280        #[serde(default)]
281        text: Option<String>,
282    },
283    #[serde(rename = "tool_use")]
284    ToolUse {
285        id: String,
286        name: String,
287        input: Value,
288    },
289    #[serde(other)]
290    Unknown,
291}
292
293#[derive(Debug, Deserialize)]
294struct AnthropicUsage {
295    #[serde(default)]
296    input_tokens: usize,
297    #[serde(default)]
298    output_tokens: usize,
299    #[serde(default)]
300    cache_creation_input_tokens: Option<usize>,
301    #[serde(default)]
302    cache_read_input_tokens: Option<usize>,
303}
304
305#[derive(Debug, Deserialize)]
306struct AnthropicError {
307    error: AnthropicErrorDetail,
308}
309
310#[derive(Debug, Deserialize)]
311struct AnthropicErrorDetail {
312    message: String,
313    #[serde(default, rename = "type")]
314    error_type: Option<String>,
315}
316
317#[async_trait]
318impl Provider for AnthropicProvider {
319    fn name(&self) -> &str {
320        &self.provider_name
321    }
322
323    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
324        self.validate_api_key()?;
325
326        if self.provider_name.eq_ignore_ascii_case("minimax") {
327            return Ok(vec![
328                ModelInfo {
329                    id: "MiniMax-M2.5".to_string(),
330                    name: "MiniMax M2.5".to_string(),
331                    provider: self.provider_name.clone(),
332                    context_window: 200_000,
333                    max_output_tokens: Some(65_536),
334                    supports_vision: false,
335                    supports_tools: true,
336                    supports_streaming: true,
337                    input_cost_per_million: Some(0.3),
338                    output_cost_per_million: Some(1.2),
339                },
340                ModelInfo {
341                    id: "MiniMax-M2.5-lightning".to_string(),
342                    name: "MiniMax M2.5 Lightning".to_string(),
343                    provider: self.provider_name.clone(),
344                    context_window: 200_000,
345                    max_output_tokens: Some(65_536),
346                    supports_vision: false,
347                    supports_tools: true,
348                    supports_streaming: true,
349                    input_cost_per_million: Some(0.3),
350                    output_cost_per_million: Some(2.4),
351                },
352                ModelInfo {
353                    id: "MiniMax-M2.1".to_string(),
354                    name: "MiniMax M2.1".to_string(),
355                    provider: self.provider_name.clone(),
356                    context_window: 200_000,
357                    max_output_tokens: Some(65_536),
358                    supports_vision: false,
359                    supports_tools: true,
360                    supports_streaming: true,
361                    input_cost_per_million: Some(0.3),
362                    output_cost_per_million: Some(1.2),
363                },
364                ModelInfo {
365                    id: "MiniMax-M2.1-lightning".to_string(),
366                    name: "MiniMax M2.1 Lightning".to_string(),
367                    provider: self.provider_name.clone(),
368                    context_window: 200_000,
369                    max_output_tokens: Some(65_536),
370                    supports_vision: false,
371                    supports_tools: true,
372                    supports_streaming: true,
373                    input_cost_per_million: Some(0.3),
374                    output_cost_per_million: Some(2.4),
375                },
376                ModelInfo {
377                    id: "MiniMax-M2".to_string(),
378                    name: "MiniMax M2".to_string(),
379                    provider: self.provider_name.clone(),
380                    context_window: 200_000,
381                    max_output_tokens: Some(65_536),
382                    supports_vision: false,
383                    supports_tools: true,
384                    supports_streaming: true,
385                    input_cost_per_million: Some(0.3),
386                    output_cost_per_million: Some(1.2),
387                },
388            ]);
389        }
390
391        Ok(vec![
392            ModelInfo {
393                id: "claude-sonnet-4-20250514".to_string(),
394                name: "Claude Sonnet 4".to_string(),
395                provider: self.provider_name.clone(),
396                context_window: 200_000,
397                max_output_tokens: Some(64_000),
398                supports_vision: true,
399                supports_tools: true,
400                supports_streaming: true,
401                input_cost_per_million: Some(3.0),
402                output_cost_per_million: Some(15.0),
403            },
404            ModelInfo {
405                id: "claude-opus-4-20250514".to_string(),
406                name: "Claude Opus 4".to_string(),
407                provider: self.provider_name.clone(),
408                context_window: 200_000,
409                max_output_tokens: Some(32_000),
410                supports_vision: true,
411                supports_tools: true,
412                supports_streaming: true,
413                input_cost_per_million: Some(15.0),
414                output_cost_per_million: Some(75.0),
415            },
416            ModelInfo {
417                id: "claude-haiku-3-5-20241022".to_string(),
418                name: "Claude 3.5 Haiku".to_string(),
419                provider: self.provider_name.clone(),
420                context_window: 200_000,
421                max_output_tokens: Some(8_192),
422                supports_vision: true,
423                supports_tools: true,
424                supports_streaming: true,
425                input_cost_per_million: Some(0.80),
426                output_cost_per_million: Some(4.0),
427            },
428        ])
429    }
430
431    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
432        tracing::debug!(
433            provider = %self.provider_name,
434            model = %request.model,
435            message_count = request.messages.len(),
436            tool_count = request.tools.len(),
437            "Starting completion request"
438        );
439
440        self.validate_api_key()?;
441
442        let (system_prompt, messages) =
443            Self::convert_messages(&request.messages, self.enable_prompt_caching);
444        let tools = Self::convert_tools(&request.tools, self.enable_prompt_caching);
445
446        let mut body = json!({
447            "model": request.model,
448            "messages": messages,
449            "max_tokens": request.max_tokens.unwrap_or(8192),
450        });
451
452        if let Some(system) = system_prompt {
453            body["system"] = json!(system);
454        }
455        if !tools.is_empty() {
456            body["tools"] = json!(tools);
457        }
458        if let Some(temp) = request.temperature {
459            body["temperature"] = json!(temp);
460        }
461        if let Some(top_p) = request.top_p {
462            body["top_p"] = json!(top_p);
463        }
464
465        tracing::debug!("Anthropic request to model {}", request.model);
466
467        let response = self
468            .client
469            .post(format!(
470                "{}/v1/messages",
471                self.base_url.trim_end_matches('/')
472            ))
473            .header("x-api-key", &self.api_key)
474            .header("anthropic-version", ANTHROPIC_VERSION)
475            .header("content-type", "application/json")
476            .json(&body)
477            .send()
478            .await
479            .context("Failed to send request to Anthropic")?;
480
481        let status = response.status();
482        let text = response
483            .text()
484            .await
485            .context("Failed to read Anthropic response")?;
486
487        if !status.is_success() {
488            if let Ok(err) = serde_json::from_str::<AnthropicError>(&text) {
489                anyhow::bail!(
490                    "Anthropic API error: {} ({:?})",
491                    err.error.message,
492                    err.error.error_type
493                );
494            }
495            anyhow::bail!("Anthropic API error: {} {}", status, text);
496        }
497
498        let response: AnthropicResponse = serde_json::from_str(&text).context(format!(
499            "Failed to parse Anthropic response: {}",
500            &text[..text.len().min(200)]
501        ))?;
502
503        tracing::debug!(
504            response_id = %response.id,
505            model = %response.model,
506            stop_reason = ?response.stop_reason,
507            "Received Anthropic response"
508        );
509
510        let mut content = Vec::new();
511        let mut has_tool_calls = false;
512
513        for part in &response.content {
514            match part {
515                AnthropicContent::Text { text } => {
516                    if !text.is_empty() {
517                        content.push(ContentPart::Text { text: text.clone() });
518                    }
519                }
520                AnthropicContent::Thinking { thinking, text } => {
521                    let reasoning = thinking
522                        .as_deref()
523                        .or(text.as_deref())
524                        .unwrap_or_default()
525                        .trim()
526                        .to_string();
527                    if !reasoning.is_empty() {
528                        content.push(ContentPart::Thinking { text: reasoning });
529                    }
530                }
531                AnthropicContent::ToolUse { id, name, input } => {
532                    has_tool_calls = true;
533                    content.push(ContentPart::ToolCall {
534                        id: id.clone(),
535                        name: name.clone(),
536                        arguments: serde_json::to_string(input).unwrap_or_default(),
537                    });
538                }
539                AnthropicContent::Unknown => {}
540            }
541        }
542
543        let finish_reason = if has_tool_calls {
544            FinishReason::ToolCalls
545        } else {
546            match response.stop_reason.as_deref() {
547                Some("end_turn") | Some("stop") => FinishReason::Stop,
548                Some("max_tokens") => FinishReason::Length,
549                Some("tool_use") => FinishReason::ToolCalls,
550                Some("content_filter") => FinishReason::ContentFilter,
551                _ => FinishReason::Stop,
552            }
553        };
554
555        let usage = response.usage.as_ref();
556
557        Ok(CompletionResponse {
558            message: Message {
559                role: Role::Assistant,
560                content,
561            },
562            usage: Usage {
563                prompt_tokens: usage.map(|u| u.input_tokens).unwrap_or(0),
564                completion_tokens: usage.map(|u| u.output_tokens).unwrap_or(0),
565                total_tokens: usage.map(|u| u.input_tokens + u.output_tokens).unwrap_or(0),
566                cache_read_tokens: usage.and_then(|u| u.cache_read_input_tokens),
567                cache_write_tokens: usage.and_then(|u| u.cache_creation_input_tokens),
568            },
569            finish_reason,
570        })
571    }
572
573    async fn complete_stream(
574        &self,
575        request: CompletionRequest,
576    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
577        // Fall back to non-streaming for now
578        let response = self.complete(request).await?;
579        let text = response
580            .message
581            .content
582            .iter()
583            .filter_map(|p| match p {
584                ContentPart::Text { text } => Some(text.clone()),
585                _ => None,
586            })
587            .collect::<Vec<_>>()
588            .join("");
589
590        Ok(Box::pin(futures::stream::once(async move {
591            StreamChunk::Text(text)
592        })))
593    }
594}
595
596fn parse_bool_env(value: &str) -> Option<bool> {
597    let normalized = value.trim().to_ascii_lowercase();
598    match normalized.as_str() {
599        "1" | "true" | "yes" | "on" | "enabled" => Some(true),
600        "0" | "false" | "no" | "off" | "disabled" => Some(false),
601        _ => None,
602    }
603}
604
605#[cfg(test)]
606mod tests {
607    use super::*;
608
609    #[test]
610    fn adds_cache_control_to_last_tool_system_and_message_block() {
611        let messages = vec![
612            Message {
613                role: Role::System,
614                content: vec![ContentPart::Text {
615                    text: "static instruction".to_string(),
616                }],
617            },
618            Message {
619                role: Role::User,
620                content: vec![ContentPart::Text {
621                    text: "dynamic input".to_string(),
622                }],
623            },
624        ];
625
626        let (system, converted_messages) = AnthropicProvider::convert_messages(&messages, true);
627        let mut converted_tools = AnthropicProvider::convert_tools(
628            &[ToolDefinition {
629                name: "get_weather".to_string(),
630                description: "Get weather".to_string(),
631                parameters: json!({"type": "object"}),
632            }],
633            true,
634        );
635
636        let system = system.expect("system blocks should be present");
637        let system_cache = system
638            .last()
639            .and_then(|v| v.get("cache_control"))
640            .and_then(|v| v.get("type"))
641            .and_then(Value::as_str);
642        assert_eq!(system_cache, Some("ephemeral"));
643
644        let message_cache = converted_messages
645            .last()
646            .and_then(|msg| msg.get("content"))
647            .and_then(Value::as_array)
648            .and_then(|parts| parts.last())
649            .and_then(|part| part.get("cache_control"))
650            .and_then(|v| v.get("type"))
651            .and_then(Value::as_str);
652        assert_eq!(message_cache, Some("ephemeral"));
653
654        let tool_cache = converted_tools
655            .pop()
656            .and_then(|tool| tool.get("cache_control").cloned())
657            .and_then(|v| v.get("type").cloned())
658            .and_then(|v| v.as_str().map(str::to_string));
659        assert_eq!(tool_cache.as_deref(), Some("ephemeral"));
660    }
661
662    #[test]
663    fn minimax_provider_name_enables_prompt_caching_by_default() {
664        let provider = AnthropicProvider::with_base_url(
665            "test-key".to_string(),
666            "https://api.minimax.io/anthropic".to_string(),
667            "minimax",
668        )
669        .expect("provider should initialize");
670
671        assert_eq!(provider.name(), "minimax");
672        assert!(provider.enable_prompt_caching);
673    }
674}