Skip to main content

wraith_api/providers/
anthropic.rs

1use std::collections::VecDeque;
2use std::time::{Duration, SystemTime, UNIX_EPOCH};
3
4use runtime::{
5    load_oauth_credentials, save_oauth_credentials, OAuthConfig, OAuthRefreshRequest,
6    OAuthTokenExchangeRequest,
7};
8use serde::Deserialize;
9
10use crate::error::ApiError;
11
12use super::{Provider, ProviderFuture};
13use crate::sse::SseParser;
14use crate::types::{MessageRequest, MessageResponse, StreamEvent};
15
16pub const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
17const ANTHROPIC_VERSION: &str = "2023-06-01";
18const REQUEST_ID_HEADER: &str = "request-id";
19const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
20const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200);
21const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2);
22const DEFAULT_MAX_RETRIES: u32 = 2;
23
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub enum AuthSource {
26    None,
27    ApiKey(String),
28    BearerToken(String),
29    ApiKeyAndBearer {
30        api_key: String,
31        bearer_token: String,
32    },
33}
34
35impl AuthSource {
36    pub fn from_env() -> Result<Self, ApiError> {
37        let api_key = read_env_non_empty("ANTHROPIC_API_KEY")?;
38        let auth_token = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?;
39        match (api_key, auth_token) {
40            (Some(api_key), Some(bearer_token)) => Ok(Self::ApiKeyAndBearer {
41                api_key,
42                bearer_token,
43            }),
44            (Some(api_key), None) => Ok(Self::ApiKey(api_key)),
45            (None, Some(bearer_token)) => Ok(Self::BearerToken(bearer_token)),
46            (None, None) => Err(ApiError::missing_credentials(
47                "Wraith",
48                &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"],
49            )),
50        }
51    }
52
53    #[must_use]
54    pub fn api_key(&self) -> Option<&str> {
55        match self {
56            Self::ApiKey(api_key) | Self::ApiKeyAndBearer { api_key, .. } => Some(api_key),
57            Self::None | Self::BearerToken(_) => None,
58        }
59    }
60
61    #[must_use]
62    pub fn bearer_token(&self) -> Option<&str> {
63        match self {
64            Self::BearerToken(token)
65            | Self::ApiKeyAndBearer {
66                bearer_token: token,
67                ..
68            } => Some(token),
69            Self::None | Self::ApiKey(_) => None,
70        }
71    }
72
73    #[must_use]
74    pub fn masked_authorization_header(&self) -> &'static str {
75        if self.bearer_token().is_some() {
76            "Bearer [REDACTED]"
77        } else {
78            "<absent>"
79        }
80    }
81
82    pub fn apply(&self, mut request_builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
83        if let Some(api_key) = self.api_key() {
84            request_builder = request_builder.header("x-api-key", api_key);
85        }
86        if let Some(token) = self.bearer_token() {
87            request_builder = request_builder.bearer_auth(token);
88        }
89        request_builder
90    }
91}
92
93#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
94pub struct OAuthTokenSet {
95    pub access_token: String,
96    pub refresh_token: Option<String>,
97    pub expires_at: Option<u64>,
98    #[serde(default)]
99    pub scopes: Vec<String>,
100}
101
102impl From<OAuthTokenSet> for AuthSource {
103    fn from(value: OAuthTokenSet) -> Self {
104        Self::BearerToken(value.access_token)
105    }
106}
107
108#[derive(Debug, Clone)]
109pub struct AnthropicClient {
110    http: reqwest::Client,
111    auth: AuthSource,
112    base_url: String,
113    max_retries: u32,
114    initial_backoff: Duration,
115    max_backoff: Duration,
116}
117
118impl AnthropicClient {
119    #[must_use]
120    pub fn new(api_key: impl Into<String>) -> Self {
121        Self {
122            http: reqwest::Client::new(),
123            auth: AuthSource::ApiKey(api_key.into()),
124            base_url: DEFAULT_BASE_URL.to_string(),
125            max_retries: DEFAULT_MAX_RETRIES,
126            initial_backoff: DEFAULT_INITIAL_BACKOFF,
127            max_backoff: DEFAULT_MAX_BACKOFF,
128        }
129    }
130
131    #[must_use]
132    pub fn from_auth(auth: AuthSource) -> Self {
133        Self {
134            http: reqwest::Client::new(),
135            auth,
136            base_url: DEFAULT_BASE_URL.to_string(),
137            max_retries: DEFAULT_MAX_RETRIES,
138            initial_backoff: DEFAULT_INITIAL_BACKOFF,
139            max_backoff: DEFAULT_MAX_BACKOFF,
140        }
141    }
142
143    pub fn from_env() -> Result<Self, ApiError> {
144        Ok(Self::from_auth(AuthSource::from_env_or_saved()?).with_base_url(read_base_url()))
145    }
146
147    #[must_use]
148    pub fn with_auth_source(mut self, auth: AuthSource) -> Self {
149        self.auth = auth;
150        self
151    }
152
153    #[must_use]
154    pub fn with_auth_token(mut self, auth_token: Option<String>) -> Self {
155        match (
156            self.auth.api_key().map(ToOwned::to_owned),
157            auth_token.filter(|token| !token.is_empty()),
158        ) {
159            (Some(api_key), Some(bearer_token)) => {
160                self.auth = AuthSource::ApiKeyAndBearer {
161                    api_key,
162                    bearer_token,
163                };
164            }
165            (Some(api_key), None) => {
166                self.auth = AuthSource::ApiKey(api_key);
167            }
168            (None, Some(bearer_token)) => {
169                self.auth = AuthSource::BearerToken(bearer_token);
170            }
171            (None, None) => {
172                self.auth = AuthSource::None;
173            }
174        }
175        self
176    }
177
178    #[must_use]
179    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
180        self.base_url = base_url.into();
181        self
182    }
183
184    #[must_use]
185    pub fn with_retry_policy(
186        mut self,
187        max_retries: u32,
188        initial_backoff: Duration,
189        max_backoff: Duration,
190    ) -> Self {
191        self.max_retries = max_retries;
192        self.initial_backoff = initial_backoff;
193        self.max_backoff = max_backoff;
194        self
195    }
196
197    #[must_use]
198    pub fn auth_source(&self) -> &AuthSource {
199        &self.auth
200    }
201
202    pub async fn send_message(
203        &self,
204        request: &MessageRequest,
205    ) -> Result<MessageResponse, ApiError> {
206        let request = MessageRequest {
207            stream: false,
208            ..request.clone()
209        };
210        let response = self.send_with_retry(&request).await?;
211        let request_id = request_id_from_headers(response.headers());
212        let mut response = response
213            .json::<MessageResponse>()
214            .await
215            .map_err(ApiError::from)?;
216        if response.request_id.is_none() {
217            response.request_id = request_id;
218        }
219        Ok(response)
220    }
221
222    pub async fn stream_message(
223        &self,
224        request: &MessageRequest,
225    ) -> Result<MessageStream, ApiError> {
226        let response = self
227            .send_with_retry(&request.clone().with_streaming())
228            .await?;
229        Ok(MessageStream {
230            request_id: request_id_from_headers(response.headers()),
231            response,
232            parser: SseParser::new(),
233            pending: VecDeque::new(),
234            done: false,
235        })
236    }
237
238    pub async fn exchange_oauth_code(
239        &self,
240        config: &OAuthConfig,
241        request: &OAuthTokenExchangeRequest,
242    ) -> Result<OAuthTokenSet, ApiError> {
243        let response = self
244            .http
245            .post(&config.token_url)
246            .header("content-type", "application/x-www-form-urlencoded")
247            .form(&request.form_params())
248            .send()
249            .await
250            .map_err(ApiError::from)?;
251        let response = expect_success(response).await?;
252        response
253            .json::<OAuthTokenSet>()
254            .await
255            .map_err(ApiError::from)
256    }
257
258    pub async fn refresh_oauth_token(
259        &self,
260        config: &OAuthConfig,
261        request: &OAuthRefreshRequest,
262    ) -> Result<OAuthTokenSet, ApiError> {
263        let response = self
264            .http
265            .post(&config.token_url)
266            .header("content-type", "application/x-www-form-urlencoded")
267            .form(&request.form_params())
268            .send()
269            .await
270            .map_err(ApiError::from)?;
271        let response = expect_success(response).await?;
272        response
273            .json::<OAuthTokenSet>()
274            .await
275            .map_err(ApiError::from)
276    }
277
278    async fn send_with_retry(
279        &self,
280        request: &MessageRequest,
281    ) -> Result<reqwest::Response, ApiError> {
282        let mut attempts = 0;
283        let mut last_error: Option<ApiError>;
284
285        loop {
286            attempts += 1;
287            match self.send_raw_request(request).await {
288                Ok(response) => match expect_success(response).await {
289                    Ok(response) => return Ok(response),
290                    Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => {
291                        last_error = Some(error);
292                    }
293                    Err(error) => return Err(error),
294                },
295                Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => {
296                    last_error = Some(error);
297                }
298                Err(error) => return Err(error),
299            }
300
301            if attempts > self.max_retries {
302                break;
303            }
304
305            tokio::time::sleep(self.backoff_for_attempt(attempts)?).await;
306        }
307
308        Err(ApiError::RetriesExhausted {
309            attempts,
310            last_error: Box::new(last_error.expect("retry loop must capture an error")),
311        })
312    }
313
314    async fn send_raw_request(
315        &self,
316        request: &MessageRequest,
317    ) -> Result<reqwest::Response, ApiError> {
318        let request_url = format!("{}/v1/messages", self.base_url.trim_end_matches('/'));
319        let request_builder = self
320            .http
321            .post(&request_url)
322            .header("anthropic-version", ANTHROPIC_VERSION)
323            .header("content-type", "application/json");
324        let mut request_builder = self.auth.apply(request_builder);
325
326        request_builder = request_builder.json(request);
327        request_builder.send().await.map_err(ApiError::from)
328    }
329
330    fn backoff_for_attempt(&self, attempt: u32) -> Result<Duration, ApiError> {
331        let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else {
332            return Err(ApiError::BackoffOverflow {
333                attempt,
334                base_delay: self.initial_backoff,
335            });
336        };
337        Ok(self
338            .initial_backoff
339            .checked_mul(multiplier)
340            .map_or(self.max_backoff, |delay| delay.min(self.max_backoff)))
341    }
342}
343
344impl AuthSource {
345    pub fn from_env_or_saved() -> Result<Self, ApiError> {
346        if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? {
347            return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
348                Some(bearer_token) => Ok(Self::ApiKeyAndBearer {
349                    api_key,
350                    bearer_token,
351                }),
352                None => Ok(Self::ApiKey(api_key)),
353            };
354        }
355        if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
356            return Ok(Self::BearerToken(bearer_token));
357        }
358        match load_saved_oauth_token() {
359            Ok(Some(token_set)) if oauth_token_is_expired(&token_set) => {
360                if token_set.refresh_token.is_some() {
361                    Err(ApiError::Auth(
362                        "saved OAuth token is expired; load runtime OAuth config to refresh it"
363                            .to_string(),
364                    ))
365                } else {
366                    Err(ApiError::ExpiredOAuthToken)
367                }
368            }
369            Ok(Some(token_set)) => Ok(Self::BearerToken(token_set.access_token)),
370            Ok(None) => Err(ApiError::missing_credentials(
371                "Wraith",
372                &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"],
373            )),
374            Err(error) => Err(error),
375        }
376    }
377}
378
379#[must_use]
380pub fn oauth_token_is_expired(token_set: &OAuthTokenSet) -> bool {
381    token_set
382        .expires_at
383        .is_some_and(|expires_at| expires_at <= now_unix_timestamp())
384}
385
386pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result<Option<OAuthTokenSet>, ApiError> {
387    let Some(token_set) = load_saved_oauth_token()? else {
388        return Ok(None);
389    };
390    resolve_saved_oauth_token_set(config, token_set).map(Some)
391}
392
393pub fn has_auth_from_env_or_saved() -> Result<bool, ApiError> {
394    Ok(read_env_non_empty("ANTHROPIC_API_KEY")?.is_some()
395        || read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?.is_some()
396        || load_saved_oauth_token()?.is_some())
397}
398
399pub fn resolve_startup_auth_source<F>(load_oauth_config: F) -> Result<AuthSource, ApiError>
400where
401    F: FnOnce() -> Result<Option<OAuthConfig>, ApiError>,
402{
403    if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? {
404        return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
405            Some(bearer_token) => Ok(AuthSource::ApiKeyAndBearer {
406                api_key,
407                bearer_token,
408            }),
409            None => Ok(AuthSource::ApiKey(api_key)),
410        };
411    }
412    if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
413        return Ok(AuthSource::BearerToken(bearer_token));
414    }
415
416    let Some(token_set) = load_saved_oauth_token()? else {
417        return Err(ApiError::missing_credentials(
418            "Wraith",
419            &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"],
420        ));
421    };
422    if !oauth_token_is_expired(&token_set) {
423        return Ok(AuthSource::BearerToken(token_set.access_token));
424    }
425    if token_set.refresh_token.is_none() {
426        return Err(ApiError::ExpiredOAuthToken);
427    }
428
429    let Some(config) = load_oauth_config()? else {
430        return Err(ApiError::Auth(
431            "saved OAuth token is expired; runtime OAuth config is missing".to_string(),
432        ));
433    };
434    Ok(AuthSource::from(resolve_saved_oauth_token_set(
435        &config, token_set,
436    )?))
437}
438
439fn resolve_saved_oauth_token_set(
440    config: &OAuthConfig,
441    token_set: OAuthTokenSet,
442) -> Result<OAuthTokenSet, ApiError> {
443    if !oauth_token_is_expired(&token_set) {
444        return Ok(token_set);
445    }
446    let Some(refresh_token) = token_set.refresh_token.clone() else {
447        return Err(ApiError::ExpiredOAuthToken);
448    };
449    let client = AnthropicClient::from_auth(AuthSource::None).with_base_url(read_base_url());
450    let refreshed = client_runtime_block_on(async {
451        client
452            .refresh_oauth_token(
453                config,
454                &OAuthRefreshRequest::from_config(
455                    config,
456                    refresh_token,
457                    Some(token_set.scopes.clone()),
458                ),
459            )
460            .await
461    })?;
462    let resolved = OAuthTokenSet {
463        access_token: refreshed.access_token,
464        refresh_token: refreshed.refresh_token.or(token_set.refresh_token),
465        expires_at: refreshed.expires_at,
466        scopes: refreshed.scopes,
467    };
468    save_oauth_credentials(&runtime::OAuthTokenSet {
469        access_token: resolved.access_token.clone(),
470        refresh_token: resolved.refresh_token.clone(),
471        expires_at: resolved.expires_at,
472        scopes: resolved.scopes.clone(),
473    })
474    .map_err(ApiError::from)?;
475    Ok(resolved)
476}
477
478fn client_runtime_block_on<F, T>(future: F) -> Result<T, ApiError>
479where
480    F: std::future::Future<Output = Result<T, ApiError>>,
481{
482    tokio::runtime::Runtime::new()
483        .map_err(ApiError::from)?
484        .block_on(future)
485}
486
487fn load_saved_oauth_token() -> Result<Option<OAuthTokenSet>, ApiError> {
488    let token_set = load_oauth_credentials().map_err(ApiError::from)?;
489    Ok(token_set.map(|token_set| OAuthTokenSet {
490        access_token: token_set.access_token,
491        refresh_token: token_set.refresh_token,
492        expires_at: token_set.expires_at,
493        scopes: token_set.scopes,
494    }))
495}
496
497fn now_unix_timestamp() -> u64 {
498    SystemTime::now()
499        .duration_since(UNIX_EPOCH)
500        .map_or(0, |duration| duration.as_secs())
501}
502
503fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
504    match std::env::var(key) {
505        Ok(value) if !value.is_empty() => Ok(Some(value)),
506        Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None),
507        Err(error) => Err(ApiError::from(error)),
508    }
509}
510
511#[cfg(test)]
512fn read_api_key() -> Result<String, ApiError> {
513    let auth = AuthSource::from_env_or_saved()?;
514    auth.api_key()
515        .or_else(|| auth.bearer_token())
516        .map(ToOwned::to_owned)
517        .ok_or(ApiError::missing_credentials(
518            "Wraith",
519            &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"],
520        ))
521}
522
523#[cfg(test)]
524fn read_auth_token() -> Option<String> {
525    read_env_non_empty("ANTHROPIC_AUTH_TOKEN")
526        .ok()
527        .and_then(std::convert::identity)
528}
529
530#[must_use]
531pub fn read_base_url() -> String {
532    std::env::var("ANTHROPIC_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string())
533}
534
535fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
536    headers
537        .get(REQUEST_ID_HEADER)
538        .or_else(|| headers.get(ALT_REQUEST_ID_HEADER))
539        .and_then(|value| value.to_str().ok())
540        .map(ToOwned::to_owned)
541}
542
543impl Provider for AnthropicClient {
544    type Stream = MessageStream;
545
546    fn send_message<'a>(
547        &'a self,
548        request: &'a MessageRequest,
549    ) -> ProviderFuture<'a, MessageResponse> {
550        Box::pin(async move { self.send_message(request).await })
551    }
552
553    fn stream_message<'a>(
554        &'a self,
555        request: &'a MessageRequest,
556    ) -> ProviderFuture<'a, Self::Stream> {
557        Box::pin(async move { self.stream_message(request).await })
558    }
559}
560
561#[derive(Debug)]
562pub struct MessageStream {
563    request_id: Option<String>,
564    response: reqwest::Response,
565    parser: SseParser,
566    pending: VecDeque<StreamEvent>,
567    done: bool,
568}
569
570impl MessageStream {
571    #[must_use]
572    pub fn request_id(&self) -> Option<&str> {
573        self.request_id.as_deref()
574    }
575
576    pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
577        loop {
578            if let Some(event) = self.pending.pop_front() {
579                return Ok(Some(event));
580            }
581
582            if self.done {
583                let remaining = self.parser.finish()?;
584                self.pending.extend(remaining);
585                if let Some(event) = self.pending.pop_front() {
586                    return Ok(Some(event));
587                }
588                return Ok(None);
589            }
590
591            match self.response.chunk().await? {
592                Some(chunk) => {
593                    self.pending.extend(self.parser.push(&chunk)?);
594                }
595                None => {
596                    self.done = true;
597                }
598            }
599        }
600    }
601}
602
603async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
604    let status = response.status();
605    if status.is_success() {
606        return Ok(response);
607    }
608
609    let body = response.text().await.unwrap_or_else(|_| String::new());
610    let parsed_error = serde_json::from_str::<ApiErrorEnvelope>(&body).ok();
611    let retryable = is_retryable_status(status);
612
613    Err(ApiError::Api {
614        status,
615        error_type: parsed_error
616            .as_ref()
617            .map(|error| error.error.error_type.clone()),
618        message: parsed_error
619            .as_ref()
620            .map(|error| error.error.message.clone()),
621        body,
622        retryable,
623    })
624}
625
626const fn is_retryable_status(status: reqwest::StatusCode) -> bool {
627    matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504)
628}
629
630#[derive(Debug, Deserialize)]
631struct ApiErrorEnvelope {
632    error: ApiErrorBody,
633}
634
635#[derive(Debug, Deserialize)]
636struct ApiErrorBody {
637    #[serde(rename = "type")]
638    error_type: String,
639    message: String,
640}
641
642#[cfg(test)]
643mod tests {
644    use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER};
645    use std::io::{Read, Write};
646    use std::net::TcpListener;
647    use std::sync::{Mutex, OnceLock};
648    use std::thread;
649    use std::time::{Duration, SystemTime, UNIX_EPOCH};
650
651    use runtime::{clear_oauth_credentials, save_oauth_credentials, OAuthConfig};
652
653    use super::{
654        now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token,
655        resolve_startup_auth_source, AuthSource, AnthropicClient, OAuthTokenSet,
656    };
657    use crate::types::{ContentBlockDelta, MessageRequest};
658
659    fn env_lock() -> std::sync::MutexGuard<'static, ()> {
660        static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
661        LOCK.get_or_init(|| Mutex::new(()))
662            .lock()
663            .unwrap_or_else(std::sync::PoisonError::into_inner)
664    }
665
666    fn temp_config_home() -> std::path::PathBuf {
667        std::env::temp_dir().join(format!(
668            "api-oauth-test-{}-{}",
669            std::process::id(),
670            SystemTime::now()
671                .duration_since(UNIX_EPOCH)
672                .expect("time")
673                .as_nanos()
674        ))
675    }
676
677    fn cleanup_temp_config_home(config_home: &std::path::Path) {
678        match std::fs::remove_dir_all(config_home) {
679            Ok(()) => {}
680            Err(error) if error.kind() == std::io::ErrorKind::NotFound => {}
681            Err(error) => panic!("cleanup temp dir: {error}"),
682        }
683    }
684
685    fn sample_oauth_config(token_url: String) -> OAuthConfig {
686        OAuthConfig {
687            client_id: "runtime-client".to_string(),
688            authorize_url: "https://console.test/oauth/authorize".to_string(),
689            token_url,
690            callback_port: Some(4545),
691            manual_redirect_url: Some("https://console.test/oauth/callback".to_string()),
692            scopes: vec!["org:read".to_string(), "user:write".to_string()],
693        }
694    }
695
696    fn spawn_token_server(response_body: &'static str) -> String {
697        let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener");
698        let address = listener.local_addr().expect("local addr");
699        thread::spawn(move || {
700            let (mut stream, _) = listener.accept().expect("accept connection");
701            let mut buffer = [0_u8; 4096];
702            let _ = stream.read(&mut buffer).expect("read request");
703            let response = format!(
704                "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{}",
705                response_body.len(),
706                response_body
707            );
708            stream
709                .write_all(response.as_bytes())
710                .expect("write response");
711        });
712        format!("http://{address}/oauth/token")
713    }
714
715    #[test]
716    fn read_api_key_requires_presence() {
717        let _guard = env_lock();
718        std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
719        std::env::remove_var("ANTHROPIC_API_KEY");
720        std::env::remove_var("WRAITH_CONFIG_HOME");
721        let error = super::read_api_key().expect_err("missing key should error");
722        assert!(matches!(
723            error,
724            crate::error::ApiError::MissingCredentials { .. }
725        ));
726    }
727
728    #[test]
729    fn read_api_key_requires_non_empty_value() {
730        let _guard = env_lock();
731        std::env::set_var("ANTHROPIC_AUTH_TOKEN", "");
732        std::env::remove_var("ANTHROPIC_API_KEY");
733        let error = super::read_api_key().expect_err("empty key should error");
734        assert!(matches!(
735            error,
736            crate::error::ApiError::MissingCredentials { .. }
737        ));
738        std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
739    }
740
741    #[test]
742    fn read_api_key_prefers_api_key_env() {
743        let _guard = env_lock();
744        std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
745        std::env::set_var("ANTHROPIC_API_KEY", "legacy-key");
746        assert_eq!(
747            super::read_api_key().expect("api key should load"),
748            "legacy-key"
749        );
750        std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
751        std::env::remove_var("ANTHROPIC_API_KEY");
752    }
753
754    #[test]
755    fn read_auth_token_reads_auth_token_env() {
756        let _guard = env_lock();
757        std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
758        assert_eq!(super::read_auth_token().as_deref(), Some("auth-token"));
759        std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
760    }
761
762    #[test]
763    fn oauth_token_maps_to_bearer_auth_source() {
764        let auth = AuthSource::from(OAuthTokenSet {
765            access_token: "access-token".to_string(),
766            refresh_token: Some("refresh".to_string()),
767            expires_at: Some(123),
768            scopes: vec!["scope:a".to_string()],
769        });
770        assert_eq!(auth.bearer_token(), Some("access-token"));
771        assert_eq!(auth.api_key(), None);
772    }
773
774    #[test]
775    fn auth_source_from_env_combines_api_key_and_bearer_token() {
776        let _guard = env_lock();
777        std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
778        std::env::set_var("ANTHROPIC_API_KEY", "legacy-key");
779        let auth = AuthSource::from_env().expect("env auth");
780        assert_eq!(auth.api_key(), Some("legacy-key"));
781        assert_eq!(auth.bearer_token(), Some("auth-token"));
782        std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
783        std::env::remove_var("ANTHROPIC_API_KEY");
784    }
785
786    #[test]
787    fn auth_source_from_saved_oauth_when_env_absent() {
788        let _guard = env_lock();
789        let config_home = temp_config_home();
790        std::env::set_var("WRAITH_CONFIG_HOME", &config_home);
791        std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
792        std::env::remove_var("ANTHROPIC_API_KEY");
793        save_oauth_credentials(&runtime::OAuthTokenSet {
794            access_token: "saved-access-token".to_string(),
795            refresh_token: Some("refresh".to_string()),
796            expires_at: Some(now_unix_timestamp() + 300),
797            scopes: vec!["scope:a".to_string()],
798        })
799        .expect("save oauth credentials");
800
801        let auth = AuthSource::from_env_or_saved().expect("saved auth");
802        assert_eq!(auth.bearer_token(), Some("saved-access-token"));
803
804        clear_oauth_credentials().expect("clear credentials");
805        std::env::remove_var("WRAITH_CONFIG_HOME");
806        cleanup_temp_config_home(&config_home);
807    }
808
809    #[test]
810    fn oauth_token_expiry_uses_expires_at_timestamp() {
811        assert!(oauth_token_is_expired(&OAuthTokenSet {
812            access_token: "access-token".to_string(),
813            refresh_token: None,
814            expires_at: Some(1),
815            scopes: Vec::new(),
816        }));
817        assert!(!oauth_token_is_expired(&OAuthTokenSet {
818            access_token: "access-token".to_string(),
819            refresh_token: None,
820            expires_at: Some(now_unix_timestamp() + 60),
821            scopes: Vec::new(),
822        }));
823    }
824
825    #[test]
826    fn resolve_saved_oauth_token_refreshes_expired_credentials() {
827        let _guard = env_lock();
828        let config_home = temp_config_home();
829        std::env::set_var("WRAITH_CONFIG_HOME", &config_home);
830        std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
831        std::env::remove_var("ANTHROPIC_API_KEY");
832        save_oauth_credentials(&runtime::OAuthTokenSet {
833            access_token: "expired-access-token".to_string(),
834            refresh_token: Some("refresh-token".to_string()),
835            expires_at: Some(1),
836            scopes: vec!["scope:a".to_string()],
837        })
838        .expect("save expired oauth credentials");
839
840        let token_url = spawn_token_server(
841            "{\"access_token\":\"refreshed-token\",\"refresh_token\":\"fresh-refresh\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}",
842        );
843        let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url))
844            .expect("resolve refreshed token")
845            .expect("token set present");
846        assert_eq!(resolved.access_token, "refreshed-token");
847        let stored = runtime::load_oauth_credentials()
848            .expect("load stored credentials")
849            .expect("stored token set");
850        assert_eq!(stored.access_token, "refreshed-token");
851
852        clear_oauth_credentials().expect("clear credentials");
853        std::env::remove_var("WRAITH_CONFIG_HOME");
854        cleanup_temp_config_home(&config_home);
855    }
856
857    #[test]
858    fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() {
859        let _guard = env_lock();
860        let config_home = temp_config_home();
861        std::env::set_var("WRAITH_CONFIG_HOME", &config_home);
862        std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
863        std::env::remove_var("ANTHROPIC_API_KEY");
864        save_oauth_credentials(&runtime::OAuthTokenSet {
865            access_token: "saved-access-token".to_string(),
866            refresh_token: Some("refresh".to_string()),
867            expires_at: Some(now_unix_timestamp() + 300),
868            scopes: vec!["scope:a".to_string()],
869        })
870        .expect("save oauth credentials");
871
872        let auth = resolve_startup_auth_source(|| panic!("config should not be loaded"))
873            .expect("startup auth");
874        assert_eq!(auth.bearer_token(), Some("saved-access-token"));
875
876        clear_oauth_credentials().expect("clear credentials");
877        std::env::remove_var("WRAITH_CONFIG_HOME");
878        cleanup_temp_config_home(&config_home);
879    }
880
881    #[test]
882    fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() {
883        let _guard = env_lock();
884        let config_home = temp_config_home();
885        std::env::set_var("WRAITH_CONFIG_HOME", &config_home);
886        std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
887        std::env::remove_var("ANTHROPIC_API_KEY");
888        save_oauth_credentials(&runtime::OAuthTokenSet {
889            access_token: "expired-access-token".to_string(),
890            refresh_token: Some("refresh-token".to_string()),
891            expires_at: Some(1),
892            scopes: vec!["scope:a".to_string()],
893        })
894        .expect("save expired oauth credentials");
895
896        let error =
897            resolve_startup_auth_source(|| Ok(None)).expect_err("missing config should error");
898        assert!(
899            matches!(error, crate::error::ApiError::Auth(message) if message.contains("runtime OAuth config is missing"))
900        );
901
902        let stored = runtime::load_oauth_credentials()
903            .expect("load stored credentials")
904            .expect("stored token set");
905        assert_eq!(stored.access_token, "expired-access-token");
906        assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token"));
907
908        clear_oauth_credentials().expect("clear credentials");
909        std::env::remove_var("WRAITH_CONFIG_HOME");
910        cleanup_temp_config_home(&config_home);
911    }
912
913    #[test]
914    fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() {
915        let _guard = env_lock();
916        let config_home = temp_config_home();
917        std::env::set_var("WRAITH_CONFIG_HOME", &config_home);
918        std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
919        std::env::remove_var("ANTHROPIC_API_KEY");
920        save_oauth_credentials(&runtime::OAuthTokenSet {
921            access_token: "expired-access-token".to_string(),
922            refresh_token: Some("refresh-token".to_string()),
923            expires_at: Some(1),
924            scopes: vec!["scope:a".to_string()],
925        })
926        .expect("save expired oauth credentials");
927
928        let token_url = spawn_token_server(
929            "{\"access_token\":\"refreshed-token\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}",
930        );
931        let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url))
932            .expect("resolve refreshed token")
933            .expect("token set present");
934        assert_eq!(resolved.access_token, "refreshed-token");
935        assert_eq!(resolved.refresh_token.as_deref(), Some("refresh-token"));
936        let stored = runtime::load_oauth_credentials()
937            .expect("load stored credentials")
938            .expect("stored token set");
939        assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token"));
940
941        clear_oauth_credentials().expect("clear credentials");
942        std::env::remove_var("WRAITH_CONFIG_HOME");
943        cleanup_temp_config_home(&config_home);
944    }
945
946    #[test]
947    fn message_request_stream_helper_sets_stream_true() {
948        let request = MessageRequest {
949            model: "claude-opus-4-6".to_string(),
950            max_tokens: 64,
951            messages: vec![],
952            system: None,
953            tools: None,
954            tool_choice: None,
955            stream: false,
956        };
957
958        assert!(request.with_streaming().stream);
959    }
960
961    #[test]
962    fn backoff_doubles_until_maximum() {
963        let client = AnthropicClient::new("test-key").with_retry_policy(
964            3,
965            Duration::from_millis(10),
966            Duration::from_millis(25),
967        );
968        assert_eq!(
969            client.backoff_for_attempt(1).expect("attempt 1"),
970            Duration::from_millis(10)
971        );
972        assert_eq!(
973            client.backoff_for_attempt(2).expect("attempt 2"),
974            Duration::from_millis(20)
975        );
976        assert_eq!(
977            client.backoff_for_attempt(3).expect("attempt 3"),
978            Duration::from_millis(25)
979        );
980    }
981
982    #[test]
983    fn retryable_statuses_are_detected() {
984        assert!(super::is_retryable_status(
985            reqwest::StatusCode::TOO_MANY_REQUESTS
986        ));
987        assert!(super::is_retryable_status(
988            reqwest::StatusCode::INTERNAL_SERVER_ERROR
989        ));
990        assert!(!super::is_retryable_status(
991            reqwest::StatusCode::UNAUTHORIZED
992        ));
993    }
994
995    #[test]
996    fn tool_delta_variant_round_trips() {
997        let delta = ContentBlockDelta::InputJsonDelta {
998            partial_json: "{\"city\":\"Paris\"}".to_string(),
999        };
1000        let encoded = serde_json::to_string(&delta).expect("delta should serialize");
1001        let decoded: ContentBlockDelta =
1002            serde_json::from_str(&encoded).expect("delta should deserialize");
1003        assert_eq!(decoded, delta);
1004    }
1005
1006    #[test]
1007    fn request_id_uses_primary_or_fallback_header() {
1008        let mut headers = reqwest::header::HeaderMap::new();
1009        headers.insert(REQUEST_ID_HEADER, "req_primary".parse().expect("header"));
1010        assert_eq!(
1011            super::request_id_from_headers(&headers).as_deref(),
1012            Some("req_primary")
1013        );
1014
1015        headers.clear();
1016        headers.insert(
1017            ALT_REQUEST_ID_HEADER,
1018            "req_fallback".parse().expect("header"),
1019        );
1020        assert_eq!(
1021            super::request_id_from_headers(&headers).as_deref(),
1022            Some("req_fallback")
1023        );
1024    }
1025
1026    #[test]
1027    fn auth_source_applies_headers() {
1028        let auth = AuthSource::ApiKeyAndBearer {
1029            api_key: "test-key".to_string(),
1030            bearer_token: "proxy-token".to_string(),
1031        };
1032        let request = auth
1033            .apply(reqwest::Client::new().post("https://example.test"))
1034            .build()
1035            .expect("request build");
1036        let headers = request.headers();
1037        assert_eq!(
1038            headers.get("x-api-key").and_then(|v| v.to_str().ok()),
1039            Some("test-key")
1040        );
1041        assert_eq!(
1042            headers.get("authorization").and_then(|v| v.to_str().ok()),
1043            Some("Bearer proxy-token")
1044        );
1045    }
1046}