Skip to main content

codetether_agent/provider/
anthropic.rs

1//! Anthropic provider implementation using the Messages API
2//!
3//! Supports Claude Sonnet 4, Claude Opus 4, and other Claude models.
4//! Uses the native Anthropic API format (not OpenAI-compatible).
5//! Reference: https://docs.anthropic.com/en/api/messages
6
7use super::{
8    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
9    Role, StreamChunk, ToolDefinition, Usage,
10};
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use serde::Deserialize;
15use serde_json::{Value, json};
16
17const ANTHROPIC_API_BASE: &str = "https://api.anthropic.com";
18const ANTHROPIC_VERSION: &str = "2023-06-01";
19
20pub struct AnthropicProvider {
21    client: Client,
22    api_key: String,
23    base_url: String,
24    provider_name: String,
25    enable_prompt_caching: bool,
26}
27
28impl std::fmt::Debug for AnthropicProvider {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        f.debug_struct("AnthropicProvider")
31            .field("api_key", &"<REDACTED>")
32            .field("api_key_len", &self.api_key.len())
33            .field("base_url", &self.base_url)
34            .field("provider_name", &self.provider_name)
35            .field("enable_prompt_caching", &self.enable_prompt_caching)
36            .finish()
37    }
38}
39
40impl AnthropicProvider {
41    pub fn new(api_key: String) -> Result<Self> {
42        Self::with_base_url(api_key, ANTHROPIC_API_BASE.to_string(), "anthropic")
43    }
44
45    pub fn with_base_url(
46        api_key: String,
47        base_url: String,
48        provider_name: impl Into<String>,
49    ) -> Result<Self> {
50        let provider_name = provider_name.into();
51        let enable_prompt_caching = std::env::var("CODETETHER_ANTHROPIC_PROMPT_CACHING")
52            .ok()
53            .and_then(|v| parse_bool_env(&v))
54            .unwrap_or_else(|| {
55                provider_name.eq_ignore_ascii_case("minimax")
56                    || provider_name.eq_ignore_ascii_case("minimax-credits")
57            });
58
59        tracing::debug!(
60            provider = %provider_name,
61            api_key_len = api_key.len(),
62            base_url = %base_url,
63            enable_prompt_caching,
64            "Creating Anthropic provider"
65        );
66
67        Ok(Self {
68            client: Client::new(),
69            api_key,
70            base_url,
71            provider_name,
72            enable_prompt_caching,
73        })
74    }
75
76    fn validate_api_key(&self) -> Result<()> {
77        if self.api_key.is_empty() {
78            anyhow::bail!("Anthropic API key is empty");
79        }
80        Ok(())
81    }
82
83    /// Convert our generic messages to Anthropic Messages API format.
84    ///
85    /// Anthropic uses a different format:
86    /// - system prompt is a top-level field, not a message
87    /// - tool results go in user messages with type "tool_result"
88    /// - tool calls appear in assistant messages with type "tool_use"
89    fn convert_messages(
90        messages: &[Message],
91        enable_prompt_caching: bool,
92    ) -> (Option<Vec<Value>>, Vec<Value>) {
93        let mut system_blocks: Vec<Value> = Vec::new();
94        let mut api_messages: Vec<Value> = Vec::new();
95
96        for msg in messages {
97            match msg.role {
98                Role::System => {
99                    for part in &msg.content {
100                        match part {
101                            ContentPart::Text { text } => {
102                                system_blocks.push(json!({
103                                    "type": "text",
104                                    "text": text,
105                                }));
106                            }
107                            ContentPart::Thinking { text } => {
108                                system_blocks.push(json!({
109                                    "type": "thinking",
110                                    "thinking": text,
111                                }));
112                            }
113                            _ => {}
114                        }
115                    }
116                }
117                Role::User => {
118                    let mut content_parts: Vec<Value> = Vec::new();
119                    for part in &msg.content {
120                        match part {
121                            ContentPart::Text { text } => {
122                                content_parts.push(json!({
123                                    "type": "text",
124                                    "text": text,
125                                }));
126                            }
127                            ContentPart::Thinking { text } => {
128                                content_parts.push(json!({
129                                    "type": "thinking",
130                                    "thinking": text,
131                                }));
132                            }
133                            _ => {}
134                        }
135                    }
136                    if content_parts.is_empty() {
137                        content_parts.push(json!({"type": "text", "text": " "}));
138                    }
139                    api_messages.push(json!({
140                        "role": "user",
141                        "content": content_parts
142                    }));
143                }
144                Role::Assistant => {
145                    let mut content_parts: Vec<Value> = Vec::new();
146
147                    for part in &msg.content {
148                        match part {
149                            ContentPart::Text { text } => {
150                                content_parts.push(json!({
151                                    "type": "text",
152                                    "text": text
153                                }));
154                            }
155                            ContentPart::Thinking { text } => {
156                                content_parts.push(json!({
157                                    "type": "thinking",
158                                    "thinking": text
159                                }));
160                            }
161                            ContentPart::ToolCall {
162                                id,
163                                name,
164                                arguments,
165                                ..
166                            } => {
167                                let input: Value = serde_json::from_str(arguments)
168                                    .unwrap_or_else(|_| json!({"raw": arguments}));
169                                content_parts.push(json!({
170                                    "type": "tool_use",
171                                    "id": id,
172                                    "name": name,
173                                    "input": input
174                                }));
175                            }
176                            _ => {}
177                        }
178                    }
179
180                    if content_parts.is_empty() {
181                        content_parts.push(json!({"type": "text", "text": " "}));
182                    }
183
184                    api_messages.push(json!({
185                        "role": "assistant",
186                        "content": content_parts
187                    }));
188                }
189                Role::Tool => {
190                    let mut tool_results: Vec<Value> = Vec::new();
191                    for part in &msg.content {
192                        if let ContentPart::ToolResult {
193                            tool_call_id,
194                            content,
195                        } = part
196                        {
197                            tool_results.push(json!({
198                                "type": "tool_result",
199                                "tool_use_id": tool_call_id,
200                                "content": content
201                            }));
202                        }
203                    }
204                    if !tool_results.is_empty() {
205                        api_messages.push(json!({
206                            "role": "user",
207                            "content": tool_results
208                        }));
209                    }
210                }
211            }
212        }
213
214        if enable_prompt_caching {
215            if let Some(last_tool_or_text_msg) = api_messages.iter_mut().rev().find_map(|msg| {
216                msg.get_mut("content")
217                    .and_then(Value::as_array_mut)
218                    .and_then(|parts| parts.last_mut())
219            }) {
220                Self::add_ephemeral_cache_control(last_tool_or_text_msg);
221            }
222            if let Some(last_system) = system_blocks.last_mut() {
223                Self::add_ephemeral_cache_control(last_system);
224            }
225        }
226
227        let system = if system_blocks.is_empty() {
228            None
229        } else {
230            Some(system_blocks)
231        };
232
233        (system, api_messages)
234    }
235
236    fn convert_tools(tools: &[ToolDefinition], enable_prompt_caching: bool) -> Vec<Value> {
237        let mut converted: Vec<Value> = tools
238            .iter()
239            .map(|t| {
240                json!({
241                    "name": t.name,
242                    "description": t.description,
243                    "input_schema": t.parameters
244                })
245            })
246            .collect();
247
248        if enable_prompt_caching && let Some(last_tool) = converted.last_mut() {
249            Self::add_ephemeral_cache_control(last_tool);
250        }
251
252        converted
253    }
254
255    fn add_ephemeral_cache_control(block: &mut Value) {
256        if let Some(obj) = block.as_object_mut() {
257            obj.insert("cache_control".to_string(), json!({ "type": "ephemeral" }));
258        }
259    }
260}
261
262fn safe_char_prefix(input: &str, max_chars: usize) -> String {
263    input.chars().take(max_chars).collect()
264}
265
266#[derive(Debug, Deserialize)]
267struct AnthropicResponse {
268    #[allow(dead_code)]
269    id: String,
270    #[allow(dead_code)]
271    model: String,
272    content: Vec<AnthropicContent>,
273    #[serde(default)]
274    stop_reason: Option<String>,
275    #[serde(default)]
276    usage: Option<AnthropicUsage>,
277}
278
279#[derive(Debug, Deserialize)]
280#[serde(tag = "type")]
281enum AnthropicContent {
282    #[serde(rename = "text")]
283    Text { text: String },
284    #[serde(rename = "thinking")]
285    Thinking {
286        #[serde(default)]
287        thinking: Option<String>,
288        #[serde(default)]
289        text: Option<String>,
290    },
291    #[serde(rename = "tool_use")]
292    ToolUse {
293        id: String,
294        name: String,
295        input: Value,
296    },
297    #[serde(other)]
298    Unknown,
299}
300
301#[derive(Debug, Deserialize)]
302struct AnthropicUsage {
303    #[serde(default)]
304    input_tokens: usize,
305    #[serde(default)]
306    output_tokens: usize,
307    #[serde(default)]
308    cache_creation_input_tokens: Option<usize>,
309    #[serde(default)]
310    cache_read_input_tokens: Option<usize>,
311}
312
313#[derive(Debug, Deserialize)]
314struct AnthropicError {
315    error: AnthropicErrorDetail,
316}
317
318#[derive(Debug, Deserialize)]
319struct AnthropicErrorDetail {
320    message: String,
321    #[serde(default, rename = "type")]
322    error_type: Option<String>,
323}
324
325#[async_trait]
326impl Provider for AnthropicProvider {
327    fn name(&self) -> &str {
328        &self.provider_name
329    }
330
331    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
332        self.validate_api_key()?;
333
334        if self.provider_name.eq_ignore_ascii_case("minimax-credits") {
335            // Credits-based key: highspeed models only (not on coding plan)
336            return Ok(vec![
337                ModelInfo {
338                    id: "MiniMax-M2.5-highspeed".to_string(),
339                    name: "MiniMax M2.5 Highspeed".to_string(),
340                    provider: self.provider_name.clone(),
341                    context_window: 200_000,
342                    max_output_tokens: Some(65_536),
343                    supports_vision: false,
344                    supports_tools: true,
345                    supports_streaming: true,
346                    input_cost_per_million: Some(0.6),
347                    output_cost_per_million: Some(2.4),
348                },
349                ModelInfo {
350                    id: "MiniMax-M2.1-highspeed".to_string(),
351                    name: "MiniMax M2.1 Highspeed".to_string(),
352                    provider: self.provider_name.clone(),
353                    context_window: 200_000,
354                    max_output_tokens: Some(65_536),
355                    supports_vision: false,
356                    supports_tools: true,
357                    supports_streaming: true,
358                    input_cost_per_million: Some(0.6),
359                    output_cost_per_million: Some(2.4),
360                },
361            ]);
362        }
363
364        if self.provider_name.eq_ignore_ascii_case("minimax") {
365            // Coding plan key: regular (non-highspeed) models
366            return Ok(vec![
367                ModelInfo {
368                    id: "MiniMax-M2.5".to_string(),
369                    name: "MiniMax M2.5".to_string(),
370                    provider: self.provider_name.clone(),
371                    context_window: 200_000,
372                    max_output_tokens: Some(65_536),
373                    supports_vision: false,
374                    supports_tools: true,
375                    supports_streaming: true,
376                    input_cost_per_million: Some(0.3),
377                    output_cost_per_million: Some(1.2),
378                },
379                ModelInfo {
380                    id: "MiniMax-M2.1".to_string(),
381                    name: "MiniMax M2.1".to_string(),
382                    provider: self.provider_name.clone(),
383                    context_window: 200_000,
384                    max_output_tokens: Some(65_536),
385                    supports_vision: false,
386                    supports_tools: true,
387                    supports_streaming: true,
388                    input_cost_per_million: Some(0.3),
389                    output_cost_per_million: Some(1.2),
390                },
391                ModelInfo {
392                    id: "MiniMax-M2".to_string(),
393                    name: "MiniMax M2".to_string(),
394                    provider: self.provider_name.clone(),
395                    context_window: 200_000,
396                    max_output_tokens: Some(65_536),
397                    supports_vision: false,
398                    supports_tools: true,
399                    supports_streaming: true,
400                    input_cost_per_million: Some(0.3),
401                    output_cost_per_million: Some(1.2),
402                },
403            ]);
404        }
405
406        Ok(vec![
407            ModelInfo {
408                id: "claude-sonnet-4-6".to_string(),
409                name: "Claude Sonnet 4.6".to_string(),
410                provider: self.provider_name.clone(),
411                context_window: 200_000,
412                max_output_tokens: Some(128_000),
413                supports_vision: true,
414                supports_tools: true,
415                supports_streaming: true,
416                input_cost_per_million: Some(3.0),
417                output_cost_per_million: Some(15.0),
418            },
419            ModelInfo {
420                id: "claude-sonnet-4-20250514".to_string(),
421                name: "Claude Sonnet 4".to_string(),
422                provider: self.provider_name.clone(),
423                context_window: 200_000,
424                max_output_tokens: Some(64_000),
425                supports_vision: true,
426                supports_tools: true,
427                supports_streaming: true,
428                input_cost_per_million: Some(3.0),
429                output_cost_per_million: Some(15.0),
430            },
431            ModelInfo {
432                id: "claude-opus-4-20250514".to_string(),
433                name: "Claude Opus 4".to_string(),
434                provider: self.provider_name.clone(),
435                context_window: 200_000,
436                max_output_tokens: Some(32_000),
437                supports_vision: true,
438                supports_tools: true,
439                supports_streaming: true,
440                input_cost_per_million: Some(15.0),
441                output_cost_per_million: Some(75.0),
442            },
443            ModelInfo {
444                id: "claude-haiku-3-5-20241022".to_string(),
445                name: "Claude 3.5 Haiku".to_string(),
446                provider: self.provider_name.clone(),
447                context_window: 200_000,
448                max_output_tokens: Some(8_192),
449                supports_vision: true,
450                supports_tools: true,
451                supports_streaming: true,
452                input_cost_per_million: Some(0.80),
453                output_cost_per_million: Some(4.0),
454            },
455        ])
456    }
457
458    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
459        tracing::debug!(
460            provider = %self.provider_name,
461            model = %request.model,
462            message_count = request.messages.len(),
463            tool_count = request.tools.len(),
464            "Starting completion request"
465        );
466
467        self.validate_api_key()?;
468
469        let (system_prompt, messages) =
470            Self::convert_messages(&request.messages, self.enable_prompt_caching);
471        let tools = Self::convert_tools(&request.tools, self.enable_prompt_caching);
472
473        let mut body = json!({
474            "model": request.model,
475            "messages": messages,
476            "max_tokens": request.max_tokens.unwrap_or(8192),
477        });
478
479        if let Some(system) = system_prompt {
480            body["system"] = json!(system);
481        }
482        if !tools.is_empty() {
483            body["tools"] = json!(tools);
484        }
485        if let Some(temp) = request.temperature {
486            body["temperature"] = json!(temp);
487        }
488        if let Some(top_p) = request.top_p {
489            body["top_p"] = json!(top_p);
490        }
491
492        tracing::debug!("Anthropic request to model {}", request.model);
493
494        let response = self
495            .client
496            .post(format!(
497                "{}/v1/messages",
498                self.base_url.trim_end_matches('/')
499            ))
500            .header("x-api-key", &self.api_key)
501            .header("anthropic-version", ANTHROPIC_VERSION)
502            .header("content-type", "application/json")
503            .json(&body)
504            .send()
505            .await
506            .context("Failed to send request to Anthropic")?;
507
508        let status = response.status();
509        let text = response
510            .text()
511            .await
512            .context("Failed to read Anthropic response")?;
513
514        if !status.is_success() {
515            if let Ok(err) = serde_json::from_str::<AnthropicError>(&text) {
516                anyhow::bail!(
517                    "Anthropic API error: {} ({:?})",
518                    err.error.message,
519                    err.error.error_type
520                );
521            }
522            anyhow::bail!("Anthropic API error: {} {}", status, text);
523        }
524
525        let response: AnthropicResponse = serde_json::from_str(&text).context(format!(
526            "Failed to parse Anthropic response: {}",
527            safe_char_prefix(&text, 200)
528        ))?;
529
530        tracing::debug!(
531            response_id = %response.id,
532            model = %response.model,
533            stop_reason = ?response.stop_reason,
534            "Received Anthropic response"
535        );
536
537        let mut content = Vec::new();
538        let mut has_tool_calls = false;
539
540        for part in &response.content {
541            match part {
542                AnthropicContent::Text { text } => {
543                    if !text.is_empty() {
544                        content.push(ContentPart::Text { text: text.clone() });
545                    }
546                }
547                AnthropicContent::Thinking { thinking, text } => {
548                    let reasoning = thinking
549                        .as_deref()
550                        .or(text.as_deref())
551                        .unwrap_or_default()
552                        .trim()
553                        .to_string();
554                    if !reasoning.is_empty() {
555                        content.push(ContentPart::Thinking { text: reasoning });
556                    }
557                }
558                AnthropicContent::ToolUse { id, name, input } => {
559                    has_tool_calls = true;
560                    content.push(ContentPart::ToolCall {
561                        id: id.clone(),
562                        name: name.clone(),
563                        arguments: serde_json::to_string(input).unwrap_or_default(),
564                        thought_signature: None,
565                    });
566                }
567                AnthropicContent::Unknown => {}
568            }
569        }
570
571        let finish_reason = if has_tool_calls {
572            FinishReason::ToolCalls
573        } else {
574            match response.stop_reason.as_deref() {
575                Some("end_turn") | Some("stop") => FinishReason::Stop,
576                Some("max_tokens") => FinishReason::Length,
577                Some("tool_use") => FinishReason::ToolCalls,
578                Some("content_filter") => FinishReason::ContentFilter,
579                _ => FinishReason::Stop,
580            }
581        };
582
583        let usage = response.usage.as_ref();
584
585        Ok(CompletionResponse {
586            message: Message {
587                role: Role::Assistant,
588                content,
589            },
590            usage: Usage {
591                prompt_tokens: usage.map(|u| u.input_tokens).unwrap_or(0),
592                completion_tokens: usage.map(|u| u.output_tokens).unwrap_or(0),
593                total_tokens: usage.map(|u| u.input_tokens + u.output_tokens).unwrap_or(0),
594                cache_read_tokens: usage.and_then(|u| u.cache_read_input_tokens),
595                cache_write_tokens: usage.and_then(|u| u.cache_creation_input_tokens),
596            },
597            finish_reason,
598        })
599    }
600
601    async fn complete_stream(
602        &self,
603        request: CompletionRequest,
604    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
605        // Fall back to non-streaming for now
606        let response = self.complete(request).await?;
607        let text = response
608            .message
609            .content
610            .iter()
611            .filter_map(|p| match p {
612                ContentPart::Text { text } => Some(text.clone()),
613                _ => None,
614            })
615            .collect::<Vec<_>>()
616            .join("");
617
618        Ok(Box::pin(futures::stream::once(async move {
619            StreamChunk::Text(text)
620        })))
621    }
622}
623
624fn parse_bool_env(value: &str) -> Option<bool> {
625    let normalized = value.trim().to_ascii_lowercase();
626    match normalized.as_str() {
627        "1" | "true" | "yes" | "on" | "enabled" => Some(true),
628        "0" | "false" | "no" | "off" | "disabled" => Some(false),
629        _ => None,
630    }
631}
632
633#[cfg(test)]
634mod tests {
635    use super::*;
636
637    #[test]
638    fn adds_cache_control_to_last_tool_system_and_message_block() {
639        let messages = vec![
640            Message {
641                role: Role::System,
642                content: vec![ContentPart::Text {
643                    text: "static instruction".to_string(),
644                }],
645            },
646            Message {
647                role: Role::User,
648                content: vec![ContentPart::Text {
649                    text: "dynamic input".to_string(),
650                }],
651            },
652        ];
653
654        let (system, converted_messages) = AnthropicProvider::convert_messages(&messages, true);
655        let mut converted_tools = AnthropicProvider::convert_tools(
656            &[ToolDefinition {
657                name: "get_weather".to_string(),
658                description: "Get weather".to_string(),
659                parameters: json!({"type": "object"}),
660            }],
661            true,
662        );
663
664        let system = system.expect("system blocks should be present");
665        let system_cache = system
666            .last()
667            .and_then(|v| v.get("cache_control"))
668            .and_then(|v| v.get("type"))
669            .and_then(Value::as_str);
670        assert_eq!(system_cache, Some("ephemeral"));
671
672        let message_cache = converted_messages
673            .last()
674            .and_then(|msg| msg.get("content"))
675            .and_then(Value::as_array)
676            .and_then(|parts| parts.last())
677            .and_then(|part| part.get("cache_control"))
678            .and_then(|v| v.get("type"))
679            .and_then(Value::as_str);
680        assert_eq!(message_cache, Some("ephemeral"));
681
682        let tool_cache = converted_tools
683            .pop()
684            .and_then(|tool| tool.get("cache_control").cloned())
685            .and_then(|v| v.get("type").cloned())
686            .and_then(|v| v.as_str().map(str::to_string));
687        assert_eq!(tool_cache.as_deref(), Some("ephemeral"));
688    }
689
690    #[test]
691    fn minimax_provider_name_enables_prompt_caching_by_default() {
692        let provider = AnthropicProvider::with_base_url(
693            "test-key".to_string(),
694            "https://api.minimax.io/anthropic".to_string(),
695            "minimax",
696        )
697        .expect("provider should initialize");
698
699        assert_eq!(provider.name(), "minimax");
700        assert!(provider.enable_prompt_caching);
701    }
702
703    #[test]
704    fn safe_char_prefix_handles_multibyte_characters() {
705        let s = "abc✓def";
706        assert_eq!(safe_char_prefix(s, 4), "abc✓");
707        assert_eq!(safe_char_prefix(s, 7), s);
708    }
709}