Skip to main content

codineer_api/providers/
codineer_provider.rs

1use std::collections::VecDeque;
2use std::time::{SystemTime, UNIX_EPOCH};
3
4use runtime::{
5    load_oauth_credentials, save_oauth_credentials, OAuthConfig, OAuthRefreshRequest,
6    OAuthTokenExchangeRequest,
7};
8use serde::Deserialize;
9
10use crate::error::ApiError;
11use crate::providers::RetryPolicy;
12use crate::sse::SseParser;
13use crate::types::{MessageRequest, MessageResponse, StreamEvent};
14
15pub const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
16const ANTHROPIC_VERSION: &str = "2023-06-01";
17const REQUEST_ID_HEADER: &str = "request-id";
18const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
19
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub enum AuthSource {
22    None,
23    ApiKey(String),
24    BearerToken(String),
25    ApiKeyAndBearer {
26        api_key: String,
27        bearer_token: String,
28    },
29}
30
31impl AuthSource {
32    pub fn from_env() -> Result<Self, ApiError> {
33        let api_key = read_env_non_empty("ANTHROPIC_API_KEY")?;
34        let auth_token = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?;
35        match (api_key, auth_token) {
36            (Some(api_key), Some(bearer_token)) => Ok(Self::ApiKeyAndBearer {
37                api_key,
38                bearer_token,
39            }),
40            (Some(api_key), None) => Ok(Self::ApiKey(api_key)),
41            (None, Some(bearer_token)) => Ok(Self::BearerToken(bearer_token)),
42            (None, None) => Err(ApiError::missing_credentials(
43                "Anthropic",
44                &["ANTHROPIC_API_KEY", "ANTHROPIC_AUTH_TOKEN"],
45            )),
46        }
47    }
48
49    #[must_use]
50    pub fn api_key(&self) -> Option<&str> {
51        match self {
52            Self::ApiKey(api_key) | Self::ApiKeyAndBearer { api_key, .. } => Some(api_key),
53            Self::None | Self::BearerToken(_) => None,
54        }
55    }
56
57    #[must_use]
58    pub fn bearer_token(&self) -> Option<&str> {
59        match self {
60            Self::BearerToken(token)
61            | Self::ApiKeyAndBearer {
62                bearer_token: token,
63                ..
64            } => Some(token),
65            Self::None | Self::ApiKey(_) => None,
66        }
67    }
68
69    #[must_use]
70    pub fn masked_authorization_header(&self) -> &'static str {
71        if self.bearer_token().is_some() {
72            "Bearer [REDACTED]"
73        } else {
74            "<absent>"
75        }
76    }
77
78    pub fn apply(&self, mut request_builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
79        if let Some(api_key) = self.api_key() {
80            request_builder = request_builder.header("x-api-key", api_key);
81        }
82        if let Some(token) = self.bearer_token() {
83            request_builder = request_builder.bearer_auth(token);
84        }
85        request_builder
86    }
87}
88
89#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
90pub struct OAuthTokenSet {
91    pub access_token: String,
92    pub refresh_token: Option<String>,
93    pub expires_at: Option<u64>,
94    #[serde(default)]
95    pub scopes: Vec<String>,
96}
97
98impl From<OAuthTokenSet> for AuthSource {
99    fn from(value: OAuthTokenSet) -> Self {
100        Self::BearerToken(value.access_token)
101    }
102}
103
104#[derive(Debug, Clone)]
105pub struct CodineerApiClient {
106    http: reqwest::Client,
107    auth: AuthSource,
108    base_url: String,
109    retry: RetryPolicy,
110}
111
112impl CodineerApiClient {
113    #[must_use]
114    pub fn new(api_key: impl Into<String>) -> Self {
115        Self {
116            http: reqwest::Client::new(),
117            auth: AuthSource::ApiKey(api_key.into()),
118            base_url: DEFAULT_BASE_URL.to_string(),
119            retry: RetryPolicy::default(),
120        }
121    }
122
123    #[must_use]
124    pub fn from_auth(auth: AuthSource) -> Self {
125        Self {
126            http: reqwest::Client::new(),
127            auth,
128            base_url: DEFAULT_BASE_URL.to_string(),
129            retry: RetryPolicy::default(),
130        }
131    }
132
133    pub fn from_env() -> Result<Self, ApiError> {
134        Ok(Self::from_auth(AuthSource::from_env_or_saved()?).with_base_url(read_base_url()))
135    }
136
137    #[must_use]
138    pub fn with_auth_source(mut self, auth: AuthSource) -> Self {
139        self.auth = auth;
140        self
141    }
142
143    #[must_use]
144    pub fn with_auth_token(mut self, auth_token: Option<String>) -> Self {
145        match (
146            self.auth.api_key().map(ToOwned::to_owned),
147            auth_token.filter(|token| !token.is_empty()),
148        ) {
149            (Some(api_key), Some(bearer_token)) => {
150                self.auth = AuthSource::ApiKeyAndBearer {
151                    api_key,
152                    bearer_token,
153                };
154            }
155            (Some(api_key), None) => {
156                self.auth = AuthSource::ApiKey(api_key);
157            }
158            (None, Some(bearer_token)) => {
159                self.auth = AuthSource::BearerToken(bearer_token);
160            }
161            (None, None) => {
162                self.auth = AuthSource::None;
163            }
164        }
165        self
166    }
167
168    #[must_use]
169    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
170        self.base_url = base_url.into();
171        self
172    }
173
174    #[must_use]
175    pub fn with_retry_policy(mut self, retry: RetryPolicy) -> Self {
176        self.retry = retry;
177        self
178    }
179
180    #[must_use]
181    pub fn auth_source(&self) -> &AuthSource {
182        &self.auth
183    }
184
185    pub async fn send_message(
186        &self,
187        request: &MessageRequest,
188    ) -> Result<MessageResponse, ApiError> {
189        let request = MessageRequest {
190            stream: false,
191            ..request.clone()
192        };
193        let response = self.send_with_retry(&request).await?;
194        let request_id = request_id_from_headers(response.headers());
195        let mut response = response
196            .json::<MessageResponse>()
197            .await
198            .map_err(ApiError::from)?;
199        if response.request_id.is_none() {
200            response.request_id = request_id;
201        }
202        Ok(response)
203    }
204
205    pub async fn stream_message(
206        &self,
207        request: &MessageRequest,
208    ) -> Result<MessageStream, ApiError> {
209        let response = self
210            .send_with_retry(&request.clone().with_streaming())
211            .await?;
212        Ok(MessageStream {
213            request_id: request_id_from_headers(response.headers()),
214            response,
215            parser: SseParser::new(),
216            pending: VecDeque::new(),
217            done: false,
218        })
219    }
220
221    pub async fn exchange_oauth_code(
222        &self,
223        config: &OAuthConfig,
224        request: &OAuthTokenExchangeRequest,
225    ) -> Result<OAuthTokenSet, ApiError> {
226        let response = self
227            .http
228            .post(&config.token_url)
229            .header("content-type", "application/x-www-form-urlencoded")
230            .form(&request.form_params())
231            .send()
232            .await
233            .map_err(ApiError::from)?;
234        let response = expect_success(response).await?;
235        response
236            .json::<OAuthTokenSet>()
237            .await
238            .map_err(ApiError::from)
239    }
240
241    pub async fn refresh_oauth_token(
242        &self,
243        config: &OAuthConfig,
244        request: &OAuthRefreshRequest,
245    ) -> Result<OAuthTokenSet, ApiError> {
246        let response = self
247            .http
248            .post(&config.token_url)
249            .header("content-type", "application/x-www-form-urlencoded")
250            .form(&request.form_params())
251            .send()
252            .await
253            .map_err(ApiError::from)?;
254        let response = expect_success(response).await?;
255        response
256            .json::<OAuthTokenSet>()
257            .await
258            .map_err(ApiError::from)
259    }
260
261    async fn send_with_retry(
262        &self,
263        request: &MessageRequest,
264    ) -> Result<reqwest::Response, ApiError> {
265        let mut attempts = 0;
266        let mut last_error: Option<ApiError>;
267
268        loop {
269            attempts += 1;
270            match self.send_raw_request(request).await {
271                Ok(response) => match expect_success(response).await {
272                    Ok(response) => return Ok(response),
273                    Err(error)
274                        if error.is_retryable() && attempts <= self.retry.max_retries + 1 =>
275                    {
276                        last_error = Some(error);
277                    }
278                    Err(error) => return Err(error),
279                },
280                Err(error) if error.is_retryable() && attempts <= self.retry.max_retries + 1 => {
281                    last_error = Some(error);
282                }
283                Err(error) => return Err(error),
284            }
285
286            if attempts > self.retry.max_retries {
287                break;
288            }
289
290            tokio::time::sleep(self.backoff_for_attempt(attempts)?).await;
291        }
292
293        Err(ApiError::RetriesExhausted {
294            attempts,
295            last_error: Box::new(last_error.unwrap_or(ApiError::Auth(
296                "retry loop exited without capturing an error".into(),
297            ))),
298        })
299    }
300
301    async fn send_raw_request(
302        &self,
303        request: &MessageRequest,
304    ) -> Result<reqwest::Response, ApiError> {
305        let request_url = format!("{}/v1/messages", self.base_url.trim_end_matches('/'));
306        let request_builder = self
307            .http
308            .post(&request_url)
309            .header("anthropic-version", ANTHROPIC_VERSION)
310            .header("content-type", "application/json");
311        let mut request_builder = self.auth.apply(request_builder);
312
313        request_builder = request_builder.json(request);
314        request_builder.send().await.map_err(ApiError::from)
315    }
316
317    fn backoff_for_attempt(&self, attempt: u32) -> Result<std::time::Duration, ApiError> {
318        let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else {
319            return Err(ApiError::BackoffOverflow {
320                attempt,
321                base_delay: self.retry.initial_backoff,
322            });
323        };
324        Ok(self
325            .retry
326            .initial_backoff
327            .checked_mul(multiplier)
328            .map_or(self.retry.max_backoff, |delay| {
329                delay.min(self.retry.max_backoff)
330            }))
331    }
332}
333
334impl AuthSource {
335    pub fn from_env_or_saved() -> Result<Self, ApiError> {
336        if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? {
337            return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
338                Some(bearer_token) => Ok(Self::ApiKeyAndBearer {
339                    api_key,
340                    bearer_token,
341                }),
342                None => Ok(Self::ApiKey(api_key)),
343            };
344        }
345        if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
346            return Ok(Self::BearerToken(bearer_token));
347        }
348        match load_saved_oauth_token() {
349            Ok(Some(token_set)) if oauth_token_is_expired(&token_set) => {
350                if token_set.refresh_token.is_some() {
351                    Err(ApiError::Auth(
352                        "saved OAuth token is expired; load runtime OAuth config to refresh it"
353                            .to_string(),
354                    ))
355                } else {
356                    Err(ApiError::ExpiredOAuthToken)
357                }
358            }
359            Ok(Some(token_set)) => Ok(Self::BearerToken(token_set.access_token)),
360            Ok(None) => Err(ApiError::missing_credentials(
361                "Anthropic",
362                &["ANTHROPIC_API_KEY", "ANTHROPIC_AUTH_TOKEN"],
363            )),
364            Err(error) => Err(error),
365        }
366    }
367}
368
369#[must_use]
370pub fn oauth_token_is_expired(token_set: &OAuthTokenSet) -> bool {
371    token_set
372        .expires_at
373        .is_some_and(|expires_at| expires_at <= now_unix_timestamp())
374}
375
376pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result<Option<OAuthTokenSet>, ApiError> {
377    let Some(token_set) = load_saved_oauth_token()? else {
378        return Ok(None);
379    };
380    resolve_saved_oauth_token_set(config, token_set).map(Some)
381}
382
383pub fn has_auth_from_env_or_saved() -> Result<bool, ApiError> {
384    Ok(read_env_non_empty("ANTHROPIC_API_KEY")?.is_some()
385        || read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?.is_some()
386        || load_saved_oauth_token()?.is_some())
387}
388
389pub fn resolve_startup_auth_source<F>(load_oauth_config: F) -> Result<AuthSource, ApiError>
390where
391    F: FnOnce() -> Result<Option<OAuthConfig>, ApiError>,
392{
393    if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? {
394        return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
395            Some(bearer_token) => Ok(AuthSource::ApiKeyAndBearer {
396                api_key,
397                bearer_token,
398            }),
399            None => Ok(AuthSource::ApiKey(api_key)),
400        };
401    }
402    if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
403        return Ok(AuthSource::BearerToken(bearer_token));
404    }
405
406    let Some(token_set) = load_saved_oauth_token()? else {
407        return Err(ApiError::missing_credentials(
408            "Anthropic",
409            &["ANTHROPIC_API_KEY", "ANTHROPIC_AUTH_TOKEN"],
410        ));
411    };
412    if !oauth_token_is_expired(&token_set) {
413        return Ok(AuthSource::BearerToken(token_set.access_token));
414    }
415    if token_set.refresh_token.is_none() {
416        return Err(ApiError::ExpiredOAuthToken);
417    }
418
419    let Some(config) = load_oauth_config()? else {
420        return Err(ApiError::Auth(
421            "saved OAuth token is expired; runtime OAuth config is missing".to_string(),
422        ));
423    };
424    Ok(AuthSource::from(resolve_saved_oauth_token_set(
425        &config, token_set,
426    )?))
427}
428
429fn resolve_saved_oauth_token_set(
430    config: &OAuthConfig,
431    token_set: OAuthTokenSet,
432) -> Result<OAuthTokenSet, ApiError> {
433    if !oauth_token_is_expired(&token_set) {
434        return Ok(token_set);
435    }
436    let Some(refresh_token) = token_set.refresh_token.clone() else {
437        return Err(ApiError::ExpiredOAuthToken);
438    };
439    let client = CodineerApiClient::from_auth(AuthSource::None).with_base_url(read_base_url());
440    let refreshed = client_runtime_block_on(async {
441        client
442            .refresh_oauth_token(
443                config,
444                &OAuthRefreshRequest::from_config(
445                    config,
446                    refresh_token,
447                    Some(token_set.scopes.clone()),
448                ),
449            )
450            .await
451    })?;
452    let resolved = OAuthTokenSet {
453        access_token: refreshed.access_token,
454        refresh_token: refreshed.refresh_token.or(token_set.refresh_token),
455        expires_at: refreshed.expires_at,
456        scopes: refreshed.scopes,
457    };
458    save_oauth_credentials(&runtime::OAuthTokenSet {
459        access_token: resolved.access_token.clone(),
460        refresh_token: resolved.refresh_token.clone(),
461        expires_at: resolved.expires_at,
462        scopes: resolved.scopes.clone(),
463    })
464    .map_err(ApiError::from)?;
465    Ok(resolved)
466}
467
468fn client_runtime_block_on<F, T>(future: F) -> Result<T, ApiError>
469where
470    F: std::future::Future<Output = Result<T, ApiError>>,
471{
472    match tokio::runtime::Handle::try_current() {
473        Ok(handle) => tokio::task::block_in_place(|| handle.block_on(future)),
474        Err(_) => tokio::runtime::Runtime::new()
475            .map_err(ApiError::from)?
476            .block_on(future),
477    }
478}
479
480fn load_saved_oauth_token() -> Result<Option<OAuthTokenSet>, ApiError> {
481    let token_set = load_oauth_credentials().map_err(ApiError::from)?;
482    Ok(token_set.map(|token_set| OAuthTokenSet {
483        access_token: token_set.access_token,
484        refresh_token: token_set.refresh_token,
485        expires_at: token_set.expires_at,
486        scopes: token_set.scopes,
487    }))
488}
489
490fn now_unix_timestamp() -> u64 {
491    SystemTime::now()
492        .duration_since(UNIX_EPOCH)
493        .map_or(0, |duration| duration.as_secs())
494}
495
496fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
497    match std::env::var(key) {
498        Ok(value) if !value.is_empty() => Ok(Some(value)),
499        Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None),
500        Err(error) => Err(ApiError::from(error)),
501    }
502}
503
504#[cfg(test)]
505fn read_api_key() -> Result<String, ApiError> {
506    let auth = AuthSource::from_env_or_saved()?;
507    auth.api_key()
508        .or_else(|| auth.bearer_token())
509        .map(ToOwned::to_owned)
510        .ok_or(ApiError::missing_credentials(
511            "Anthropic",
512            &["ANTHROPIC_API_KEY", "ANTHROPIC_AUTH_TOKEN"],
513        ))
514}
515
516#[cfg(test)]
517fn read_auth_token() -> Option<String> {
518    read_env_non_empty("ANTHROPIC_AUTH_TOKEN")
519        .ok()
520        .and_then(std::convert::identity)
521}
522
523#[must_use]
524pub fn read_base_url() -> String {
525    std::env::var("ANTHROPIC_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string())
526}
527
528fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
529    headers
530        .get(REQUEST_ID_HEADER)
531        .or_else(|| headers.get(ALT_REQUEST_ID_HEADER))
532        .and_then(|value| value.to_str().ok())
533        .map(ToOwned::to_owned)
534}
535
536#[derive(Debug)]
537pub struct MessageStream {
538    request_id: Option<String>,
539    response: reqwest::Response,
540    parser: SseParser,
541    pending: VecDeque<StreamEvent>,
542    done: bool,
543}
544
545impl MessageStream {
546    #[must_use]
547    pub fn request_id(&self) -> Option<&str> {
548        self.request_id.as_deref()
549    }
550
551    pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
552        loop {
553            if let Some(event) = self.pending.pop_front() {
554                return Ok(Some(event));
555            }
556
557            if self.done {
558                let remaining = self.parser.finish()?;
559                self.pending.extend(remaining);
560                if let Some(event) = self.pending.pop_front() {
561                    return Ok(Some(event));
562                }
563                return Ok(None);
564            }
565
566            match self.response.chunk().await? {
567                Some(chunk) => {
568                    self.pending.extend(self.parser.push(&chunk)?);
569                }
570                None => {
571                    self.done = true;
572                }
573            }
574        }
575    }
576}
577
578async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
579    let status = response.status();
580    if status.is_success() {
581        return Ok(response);
582    }
583
584    let body = response.text().await.unwrap_or_else(|_| String::new());
585    let parsed_error = serde_json::from_str::<ApiErrorEnvelope>(&body).ok();
586    let retryable = is_retryable_status(status);
587
588    Err(ApiError::Api {
589        status,
590        error_type: parsed_error
591            .as_ref()
592            .map(|error| error.error.error_type.clone()),
593        message: parsed_error
594            .as_ref()
595            .map(|error| error.error.message.clone()),
596        body,
597        retryable,
598    })
599}
600
601const fn is_retryable_status(status: reqwest::StatusCode) -> bool {
602    matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504)
603}
604
605#[derive(Debug, Deserialize)]
606struct ApiErrorEnvelope {
607    error: ApiErrorBody,
608}
609
610#[derive(Debug, Deserialize)]
611struct ApiErrorBody {
612    #[serde(rename = "type")]
613    error_type: String,
614    message: String,
615}
616
617#[cfg(test)]
618#[path = "codineer_provider_tests.rs"]
619mod tests;