Skip to main content

dot/provider/anthropic/
mod.rs

1mod auth;
2mod stream;
3mod types;
4
5use auth::{AnthropicAuth, AuthResolved, refresh_oauth_token};
6use stream::process_sse_stream;
7use types::{AnthropicRequest, convert_messages, convert_tools};
8
9use std::{
10    collections::HashMap,
11    future::Future,
12    pin::Pin,
13    time::{SystemTime, UNIX_EPOCH},
14};
15
16use anyhow::Context;
17use tokio::sync::{mpsc, mpsc::UnboundedReceiver};
18use tracing::warn;
19
20use crate::provider::{Message, Provider, StreamEvent, StreamEventType, ToolDefinition};
21
22pub struct AnthropicProvider {
23    client: reqwest::Client,
24    model: String,
25    auth: tokio::sync::Mutex<AnthropicAuth>,
26    cached_models: std::sync::Mutex<Option<Vec<String>>>,
27    context_windows: std::sync::Mutex<HashMap<String, u32>>,
28}
29
30impl AnthropicProvider {
31    pub fn new_with_api_key(api_key: impl Into<String>, model: impl Into<String>) -> Self {
32        Self {
33            client: reqwest::Client::builder()
34                .user_agent("dot/0.1.0")
35                .build()
36                .expect("Failed to build reqwest client"),
37            model: model.into(),
38            auth: tokio::sync::Mutex::new(AnthropicAuth::ApiKey(api_key.into())),
39            cached_models: std::sync::Mutex::new(None),
40            context_windows: std::sync::Mutex::new(HashMap::new()),
41        }
42    }
43
44    pub fn new_with_oauth(
45        access_token: impl Into<String>,
46        refresh_token: impl Into<String>,
47        expires_at: i64,
48        model: impl Into<String>,
49    ) -> Self {
50        Self {
51            client: reqwest::Client::builder()
52                .user_agent("claude-code/2.1.49 (external, cli)")
53                .build()
54                .expect("Failed to build reqwest client"),
55            model: model.into(),
56            auth: tokio::sync::Mutex::new(AnthropicAuth::OAuth {
57                access_token: access_token.into(),
58                refresh_token: refresh_token.into(),
59                expires_at,
60            }),
61            cached_models: std::sync::Mutex::new(None),
62            context_windows: std::sync::Mutex::new(HashMap::new()),
63        }
64    }
65
66    async fn fetch_model_context_window(&self, model: &str) -> anyhow::Result<u32> {
67        let auth = self.resolve_auth().await?;
68        let url = format!("https://api.anthropic.com/v1/models/{model}");
69        let mut req = self
70            .client
71            .get(&url)
72            .header(&auth.header_name, &auth.header_value)
73            .header("anthropic-version", "2023-06-01");
74        if auth.is_oauth {
75            req = req
76                .header(
77                    "anthropic-beta",
78                    "oauth-2025-04-20,interleaved-thinking-2025-05-14",
79                )
80                .header("user-agent", "claude-code/2.1.49 (external, cli)");
81        }
82        let resp = req.send().await.context("Failed to fetch model info")?;
83        if !resp.status().is_success() {
84            let status = resp.status();
85            let body = resp.text().await.unwrap_or_default();
86            return Err(anyhow::anyhow!(
87                "Anthropic model API error {status}: {body}"
88            ));
89        }
90        let data: serde_json::Value = resp.json().await?;
91        data["context_window"]
92            .as_u64()
93            .map(|v| v as u32)
94            .ok_or_else(|| anyhow::anyhow!("context_window not found in model response"))
95    }
96
97    async fn resolve_auth(&self) -> anyhow::Result<AuthResolved> {
98        let mut auth = self.auth.lock().await;
99        match &*auth {
100            AnthropicAuth::ApiKey(key) => Ok(AuthResolved {
101                header_name: "x-api-key".to_string(),
102                header_value: key.clone(),
103                is_oauth: false,
104            }),
105            AnthropicAuth::OAuth {
106                access_token,
107                refresh_token,
108                expires_at,
109            } => {
110                let now = SystemTime::now()
111                    .duration_since(UNIX_EPOCH)
112                    .unwrap_or_default()
113                    .as_secs() as i64;
114                // Handle legacy millis-format expires_at from older credentials
115                let expires_at_secs = if *expires_at > 1_000_000_000_000 {
116                    *expires_at / 1000
117                } else {
118                    *expires_at
119                };
120
121                let token = if now >= expires_at_secs - 60 {
122                    let rt = refresh_token.clone();
123                    match refresh_oauth_token(&self.client, &rt).await {
124                        Ok((new_token, new_expires_at)) => {
125                            if let AnthropicAuth::OAuth {
126                                access_token,
127                                expires_at,
128                                ..
129                            } = &mut *auth
130                            {
131                                *access_token = new_token.clone();
132                                *expires_at = new_expires_at;
133                            }
134                            new_token
135                        }
136                        Err(e) => {
137                            warn!("OAuth token refresh failed: {e}");
138                            access_token.clone()
139                        }
140                    }
141                } else {
142                    access_token.clone()
143                };
144
145                Ok(AuthResolved {
146                    header_name: "Authorization".to_string(),
147                    header_value: format!("Bearer {token}"),
148                    is_oauth: true,
149                })
150            }
151        }
152    }
153}
154
155impl Provider for AnthropicProvider {
156    fn name(&self) -> &str {
157        "anthropic"
158    }
159
160    fn model(&self) -> &str {
161        &self.model
162    }
163
164    fn set_model(&mut self, model: String) {
165        self.model = model;
166    }
167
168    fn available_models(&self) -> Vec<String> {
169        let cache = self.cached_models.lock().unwrap();
170        cache.clone().unwrap_or_default()
171    }
172
173    fn context_window(&self) -> u32 {
174        let cw = self.context_windows.lock().unwrap();
175        cw.get(&self.model).copied().unwrap_or(0)
176    }
177
178    fn fetch_context_window(
179        &self,
180    ) -> Pin<Box<dyn Future<Output = anyhow::Result<u32>> + Send + '_>> {
181        Box::pin(async move {
182            {
183                let cw = self.context_windows.lock().unwrap();
184                if let Some(&val) = cw.get(&self.model) {
185                    return Ok(val);
186                }
187            }
188            let val = self.fetch_model_context_window(&self.model).await?;
189            let mut cw = self.context_windows.lock().unwrap();
190            cw.insert(self.model.clone(), val);
191            Ok(val)
192        })
193    }
194
195    fn fetch_models(
196        &self,
197    ) -> Pin<Box<dyn Future<Output = anyhow::Result<Vec<String>>> + Send + '_>> {
198        Box::pin(async move {
199            {
200                let cache = self.cached_models.lock().unwrap();
201                if let Some(ref models) = *cache {
202                    return Ok(models.clone());
203                }
204            }
205            let auth = self.resolve_auth().await?;
206            let mut all_models: Vec<String> = Vec::new();
207            let mut cw_map: HashMap<String, u32> = HashMap::new();
208            let mut after_id: Option<String> = None;
209
210            loop {
211                let mut url = "https://api.anthropic.com/v1/models?limit=1000".to_string();
212                if let Some(ref cursor) = after_id {
213                    url.push_str(&format!("&after_id={cursor}"));
214                }
215
216                let mut req = self
217                    .client
218                    .get(&url)
219                    .header(&auth.header_name, &auth.header_value)
220                    .header("anthropic-version", "2023-06-01");
221
222                if auth.is_oauth {
223                    req = req
224                        .header(
225                            "anthropic-beta",
226                            "oauth-2025-04-20,interleaved-thinking-2025-05-14",
227                        )
228                        .header("user-agent", "claude-code/2.1.49 (external, cli)");
229                }
230
231                let resp = req
232                    .send()
233                    .await
234                    .context("Failed to fetch Anthropic models")?;
235
236                if !resp.status().is_success() {
237                    let status = resp.status();
238                    let body = resp.text().await.unwrap_or_default();
239                    return Err(anyhow::anyhow!(
240                        "Anthropic models API error {status}: {body}"
241                    ));
242                }
243
244                let data: serde_json::Value = resp
245                    .json()
246                    .await
247                    .context("Failed to parse Anthropic models response")?;
248
249                if let Some(arr) = data["data"].as_array() {
250                    for m in arr {
251                        if let Some(id) = m["id"].as_str() {
252                            all_models.push(id.to_string());
253                            if let Some(cw) = m["context_window"].as_u64() {
254                                cw_map.insert(id.to_string(), cw as u32);
255                            }
256                        }
257                    }
258                }
259
260                let has_more = data["has_more"].as_bool().unwrap_or(false);
261                if !has_more {
262                    break;
263                }
264
265                match data["last_id"].as_str() {
266                    Some(last) => after_id = Some(last.to_string()),
267                    None => break,
268                }
269            }
270
271            if all_models.is_empty() {
272                return Err(anyhow::anyhow!("Anthropic models API returned empty list"));
273            }
274
275            all_models.sort();
276            let mut cache = self.cached_models.lock().unwrap();
277            *cache = Some(all_models.clone());
278            drop(cache);
279
280            let mut cw_cache = self.context_windows.lock().unwrap();
281            *cw_cache = cw_map;
282
283            Ok(all_models)
284        })
285    }
286
287    fn stream(
288        &self,
289        messages: &[Message],
290        system: Option<&str>,
291        tools: &[ToolDefinition],
292        max_tokens: u32,
293        thinking_budget: u32,
294    ) -> Pin<Box<dyn Future<Output = anyhow::Result<UnboundedReceiver<StreamEvent>>> + Send + '_>>
295    {
296        self.stream_with_model(
297            &self.model,
298            messages,
299            system,
300            tools,
301            max_tokens,
302            thinking_budget,
303        )
304    }
305
306    fn stream_with_model(
307        &self,
308        model: &str,
309        messages: &[Message],
310        system: Option<&str>,
311        tools: &[ToolDefinition],
312        max_tokens: u32,
313        thinking_budget: u32,
314    ) -> Pin<Box<dyn Future<Output = anyhow::Result<UnboundedReceiver<StreamEvent>>> + Send + '_>>
315    {
316        let messages = messages.to_vec();
317        let system = system.map(String::from);
318        let tools = tools.to_vec();
319        let model = model.to_string();
320
321        Box::pin(async move {
322            let auth = self.resolve_auth().await?;
323
324            let url = if auth.is_oauth {
325                "https://api.anthropic.com/v1/messages?beta=true".to_string()
326            } else {
327                "https://api.anthropic.com/v1/messages".to_string()
328            };
329
330            let thinking = if thinking_budget >= 1024 {
331                Some(serde_json::json!({
332                    "type": "enabled",
333                    "budget_tokens": thinking_budget,
334                }))
335            } else {
336                None
337            };
338
339            let effective_max_tokens = if thinking_budget >= 1024 {
340                max_tokens.max(thinking_budget.saturating_add(4096))
341            } else {
342                max_tokens
343            };
344
345            let body = AnthropicRequest {
346                model: &model,
347                messages: convert_messages(&messages),
348                max_tokens: effective_max_tokens,
349                stream: true,
350                system: system.as_deref(),
351                tools: convert_tools(&tools),
352                temperature: 1.0,
353                thinking,
354            };
355
356            let mut req_builder = self
357                .client
358                .post(&url)
359                .header(&auth.header_name, &auth.header_value)
360                .header("anthropic-version", "2023-06-01")
361                .header("content-type", "application/json");
362
363            if auth.is_oauth {
364                req_builder = req_builder
365                    .header(
366                        "anthropic-beta",
367                        "oauth-2025-04-20,interleaved-thinking-2025-05-14",
368                    )
369                    .header("user-agent", "claude-code/2.1.49 (external, cli)");
370            } else if thinking_budget >= 1024 {
371                req_builder =
372                    req_builder.header("anthropic-beta", "interleaved-thinking-2025-05-14");
373            }
374
375            let response = req_builder
376                .json(&body)
377                .send()
378                .await
379                .context("Failed to connect to Anthropic API")?;
380
381            if !response.status().is_success() {
382                let status = response.status();
383                let body_text = response.text().await.unwrap_or_default();
384                return Err(anyhow::anyhow!("Anthropic API error {status}: {body_text}"));
385            }
386
387            let (tx, rx) = mpsc::unbounded_channel::<StreamEvent>();
388            let tx_clone = tx.clone();
389
390            tokio::spawn(async move {
391                if let Err(e) = process_sse_stream(response, tx_clone.clone()).await {
392                    let _ = tx_clone.send(StreamEvent {
393                        event_type: StreamEventType::Error(e.to_string()),
394                    });
395                }
396            });
397
398            Ok(rx)
399        })
400    }
401}