Skip to main content

codetether_agent/provider/
anthropic.rs

1//! Anthropic provider implementation using the Messages API
2//!
3//! Supports Claude Sonnet 4, Claude Opus 4, and other Claude models.
4//! Uses the native Anthropic API format (not OpenAI-compatible).
5//! Reference: https://docs.anthropic.com/en/api/messages
6
7use super::{
8    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
9    Role, StreamChunk, ToolDefinition, Usage,
10};
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use serde::Deserialize;
15use serde_json::{Value, json};
16
17const ANTHROPIC_API_BASE: &str = "https://api.anthropic.com";
18const ANTHROPIC_VERSION: &str = "2023-06-01";
19
20pub struct AnthropicProvider {
21    client: Client,
22    api_key: String,
23    base_url: String,
24    provider_name: String,
25    enable_prompt_caching: bool,
26}
27
28impl std::fmt::Debug for AnthropicProvider {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        f.debug_struct("AnthropicProvider")
31            .field("api_key", &"<REDACTED>")
32            .field("api_key_len", &self.api_key.len())
33            .field("base_url", &self.base_url)
34            .field("provider_name", &self.provider_name)
35            .field("enable_prompt_caching", &self.enable_prompt_caching)
36            .finish()
37    }
38}
39
40impl AnthropicProvider {
41    pub fn new(api_key: String) -> Result<Self> {
42        Self::with_base_url(api_key, ANTHROPIC_API_BASE.to_string(), "anthropic")
43    }
44
45    pub fn with_base_url(
46        api_key: String,
47        base_url: String,
48        provider_name: impl Into<String>,
49    ) -> Result<Self> {
50        let provider_name = provider_name.into();
51        let enable_prompt_caching = std::env::var("CODETETHER_ANTHROPIC_PROMPT_CACHING")
52            .ok()
53            .and_then(|v| parse_bool_env(&v))
54            .unwrap_or_else(|| {
55                provider_name.eq_ignore_ascii_case("minimax")
56                    || provider_name.eq_ignore_ascii_case("minimax-credits")
57            });
58
59        tracing::debug!(
60            provider = %provider_name,
61            api_key_len = api_key.len(),
62            base_url = %base_url,
63            enable_prompt_caching,
64            "Creating Anthropic provider"
65        );
66
67        Ok(Self {
68            client: Client::new(),
69            api_key,
70            base_url,
71            provider_name,
72            enable_prompt_caching,
73        })
74    }
75
76    fn validate_api_key(&self) -> Result<()> {
77        if self.api_key.is_empty() {
78            anyhow::bail!("Anthropic API key is empty");
79        }
80        Ok(())
81    }
82
83    /// Convert our generic messages to Anthropic Messages API format.
84    ///
85    /// Anthropic uses a different format:
86    /// - system prompt is a top-level field, not a message
87    /// - tool results go in user messages with type "tool_result"
88    /// - tool calls appear in assistant messages with type "tool_use"
89    fn convert_messages(
90        messages: &[Message],
91        enable_prompt_caching: bool,
92    ) -> (Option<Vec<Value>>, Vec<Value>) {
93        let mut system_blocks: Vec<Value> = Vec::new();
94        let mut api_messages: Vec<Value> = Vec::new();
95
96        for msg in messages {
97            match msg.role {
98                Role::System => {
99                    for part in &msg.content {
100                        match part {
101                            ContentPart::Text { text } => {
102                                system_blocks.push(json!({
103                                    "type": "text",
104                                    "text": text,
105                                }));
106                            }
107                            ContentPart::Thinking { text } => {
108                                system_blocks.push(json!({
109                                    "type": "thinking",
110                                    "thinking": text,
111                                }));
112                            }
113                            _ => {}
114                        }
115                    }
116                }
117                Role::User => {
118                    let mut content_parts: Vec<Value> = Vec::new();
119                    for part in &msg.content {
120                        match part {
121                            ContentPart::Text { text } => {
122                                content_parts.push(json!({
123                                    "type": "text",
124                                    "text": text,
125                                }));
126                            }
127                            ContentPart::Thinking { text } => {
128                                content_parts.push(json!({
129                                    "type": "thinking",
130                                    "thinking": text,
131                                }));
132                            }
133                            _ => {}
134                        }
135                    }
136                    if content_parts.is_empty() {
137                        content_parts.push(json!({"type": "text", "text": " "}));
138                    }
139                    api_messages.push(json!({
140                        "role": "user",
141                        "content": content_parts
142                    }));
143                }
144                Role::Assistant => {
145                    let mut content_parts: Vec<Value> = Vec::new();
146
147                    for part in &msg.content {
148                        match part {
149                            ContentPart::Text { text } => {
150                                content_parts.push(json!({
151                                    "type": "text",
152                                    "text": text
153                                }));
154                            }
155                            ContentPart::Thinking { text } => {
156                                content_parts.push(json!({
157                                    "type": "thinking",
158                                    "thinking": text
159                                }));
160                            }
161                            ContentPart::ToolCall {
162                                id,
163                                name,
164                                arguments,
165                                ..
166                            } => {
167                                let input: Value = serde_json::from_str(arguments)
168                                    .unwrap_or_else(|_| json!({"raw": arguments}));
169                                content_parts.push(json!({
170                                    "type": "tool_use",
171                                    "id": id,
172                                    "name": name,
173                                    "input": input
174                                }));
175                            }
176                            _ => {}
177                        }
178                    }
179
180                    if content_parts.is_empty() {
181                        content_parts.push(json!({"type": "text", "text": " "}));
182                    }
183
184                    api_messages.push(json!({
185                        "role": "assistant",
186                        "content": content_parts
187                    }));
188                }
189                Role::Tool => {
190                    let mut tool_results: Vec<Value> = Vec::new();
191                    for part in &msg.content {
192                        if let ContentPart::ToolResult {
193                            tool_call_id,
194                            content,
195                        } = part
196                        {
197                            tool_results.push(json!({
198                                "type": "tool_result",
199                                "tool_use_id": tool_call_id,
200                                "content": content
201                            }));
202                        }
203                    }
204                    if !tool_results.is_empty() {
205                        api_messages.push(json!({
206                            "role": "user",
207                            "content": tool_results
208                        }));
209                    }
210                }
211            }
212        }
213
214        if enable_prompt_caching {
215            if let Some(last_tool_or_text_msg) = api_messages.iter_mut().rev().find_map(|msg| {
216                msg.get_mut("content")
217                    .and_then(Value::as_array_mut)
218                    .and_then(|parts| parts.last_mut())
219            }) {
220                Self::add_ephemeral_cache_control(last_tool_or_text_msg);
221            }
222            if let Some(last_system) = system_blocks.last_mut() {
223                Self::add_ephemeral_cache_control(last_system);
224            }
225        }
226
227        let system = if system_blocks.is_empty() {
228            None
229        } else {
230            Some(system_blocks)
231        };
232
233        (system, api_messages)
234    }
235
236    fn convert_tools(tools: &[ToolDefinition], enable_prompt_caching: bool) -> Vec<Value> {
237        let mut converted: Vec<Value> = tools
238            .iter()
239            .map(|t| {
240                json!({
241                    "name": t.name,
242                    "description": t.description,
243                    "input_schema": t.parameters
244                })
245            })
246            .collect();
247
248        if enable_prompt_caching && let Some(last_tool) = converted.last_mut() {
249            Self::add_ephemeral_cache_control(last_tool);
250        }
251
252        converted
253    }
254
255    fn add_ephemeral_cache_control(block: &mut Value) {
256        if let Some(obj) = block.as_object_mut() {
257            obj.insert("cache_control".to_string(), json!({ "type": "ephemeral" }));
258        }
259    }
260}
261
262#[derive(Debug, Deserialize)]
263struct AnthropicResponse {
264    #[allow(dead_code)]
265    id: String,
266    #[allow(dead_code)]
267    model: String,
268    content: Vec<AnthropicContent>,
269    #[serde(default)]
270    stop_reason: Option<String>,
271    #[serde(default)]
272    usage: Option<AnthropicUsage>,
273}
274
275#[derive(Debug, Deserialize)]
276#[serde(tag = "type")]
277enum AnthropicContent {
278    #[serde(rename = "text")]
279    Text { text: String },
280    #[serde(rename = "thinking")]
281    Thinking {
282        #[serde(default)]
283        thinking: Option<String>,
284        #[serde(default)]
285        text: Option<String>,
286    },
287    #[serde(rename = "tool_use")]
288    ToolUse {
289        id: String,
290        name: String,
291        input: Value,
292    },
293    #[serde(other)]
294    Unknown,
295}
296
297#[derive(Debug, Deserialize)]
298struct AnthropicUsage {
299    #[serde(default)]
300    input_tokens: usize,
301    #[serde(default)]
302    output_tokens: usize,
303    #[serde(default)]
304    cache_creation_input_tokens: Option<usize>,
305    #[serde(default)]
306    cache_read_input_tokens: Option<usize>,
307}
308
309#[derive(Debug, Deserialize)]
310struct AnthropicError {
311    error: AnthropicErrorDetail,
312}
313
314#[derive(Debug, Deserialize)]
315struct AnthropicErrorDetail {
316    message: String,
317    #[serde(default, rename = "type")]
318    error_type: Option<String>,
319}
320
321#[async_trait]
322impl Provider for AnthropicProvider {
323    fn name(&self) -> &str {
324        &self.provider_name
325    }
326
327    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
328        self.validate_api_key()?;
329
330        if self.provider_name.eq_ignore_ascii_case("minimax-credits") {
331            // Credits-based key: highspeed models only (not on coding plan)
332            return Ok(vec![
333                ModelInfo {
334                    id: "MiniMax-M2.5-highspeed".to_string(),
335                    name: "MiniMax M2.5 Highspeed".to_string(),
336                    provider: self.provider_name.clone(),
337                    context_window: 200_000,
338                    max_output_tokens: Some(65_536),
339                    supports_vision: false,
340                    supports_tools: true,
341                    supports_streaming: true,
342                    input_cost_per_million: Some(0.6),
343                    output_cost_per_million: Some(2.4),
344                },
345                ModelInfo {
346                    id: "MiniMax-M2.1-highspeed".to_string(),
347                    name: "MiniMax M2.1 Highspeed".to_string(),
348                    provider: self.provider_name.clone(),
349                    context_window: 200_000,
350                    max_output_tokens: Some(65_536),
351                    supports_vision: false,
352                    supports_tools: true,
353                    supports_streaming: true,
354                    input_cost_per_million: Some(0.6),
355                    output_cost_per_million: Some(2.4),
356                },
357            ]);
358        }
359
360        if self.provider_name.eq_ignore_ascii_case("minimax") {
361            // Coding plan key: regular (non-highspeed) models
362            return Ok(vec![
363                ModelInfo {
364                    id: "MiniMax-M2.5".to_string(),
365                    name: "MiniMax M2.5".to_string(),
366                    provider: self.provider_name.clone(),
367                    context_window: 200_000,
368                    max_output_tokens: Some(65_536),
369                    supports_vision: false,
370                    supports_tools: true,
371                    supports_streaming: true,
372                    input_cost_per_million: Some(0.3),
373                    output_cost_per_million: Some(1.2),
374                },
375                ModelInfo {
376                    id: "MiniMax-M2.1".to_string(),
377                    name: "MiniMax M2.1".to_string(),
378                    provider: self.provider_name.clone(),
379                    context_window: 200_000,
380                    max_output_tokens: Some(65_536),
381                    supports_vision: false,
382                    supports_tools: true,
383                    supports_streaming: true,
384                    input_cost_per_million: Some(0.3),
385                    output_cost_per_million: Some(1.2),
386                },
387                ModelInfo {
388                    id: "MiniMax-M2".to_string(),
389                    name: "MiniMax M2".to_string(),
390                    provider: self.provider_name.clone(),
391                    context_window: 200_000,
392                    max_output_tokens: Some(65_536),
393                    supports_vision: false,
394                    supports_tools: true,
395                    supports_streaming: true,
396                    input_cost_per_million: Some(0.3),
397                    output_cost_per_million: Some(1.2),
398                },
399            ]);
400        }
401
402        Ok(vec![
403            ModelInfo {
404                id: "claude-sonnet-4-6".to_string(),
405                name: "Claude Sonnet 4.6".to_string(),
406                provider: self.provider_name.clone(),
407                context_window: 200_000,
408                max_output_tokens: Some(128_000),
409                supports_vision: true,
410                supports_tools: true,
411                supports_streaming: true,
412                input_cost_per_million: Some(3.0),
413                output_cost_per_million: Some(15.0),
414            },
415            ModelInfo {
416                id: "claude-sonnet-4-20250514".to_string(),
417                name: "Claude Sonnet 4".to_string(),
418                provider: self.provider_name.clone(),
419                context_window: 200_000,
420                max_output_tokens: Some(64_000),
421                supports_vision: true,
422                supports_tools: true,
423                supports_streaming: true,
424                input_cost_per_million: Some(3.0),
425                output_cost_per_million: Some(15.0),
426            },
427            ModelInfo {
428                id: "claude-opus-4-20250514".to_string(),
429                name: "Claude Opus 4".to_string(),
430                provider: self.provider_name.clone(),
431                context_window: 200_000,
432                max_output_tokens: Some(32_000),
433                supports_vision: true,
434                supports_tools: true,
435                supports_streaming: true,
436                input_cost_per_million: Some(15.0),
437                output_cost_per_million: Some(75.0),
438            },
439            ModelInfo {
440                id: "claude-haiku-3-5-20241022".to_string(),
441                name: "Claude 3.5 Haiku".to_string(),
442                provider: self.provider_name.clone(),
443                context_window: 200_000,
444                max_output_tokens: Some(8_192),
445                supports_vision: true,
446                supports_tools: true,
447                supports_streaming: true,
448                input_cost_per_million: Some(0.80),
449                output_cost_per_million: Some(4.0),
450            },
451        ])
452    }
453
454    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
455        tracing::debug!(
456            provider = %self.provider_name,
457            model = %request.model,
458            message_count = request.messages.len(),
459            tool_count = request.tools.len(),
460            "Starting completion request"
461        );
462
463        self.validate_api_key()?;
464
465        let (system_prompt, messages) =
466            Self::convert_messages(&request.messages, self.enable_prompt_caching);
467        let tools = Self::convert_tools(&request.tools, self.enable_prompt_caching);
468
469        let mut body = json!({
470            "model": request.model,
471            "messages": messages,
472            "max_tokens": request.max_tokens.unwrap_or(8192),
473        });
474
475        if let Some(system) = system_prompt {
476            body["system"] = json!(system);
477        }
478        if !tools.is_empty() {
479            body["tools"] = json!(tools);
480        }
481        if let Some(temp) = request.temperature {
482            body["temperature"] = json!(temp);
483        }
484        if let Some(top_p) = request.top_p {
485            body["top_p"] = json!(top_p);
486        }
487
488        tracing::debug!("Anthropic request to model {}", request.model);
489
490        let response = self
491            .client
492            .post(format!(
493                "{}/v1/messages",
494                self.base_url.trim_end_matches('/')
495            ))
496            .header("x-api-key", &self.api_key)
497            .header("anthropic-version", ANTHROPIC_VERSION)
498            .header("content-type", "application/json")
499            .json(&body)
500            .send()
501            .await
502            .context("Failed to send request to Anthropic")?;
503
504        let status = response.status();
505        let text = response
506            .text()
507            .await
508            .context("Failed to read Anthropic response")?;
509
510        if !status.is_success() {
511            if let Ok(err) = serde_json::from_str::<AnthropicError>(&text) {
512                anyhow::bail!(
513                    "Anthropic API error: {} ({:?})",
514                    err.error.message,
515                    err.error.error_type
516                );
517            }
518            anyhow::bail!("Anthropic API error: {} {}", status, text);
519        }
520
521        let response: AnthropicResponse = serde_json::from_str(&text).context(format!(
522            "Failed to parse Anthropic response: {}",
523            &text[..text.len().min(200)]
524        ))?;
525
526        tracing::debug!(
527            response_id = %response.id,
528            model = %response.model,
529            stop_reason = ?response.stop_reason,
530            "Received Anthropic response"
531        );
532
533        let mut content = Vec::new();
534        let mut has_tool_calls = false;
535
536        for part in &response.content {
537            match part {
538                AnthropicContent::Text { text } => {
539                    if !text.is_empty() {
540                        content.push(ContentPart::Text { text: text.clone() });
541                    }
542                }
543                AnthropicContent::Thinking { thinking, text } => {
544                    let reasoning = thinking
545                        .as_deref()
546                        .or(text.as_deref())
547                        .unwrap_or_default()
548                        .trim()
549                        .to_string();
550                    if !reasoning.is_empty() {
551                        content.push(ContentPart::Thinking { text: reasoning });
552                    }
553                }
554                AnthropicContent::ToolUse { id, name, input } => {
555                    has_tool_calls = true;
556                    content.push(ContentPart::ToolCall {
557                        id: id.clone(),
558                        name: name.clone(),
559                        arguments: serde_json::to_string(input).unwrap_or_default(),
560                        thought_signature: None,
561                    });
562                }
563                AnthropicContent::Unknown => {}
564            }
565        }
566
567        let finish_reason = if has_tool_calls {
568            FinishReason::ToolCalls
569        } else {
570            match response.stop_reason.as_deref() {
571                Some("end_turn") | Some("stop") => FinishReason::Stop,
572                Some("max_tokens") => FinishReason::Length,
573                Some("tool_use") => FinishReason::ToolCalls,
574                Some("content_filter") => FinishReason::ContentFilter,
575                _ => FinishReason::Stop,
576            }
577        };
578
579        let usage = response.usage.as_ref();
580
581        Ok(CompletionResponse {
582            message: Message {
583                role: Role::Assistant,
584                content,
585            },
586            usage: Usage {
587                prompt_tokens: usage.map(|u| u.input_tokens).unwrap_or(0),
588                completion_tokens: usage.map(|u| u.output_tokens).unwrap_or(0),
589                total_tokens: usage.map(|u| u.input_tokens + u.output_tokens).unwrap_or(0),
590                cache_read_tokens: usage.and_then(|u| u.cache_read_input_tokens),
591                cache_write_tokens: usage.and_then(|u| u.cache_creation_input_tokens),
592            },
593            finish_reason,
594        })
595    }
596
597    async fn complete_stream(
598        &self,
599        request: CompletionRequest,
600    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
601        // Fall back to non-streaming for now
602        let response = self.complete(request).await?;
603        let text = response
604            .message
605            .content
606            .iter()
607            .filter_map(|p| match p {
608                ContentPart::Text { text } => Some(text.clone()),
609                _ => None,
610            })
611            .collect::<Vec<_>>()
612            .join("");
613
614        Ok(Box::pin(futures::stream::once(async move {
615            StreamChunk::Text(text)
616        })))
617    }
618}
619
620fn parse_bool_env(value: &str) -> Option<bool> {
621    let normalized = value.trim().to_ascii_lowercase();
622    match normalized.as_str() {
623        "1" | "true" | "yes" | "on" | "enabled" => Some(true),
624        "0" | "false" | "no" | "off" | "disabled" => Some(false),
625        _ => None,
626    }
627}
628
629#[cfg(test)]
630mod tests {
631    use super::*;
632
633    #[test]
634    fn adds_cache_control_to_last_tool_system_and_message_block() {
635        let messages = vec![
636            Message {
637                role: Role::System,
638                content: vec![ContentPart::Text {
639                    text: "static instruction".to_string(),
640                }],
641            },
642            Message {
643                role: Role::User,
644                content: vec![ContentPart::Text {
645                    text: "dynamic input".to_string(),
646                }],
647            },
648        ];
649
650        let (system, converted_messages) = AnthropicProvider::convert_messages(&messages, true);
651        let mut converted_tools = AnthropicProvider::convert_tools(
652            &[ToolDefinition {
653                name: "get_weather".to_string(),
654                description: "Get weather".to_string(),
655                parameters: json!({"type": "object"}),
656            }],
657            true,
658        );
659
660        let system = system.expect("system blocks should be present");
661        let system_cache = system
662            .last()
663            .and_then(|v| v.get("cache_control"))
664            .and_then(|v| v.get("type"))
665            .and_then(Value::as_str);
666        assert_eq!(system_cache, Some("ephemeral"));
667
668        let message_cache = converted_messages
669            .last()
670            .and_then(|msg| msg.get("content"))
671            .and_then(Value::as_array)
672            .and_then(|parts| parts.last())
673            .and_then(|part| part.get("cache_control"))
674            .and_then(|v| v.get("type"))
675            .and_then(Value::as_str);
676        assert_eq!(message_cache, Some("ephemeral"));
677
678        let tool_cache = converted_tools
679            .pop()
680            .and_then(|tool| tool.get("cache_control").cloned())
681            .and_then(|v| v.get("type").cloned())
682            .and_then(|v| v.as_str().map(str::to_string));
683        assert_eq!(tool_cache.as_deref(), Some("ephemeral"));
684    }
685
686    #[test]
687    fn minimax_provider_name_enables_prompt_caching_by_default() {
688        let provider = AnthropicProvider::with_base_url(
689            "test-key".to_string(),
690            "https://api.minimax.io/anthropic".to_string(),
691            "minimax",
692        )
693        .expect("provider should initialize");
694
695        assert_eq!(provider.name(), "minimax");
696        assert!(provider.enable_prompt_caching);
697    }
698}