Skip to main content

aster/providers/
anthropic.rs

1use anyhow::Result;
2use async_stream::try_stream;
3use async_trait::async_trait;
4use futures::TryStreamExt;
5use reqwest::StatusCode;
6use serde_json::Value;
7use std::io;
8use tokio::pin;
9use tokio_util::io::StreamReader;
10
11use super::api_client::{ApiClient, ApiResponse, AuthMethod};
12use super::base::{ConfigKey, MessageStream, ModelInfo, Provider, ProviderMetadata, ProviderUsage};
13use super::errors::ProviderError;
14use super::formats::anthropic::{
15    create_request, get_usage, response_to_message, response_to_streaming_message,
16};
17use super::utils::{get_model, handle_status_openai_compat, map_http_error_to_provider_error};
18use crate::config::declarative_providers::DeclarativeProviderConfig;
19use crate::conversation::message::Message;
20use crate::model::ModelConfig;
21use crate::providers::retry::ProviderRetry;
22use crate::providers::utils::RequestLog;
23use rmcp::model::Tool;
24
25pub const ANTHROPIC_DEFAULT_MODEL: &str = "claude-sonnet-4-5";
26const ANTHROPIC_DEFAULT_FAST_MODEL: &str = "claude-haiku-4-5";
27const ANTHROPIC_KNOWN_MODELS: &[&str] = &[
28    // Claude 4.5 models with aliases
29    "claude-sonnet-4-5",
30    "claude-sonnet-4-5-20250929",
31    "claude-haiku-4-5",
32    "claude-haiku-4-5-20251001",
33    "claude-opus-4-5",
34    "claude-opus-4-5-20251101",
35];
36
37const ANTHROPIC_DOC_URL: &str = "https://docs.anthropic.com/en/docs/about-claude/models";
38const ANTHROPIC_API_VERSION: &str = "2023-06-01";
39
40#[derive(serde::Serialize)]
41pub struct AnthropicProvider {
42    #[serde(skip)]
43    api_client: ApiClient,
44    model: ModelConfig,
45    supports_streaming: bool,
46    name: String,
47}
48
49impl AnthropicProvider {
50    pub async fn from_env(model: ModelConfig) -> Result<Self> {
51        let model = model.with_fast(ANTHROPIC_DEFAULT_FAST_MODEL.to_string());
52
53        let config = crate::config::Config::global();
54        let api_key: String = config.get_secret("ANTHROPIC_API_KEY")?;
55        let host: String = config
56            .get_param("ANTHROPIC_HOST")
57            .or_else(|_| config.get_param("ANTHROPIC_BASE_URL"))
58            .unwrap_or_else(|_| "https://api.anthropic.com".to_string());
59
60        let auth = AuthMethod::ApiKey {
61            header_name: "x-api-key".to_string(),
62            key: api_key,
63        };
64
65        let api_client =
66            ApiClient::new(host, auth)?.with_header("anthropic-version", ANTHROPIC_API_VERSION)?;
67
68        Ok(Self {
69            api_client,
70            model,
71            supports_streaming: true,
72            name: Self::metadata().name,
73        })
74    }
75
76    pub fn from_custom_config(
77        model: ModelConfig,
78        config: DeclarativeProviderConfig,
79    ) -> Result<Self> {
80        let global_config = crate::config::Config::global();
81        let api_key: String = global_config
82            .get_secret(&config.api_key_env)
83            .map_err(|_| anyhow::anyhow!("Missing API key: {}", config.api_key_env))?;
84
85        let auth = AuthMethod::ApiKey {
86            header_name: "x-api-key".to_string(),
87            key: api_key,
88        };
89
90        let api_client = ApiClient::new(config.base_url, auth)?
91            .with_header("anthropic-version", ANTHROPIC_API_VERSION)?;
92
93        Ok(Self {
94            api_client,
95            model,
96            supports_streaming: config.supports_streaming.unwrap_or(true),
97            name: config.name.clone(),
98        })
99    }
100
101    fn get_conditional_headers(&self) -> Vec<(&str, &str)> {
102        let mut headers = Vec::new();
103
104        let is_thinking_enabled = std::env::var("CLAUDE_THINKING_ENABLED").is_ok();
105        if self.model.model_name.starts_with("claude-3-7-sonnet-") {
106            if is_thinking_enabled {
107                headers.push(("anthropic-beta", "output-128k-2025-02-19"));
108            }
109            headers.push(("anthropic-beta", "token-efficient-tools-2025-02-19"));
110        }
111
112        headers
113    }
114
115    async fn post(&self, payload: &Value) -> Result<ApiResponse, ProviderError> {
116        let mut request = self.api_client.request("v1/messages");
117
118        for (key, value) in self.get_conditional_headers() {
119            request = request.header(key, value)?;
120        }
121
122        Ok(request.api_post(payload).await?)
123    }
124
125    fn anthropic_api_call_result(response: ApiResponse) -> Result<Value, ProviderError> {
126        match response.status {
127            StatusCode::OK => response.payload.ok_or_else(|| {
128                ProviderError::RequestFailed("Response body is not valid JSON".to_string())
129            }),
130            _ => {
131                if response.status == StatusCode::BAD_REQUEST {
132                    if let Some(error_msg) = response
133                        .payload
134                        .as_ref()
135                        .and_then(|p| p.get("error"))
136                        .and_then(|e| e.get("message"))
137                        .and_then(|m| m.as_str())
138                    {
139                        let msg = error_msg.to_string();
140                        if msg.to_lowercase().contains("too long")
141                            || msg.to_lowercase().contains("too many")
142                        {
143                            return Err(ProviderError::ContextLengthExceeded(msg));
144                        }
145                    }
146                }
147                Err(map_http_error_to_provider_error(
148                    response.status,
149                    response.payload,
150                ))
151            }
152        }
153    }
154}
155
156#[async_trait]
157impl Provider for AnthropicProvider {
158    fn metadata() -> ProviderMetadata {
159        let models: Vec<ModelInfo> = ANTHROPIC_KNOWN_MODELS
160            .iter()
161            .map(|&model_name| ModelInfo::new(model_name, 200_000))
162            .collect();
163
164        ProviderMetadata::with_models(
165            "anthropic",
166            "Anthropic",
167            "Claude and other models from Anthropic",
168            ANTHROPIC_DEFAULT_MODEL,
169            models,
170            ANTHROPIC_DOC_URL,
171            vec![
172                ConfigKey::new("ANTHROPIC_API_KEY", true, true, None),
173                ConfigKey::new(
174                    "ANTHROPIC_HOST",
175                    true,
176                    false,
177                    Some("https://api.anthropic.com"),
178                ),
179            ],
180        )
181    }
182
183    fn get_name(&self) -> &str {
184        &self.name
185    }
186
187    fn get_model_config(&self) -> ModelConfig {
188        self.model.clone()
189    }
190
191    #[tracing::instrument(
192        skip(self, model_config, system, messages, tools),
193        fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
194    )]
195    async fn complete_with_model(
196        &self,
197        model_config: &ModelConfig,
198        system: &str,
199        messages: &[Message],
200        tools: &[Tool],
201    ) -> Result<(Message, ProviderUsage), ProviderError> {
202        let payload = create_request(model_config, system, messages, tools)?;
203
204        let response = self
205            .with_retry(|| async { self.post(&payload).await })
206            .await?;
207
208        let json_response = Self::anthropic_api_call_result(response)?;
209
210        let message = response_to_message(&json_response)?;
211        let usage = get_usage(&json_response)?;
212        tracing::debug!("🔍 Anthropic non-streaming parsed usage: input_tokens={:?}, output_tokens={:?}, total_tokens={:?}",
213                usage.input_tokens, usage.output_tokens, usage.total_tokens);
214
215        let response_model = get_model(&json_response);
216        let mut log = RequestLog::start(&self.model, &payload)?;
217        log.write(&json_response, Some(&usage))?;
218        let provider_usage = ProviderUsage::new(response_model, usage);
219        tracing::debug!(
220            "🔍 Anthropic non-streaming returning ProviderUsage: {:?}",
221            provider_usage
222        );
223        Ok((message, provider_usage))
224    }
225
226    async fn fetch_supported_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
227        let response = self.api_client.api_get("v1/models").await?;
228
229        if response.status != StatusCode::OK {
230            return Err(map_http_error_to_provider_error(
231                response.status,
232                response.payload,
233            ));
234        }
235
236        let json = response.payload.unwrap_or_default();
237        let arr = match json.get("data").and_then(|v| v.as_array()) {
238            Some(arr) => arr,
239            None => return Ok(None),
240        };
241
242        let mut models: Vec<String> = arr
243            .iter()
244            .filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(str::to_string))
245            .collect();
246        models.sort();
247        Ok(Some(models))
248    }
249
250    async fn stream(
251        &self,
252        system: &str,
253        messages: &[Message],
254        tools: &[Tool],
255    ) -> Result<MessageStream, ProviderError> {
256        let mut payload = create_request(&self.model, system, messages, tools)?;
257        payload
258            .as_object_mut()
259            .unwrap()
260            .insert("stream".to_string(), Value::Bool(true));
261
262        let mut request = self.api_client.request("v1/messages");
263        let mut log = RequestLog::start(&self.model, &payload)?;
264
265        for (key, value) in self.get_conditional_headers() {
266            request = request.header(key, value)?;
267        }
268
269        let resp = request.response_post(&payload).await.inspect_err(|e| {
270            let _ = log.error(e);
271        })?;
272        let response = handle_status_openai_compat(resp).await.inspect_err(|e| {
273            let _ = log.error(e);
274        })?;
275
276        let stream = response.bytes_stream().map_err(io::Error::other);
277
278        Ok(Box::pin(try_stream! {
279            let stream_reader = StreamReader::new(stream);
280            let framed = tokio_util::codec::FramedRead::new(stream_reader, tokio_util::codec::LinesCodec::new()).map_err(anyhow::Error::from);
281
282            let message_stream = response_to_streaming_message(framed);
283            pin!(message_stream);
284            while let Some(message) = futures::StreamExt::next(&mut message_stream).await {
285                let (message, usage) = message.map_err(|e| ProviderError::RequestFailed(format!("Stream decode error: {}", e)))?;
286                log.write(&message, usage.as_ref().map(|f| f.usage).as_ref())?;
287                yield (message, usage);
288            }
289        }))
290    }
291
292    fn supports_streaming(&self) -> bool {
293        self.supports_streaming
294    }
295}