Skip to main content

codineer_api/providers/
codineer_provider.rs

1use std::collections::VecDeque;
2use std::time::{SystemTime, UNIX_EPOCH};
3
4use runtime::{
5    load_oauth_credentials, save_oauth_credentials, OAuthConfig, OAuthRefreshRequest,
6    OAuthTokenExchangeRequest,
7};
8use serde::Deserialize;
9
10use crate::error::ApiError;
11use crate::providers::RetryPolicy;
12use crate::sse::SseParser;
13use crate::types::{MessageRequest, MessageResponse, StreamEvent};
14
15pub const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
16const ANTHROPIC_VERSION: &str = "2023-06-01";
17const REQUEST_ID_HEADER: &str = "request-id";
18const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
19
20#[derive(Clone, PartialEq, Eq)]
21pub enum AuthSource {
22    None,
23    ApiKey(String),
24    BearerToken(String),
25    ApiKeyAndBearer {
26        api_key: String,
27        bearer_token: String,
28    },
29}
30
31impl std::fmt::Debug for AuthSource {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        match self {
34            Self::None => write!(f, "AuthSource::None"),
35            Self::ApiKey(_) => write!(f, "AuthSource::ApiKey(***)"),
36            Self::BearerToken(_) => write!(f, "AuthSource::BearerToken(***)"),
37            Self::ApiKeyAndBearer { .. } => write!(f, "AuthSource::ApiKeyAndBearer(***)"),
38        }
39    }
40}
41
42impl AuthSource {
43    pub fn from_env() -> Result<Self, ApiError> {
44        let api_key = read_env_non_empty("ANTHROPIC_API_KEY")?;
45        let auth_token = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?;
46        match (api_key, auth_token) {
47            (Some(api_key), Some(bearer_token)) => Ok(Self::ApiKeyAndBearer {
48                api_key,
49                bearer_token,
50            }),
51            (Some(api_key), None) => Ok(Self::ApiKey(api_key)),
52            (None, Some(bearer_token)) => Ok(Self::BearerToken(bearer_token)),
53            (None, None) => Err(ApiError::missing_credentials(
54                "Anthropic",
55                &["ANTHROPIC_API_KEY", "ANTHROPIC_AUTH_TOKEN"],
56            )),
57        }
58    }
59
60    #[must_use]
61    pub fn api_key(&self) -> Option<&str> {
62        match self {
63            Self::ApiKey(api_key) | Self::ApiKeyAndBearer { api_key, .. } => Some(api_key),
64            Self::None | Self::BearerToken(_) => None,
65        }
66    }
67
68    #[must_use]
69    pub fn bearer_token(&self) -> Option<&str> {
70        match self {
71            Self::BearerToken(token)
72            | Self::ApiKeyAndBearer {
73                bearer_token: token,
74                ..
75            } => Some(token),
76            Self::None | Self::ApiKey(_) => None,
77        }
78    }
79
80    #[must_use]
81    pub fn masked_authorization_header(&self) -> &'static str {
82        if self.bearer_token().is_some() {
83            "Bearer [REDACTED]"
84        } else {
85            "<absent>"
86        }
87    }
88
89    pub fn apply(&self, mut request_builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
90        if let Some(api_key) = self.api_key() {
91            request_builder = request_builder.header("x-api-key", api_key);
92        }
93        if let Some(token) = self.bearer_token() {
94            request_builder = request_builder.bearer_auth(token);
95        }
96        request_builder
97    }
98}
99
100#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
101pub struct OAuthTokenSet {
102    pub access_token: String,
103    pub refresh_token: Option<String>,
104    pub expires_at: Option<u64>,
105    #[serde(default)]
106    pub scopes: Vec<String>,
107}
108
109impl From<OAuthTokenSet> for AuthSource {
110    fn from(value: OAuthTokenSet) -> Self {
111        Self::BearerToken(value.access_token)
112    }
113}
114
115impl From<runtime::ResolvedCredential> for AuthSource {
116    fn from(value: runtime::ResolvedCredential) -> Self {
117        match value {
118            runtime::ResolvedCredential::ApiKey(key) => Self::ApiKey(key),
119            runtime::ResolvedCredential::BearerToken(token) => Self::BearerToken(token),
120            runtime::ResolvedCredential::ApiKeyAndBearer {
121                api_key,
122                bearer_token,
123            } => Self::ApiKeyAndBearer {
124                api_key,
125                bearer_token,
126            },
127        }
128    }
129}
130
131#[derive(Clone)]
132pub struct CodineerApiClient {
133    http: reqwest::Client,
134    auth: AuthSource,
135    base_url: String,
136    retry: RetryPolicy,
137}
138
139impl std::fmt::Debug for CodineerApiClient {
140    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141        f.debug_struct("CodineerApiClient")
142            .field("base_url", &self.base_url)
143            .field("auth", &self.auth)
144            .finish()
145    }
146}
147
148impl CodineerApiClient {
149    #[must_use]
150    pub fn new(api_key: impl Into<String>) -> Self {
151        Self {
152            http: crate::default_http_client(),
153            auth: AuthSource::ApiKey(api_key.into()),
154            base_url: DEFAULT_BASE_URL.to_string(),
155            retry: RetryPolicy::default(),
156        }
157    }
158
159    #[must_use]
160    pub fn from_auth(auth: AuthSource) -> Self {
161        Self {
162            http: crate::default_http_client(),
163            auth,
164            base_url: DEFAULT_BASE_URL.to_string(),
165            retry: RetryPolicy::default(),
166        }
167    }
168
169    pub fn from_env() -> Result<Self, ApiError> {
170        Ok(Self::from_auth(AuthSource::from_env_or_saved()?).with_base_url(read_base_url()))
171    }
172
173    #[must_use]
174    pub fn with_auth_source(mut self, auth: AuthSource) -> Self {
175        self.auth = auth;
176        self
177    }
178
179    #[must_use]
180    pub fn with_auth_token(mut self, auth_token: Option<String>) -> Self {
181        match (
182            self.auth.api_key().map(ToOwned::to_owned),
183            auth_token.filter(|token| !token.is_empty()),
184        ) {
185            (Some(api_key), Some(bearer_token)) => {
186                self.auth = AuthSource::ApiKeyAndBearer {
187                    api_key,
188                    bearer_token,
189                };
190            }
191            (Some(api_key), None) => {
192                self.auth = AuthSource::ApiKey(api_key);
193            }
194            (None, Some(bearer_token)) => {
195                self.auth = AuthSource::BearerToken(bearer_token);
196            }
197            (None, None) => {
198                self.auth = AuthSource::None;
199            }
200        }
201        self
202    }
203
204    #[must_use]
205    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
206        self.base_url = base_url.into();
207        self
208    }
209
210    #[must_use]
211    pub fn with_retry_policy(mut self, retry: RetryPolicy) -> Self {
212        self.retry = retry;
213        self
214    }
215
216    #[must_use]
217    pub fn auth_source(&self) -> &AuthSource {
218        &self.auth
219    }
220
221    pub async fn send_message(
222        &self,
223        request: &MessageRequest,
224    ) -> Result<MessageResponse, ApiError> {
225        let request = MessageRequest {
226            stream: false,
227            ..request.clone()
228        };
229        let response = self.send_with_retry(&request).await?;
230        let request_id = request_id_from_headers(response.headers());
231        let mut response = response
232            .json::<MessageResponse>()
233            .await
234            .map_err(ApiError::from)?;
235        if response.request_id.is_none() {
236            response.request_id = request_id;
237        }
238        Ok(response)
239    }
240
241    pub async fn stream_message(
242        &self,
243        request: &MessageRequest,
244    ) -> Result<MessageStream, ApiError> {
245        let response = self
246            .send_with_retry(&request.clone().with_streaming())
247            .await?;
248        Ok(MessageStream {
249            request_id: request_id_from_headers(response.headers()),
250            response,
251            parser: SseParser::new(),
252            pending: VecDeque::new(),
253            done: false,
254        })
255    }
256
257    pub async fn exchange_oauth_code(
258        &self,
259        config: &OAuthConfig,
260        request: &OAuthTokenExchangeRequest,
261    ) -> Result<OAuthTokenSet, ApiError> {
262        let response = self
263            .http
264            .post(&config.token_url)
265            .header("content-type", "application/x-www-form-urlencoded")
266            .form(&request.form_params())
267            .send()
268            .await
269            .map_err(ApiError::from)?;
270        let response = expect_success(response).await?;
271        response
272            .json::<OAuthTokenSet>()
273            .await
274            .map_err(ApiError::from)
275    }
276
277    pub async fn refresh_oauth_token(
278        &self,
279        config: &OAuthConfig,
280        request: &OAuthRefreshRequest,
281    ) -> Result<OAuthTokenSet, ApiError> {
282        let response = self
283            .http
284            .post(&config.token_url)
285            .header("content-type", "application/x-www-form-urlencoded")
286            .form(&request.form_params())
287            .send()
288            .await
289            .map_err(ApiError::from)?;
290        let response = expect_success(response).await?;
291        response
292            .json::<OAuthTokenSet>()
293            .await
294            .map_err(ApiError::from)
295    }
296
297    async fn send_with_retry(
298        &self,
299        request: &MessageRequest,
300    ) -> Result<reqwest::Response, ApiError> {
301        let mut attempts = 0;
302        let mut last_error: Option<ApiError>;
303
304        loop {
305            attempts += 1;
306            match self.send_raw_request(request).await {
307                Ok(response) => match expect_success(response).await {
308                    Ok(response) => return Ok(response),
309                    Err(error)
310                        if error.is_retryable() && attempts <= self.retry.max_retries + 1 =>
311                    {
312                        last_error = Some(error);
313                    }
314                    Err(error) => return Err(error),
315                },
316                Err(error) if error.is_retryable() && attempts <= self.retry.max_retries + 1 => {
317                    last_error = Some(error);
318                }
319                Err(error) => return Err(error),
320            }
321
322            if attempts > self.retry.max_retries {
323                break;
324            }
325
326            tokio::time::sleep(self.backoff_for_attempt(attempts)?).await;
327        }
328
329        Err(ApiError::RetriesExhausted {
330            attempts,
331            last_error: Box::new(last_error.unwrap_or(ApiError::Auth(
332                "retry loop exited without capturing an error".into(),
333            ))),
334        })
335    }
336
337    async fn send_raw_request(
338        &self,
339        request: &MessageRequest,
340    ) -> Result<reqwest::Response, ApiError> {
341        let request_url = format!("{}/v1/messages", self.base_url.trim_end_matches('/'));
342        let request_builder = self
343            .http
344            .post(&request_url)
345            .header("anthropic-version", ANTHROPIC_VERSION)
346            .header("content-type", "application/json");
347        let mut request_builder = self.auth.apply(request_builder);
348
349        request_builder = request_builder.json(request);
350        request_builder.send().await.map_err(ApiError::from)
351    }
352
353    fn backoff_for_attempt(&self, attempt: u32) -> Result<std::time::Duration, ApiError> {
354        let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else {
355            return Err(ApiError::BackoffOverflow {
356                attempt,
357                base_delay: self.retry.initial_backoff,
358            });
359        };
360        Ok(self
361            .retry
362            .initial_backoff
363            .checked_mul(multiplier)
364            .map_or(self.retry.max_backoff, |delay| {
365                delay.min(self.retry.max_backoff)
366            }))
367    }
368}
369
370impl AuthSource {
371    pub fn from_env_or_saved() -> Result<Self, ApiError> {
372        if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? {
373            return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
374                Some(bearer_token) => Ok(Self::ApiKeyAndBearer {
375                    api_key,
376                    bearer_token,
377                }),
378                None => Ok(Self::ApiKey(api_key)),
379            };
380        }
381        if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
382            return Ok(Self::BearerToken(bearer_token));
383        }
384        match load_saved_oauth_token() {
385            Ok(Some(token_set)) if oauth_token_is_expired(&token_set) => {
386                if token_set.refresh_token.is_some() {
387                    Err(ApiError::Auth(
388                        "saved OAuth token is expired; load runtime OAuth config to refresh it"
389                            .to_string(),
390                    ))
391                } else {
392                    Err(ApiError::ExpiredOAuthToken)
393                }
394            }
395            Ok(Some(token_set)) => Ok(Self::BearerToken(token_set.access_token)),
396            Ok(None) => Err(ApiError::missing_credentials(
397                "Anthropic",
398                &["ANTHROPIC_API_KEY", "ANTHROPIC_AUTH_TOKEN"],
399            )),
400            Err(error) => Err(error),
401        }
402    }
403}
404
405#[must_use]
406pub fn oauth_token_is_expired(token_set: &OAuthTokenSet) -> bool {
407    token_set
408        .expires_at
409        .is_some_and(|expires_at| expires_at <= now_unix_timestamp())
410}
411
412pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result<Option<OAuthTokenSet>, ApiError> {
413    let Some(token_set) = load_saved_oauth_token()? else {
414        return Ok(None);
415    };
416    resolve_saved_oauth_token_set(config, token_set).map(Some)
417}
418
419pub fn has_auth_from_env_or_saved() -> Result<bool, ApiError> {
420    Ok(read_env_non_empty("ANTHROPIC_API_KEY")?.is_some()
421        || read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?.is_some()
422        || load_saved_oauth_token()?.is_some())
423}
424
425pub fn resolve_startup_auth_source<F>(load_oauth_config: F) -> Result<AuthSource, ApiError>
426where
427    F: FnOnce() -> Result<Option<OAuthConfig>, ApiError>,
428{
429    if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? {
430        return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
431            Some(bearer_token) => Ok(AuthSource::ApiKeyAndBearer {
432                api_key,
433                bearer_token,
434            }),
435            None => Ok(AuthSource::ApiKey(api_key)),
436        };
437    }
438    if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
439        return Ok(AuthSource::BearerToken(bearer_token));
440    }
441
442    let Some(token_set) = load_saved_oauth_token()? else {
443        return Err(ApiError::missing_credentials(
444            "Anthropic",
445            &["ANTHROPIC_API_KEY", "ANTHROPIC_AUTH_TOKEN"],
446        ));
447    };
448    if !oauth_token_is_expired(&token_set) {
449        return Ok(AuthSource::BearerToken(token_set.access_token));
450    }
451    if token_set.refresh_token.is_none() {
452        return Err(ApiError::ExpiredOAuthToken);
453    }
454
455    let Some(config) = load_oauth_config()? else {
456        return Err(ApiError::Auth(
457            "saved OAuth token is expired; runtime OAuth config is missing".to_string(),
458        ));
459    };
460    Ok(AuthSource::from(resolve_saved_oauth_token_set(
461        &config, token_set,
462    )?))
463}
464
465fn resolve_saved_oauth_token_set(
466    config: &OAuthConfig,
467    token_set: OAuthTokenSet,
468) -> Result<OAuthTokenSet, ApiError> {
469    if !oauth_token_is_expired(&token_set) {
470        return Ok(token_set);
471    }
472    let Some(refresh_token) = token_set.refresh_token.clone() else {
473        return Err(ApiError::ExpiredOAuthToken);
474    };
475    let client = CodineerApiClient::from_auth(AuthSource::None).with_base_url(read_base_url());
476    let refreshed = client_runtime_block_on(async {
477        client
478            .refresh_oauth_token(
479                config,
480                &OAuthRefreshRequest::from_config(
481                    config,
482                    refresh_token,
483                    Some(token_set.scopes.clone()),
484                ),
485            )
486            .await
487    })?;
488    let resolved = OAuthTokenSet {
489        access_token: refreshed.access_token,
490        refresh_token: refreshed.refresh_token.or(token_set.refresh_token),
491        expires_at: refreshed.expires_at,
492        scopes: refreshed.scopes,
493    };
494    save_oauth_credentials(&runtime::OAuthTokenSet {
495        access_token: resolved.access_token.clone(),
496        refresh_token: resolved.refresh_token.clone(),
497        expires_at: resolved.expires_at,
498        scopes: resolved.scopes.clone(),
499    })
500    .map_err(ApiError::from)?;
501    Ok(resolved)
502}
503
504fn client_runtime_block_on<F, T>(future: F) -> Result<T, ApiError>
505where
506    F: std::future::Future<Output = Result<T, ApiError>>,
507{
508    match tokio::runtime::Handle::try_current() {
509        Ok(handle) => tokio::task::block_in_place(|| handle.block_on(future)),
510        Err(_) => tokio::runtime::Runtime::new()
511            .map_err(ApiError::from)?
512            .block_on(future),
513    }
514}
515
516fn load_saved_oauth_token() -> Result<Option<OAuthTokenSet>, ApiError> {
517    let token_set = load_oauth_credentials().map_err(ApiError::from)?;
518    Ok(token_set.map(|token_set| OAuthTokenSet {
519        access_token: token_set.access_token,
520        refresh_token: token_set.refresh_token,
521        expires_at: token_set.expires_at,
522        scopes: token_set.scopes,
523    }))
524}
525
526fn now_unix_timestamp() -> u64 {
527    SystemTime::now()
528        .duration_since(UNIX_EPOCH)
529        .map_or(0, |duration| duration.as_secs())
530}
531
532fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
533    match std::env::var(key) {
534        Ok(value) if !value.is_empty() => Ok(Some(value)),
535        Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None),
536        Err(error) => Err(ApiError::from(error)),
537    }
538}
539
540#[cfg(test)]
541fn read_api_key() -> Result<String, ApiError> {
542    let auth = AuthSource::from_env_or_saved()?;
543    auth.api_key()
544        .or_else(|| auth.bearer_token())
545        .map(ToOwned::to_owned)
546        .ok_or(ApiError::missing_credentials(
547            "Anthropic",
548            &["ANTHROPIC_API_KEY", "ANTHROPIC_AUTH_TOKEN"],
549        ))
550}
551
552#[cfg(test)]
553fn read_auth_token() -> Option<String> {
554    read_env_non_empty("ANTHROPIC_AUTH_TOKEN")
555        .ok()
556        .and_then(std::convert::identity)
557}
558
559#[must_use]
560pub fn read_base_url() -> String {
561    std::env::var("ANTHROPIC_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string())
562}
563
564fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
565    headers
566        .get(REQUEST_ID_HEADER)
567        .or_else(|| headers.get(ALT_REQUEST_ID_HEADER))
568        .and_then(|value| value.to_str().ok())
569        .map(ToOwned::to_owned)
570}
571
572#[derive(Debug)]
573pub struct MessageStream {
574    request_id: Option<String>,
575    response: reqwest::Response,
576    parser: SseParser,
577    pending: VecDeque<StreamEvent>,
578    done: bool,
579}
580
581impl MessageStream {
582    #[must_use]
583    pub fn request_id(&self) -> Option<&str> {
584        self.request_id.as_deref()
585    }
586
587    pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
588        loop {
589            if let Some(event) = self.pending.pop_front() {
590                return Ok(Some(event));
591            }
592
593            if self.done {
594                let remaining = self.parser.finish()?;
595                self.pending.extend(remaining);
596                if let Some(event) = self.pending.pop_front() {
597                    return Ok(Some(event));
598                }
599                return Ok(None);
600            }
601
602            match self.response.chunk().await? {
603                Some(chunk) => {
604                    self.pending.extend(self.parser.push(&chunk)?);
605                }
606                None => {
607                    self.done = true;
608                }
609            }
610        }
611    }
612}
613
614async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
615    let status = response.status();
616    if status.is_success() {
617        return Ok(response);
618    }
619
620    let body = response.text().await.unwrap_or_else(|_| String::new());
621    let parsed_error = serde_json::from_str::<ApiErrorEnvelope>(&body).ok();
622    let retryable = is_retryable_status(status);
623
624    Err(ApiError::Api {
625        status,
626        error_type: parsed_error
627            .as_ref()
628            .map(|error| error.error.error_type.clone()),
629        message: parsed_error
630            .as_ref()
631            .map(|error| error.error.message.clone()),
632        body,
633        retryable,
634    })
635}
636
637const fn is_retryable_status(status: reqwest::StatusCode) -> bool {
638    matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504)
639}
640
641#[derive(Debug, Deserialize)]
642struct ApiErrorEnvelope {
643    error: ApiErrorBody,
644}
645
646#[derive(Debug, Deserialize)]
647struct ApiErrorBody {
648    #[serde(rename = "type")]
649    error_type: String,
650    message: String,
651}
652
653#[cfg(test)]
654#[path = "codineer_provider_tests.rs"]
655mod tests;