Skip to main content

entelix_auth_claude_code/
provider.rs

1//! [`ClaudeCodeOAuthProvider`] — [`CredentialProvider`] impl that
2//! resolves the access token the `claude` CLI manages, refreshing
3//! it through the standard OAuth2 `refresh_token` grant when needed.
4
5use std::sync::Arc;
6
7use async_trait::async_trait;
8use tokio::sync::Mutex;
9
10use entelix_core::auth::{CredentialProvider, Credentials};
11use entelix_core::error::Result;
12
13use crate::config::ClaudeCodeOAuthConfig;
14use crate::credential::{CredentialFile, OAuthCredential};
15use crate::error::ClaudeCodeAuthError;
16use crate::refresh::refresh_access_token;
17use crate::store::CredentialStore;
18
19/// Resolve credentials from a [`CredentialStore`] backend,
20/// refreshing the access token via the Anthropic console token
21/// endpoint when expiry is imminent.
22///
23/// Concurrent refresh attempts are serialised through an internal
24/// mutex — refresh tokens may rotate on every grant, so two
25/// in-flight refresh calls would race each other into rejection.
26pub struct ClaudeCodeOAuthProvider {
27    store: Arc<dyn CredentialStore>,
28    http: reqwest::Client,
29    refresh_guard: Mutex<()>,
30    config: ClaudeCodeOAuthConfig,
31}
32
33impl std::fmt::Debug for ClaudeCodeOAuthProvider {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("ClaudeCodeOAuthProvider")
36            .field("config", &self.config)
37            .finish_non_exhaustive()
38    }
39}
40
41impl ClaudeCodeOAuthProvider {
42    /// Build a provider over the supplied store backend with the
43    /// canonical Anthropic console token endpoint.
44    pub fn new(store: impl CredentialStore) -> Self {
45        Self::with_config(store, ClaudeCodeOAuthConfig::default())
46    }
47
48    /// Build a provider with an explicit config (custom token URL,
49    /// refresh timeout, …).
50    pub fn with_config(store: impl CredentialStore, config: ClaudeCodeOAuthConfig) -> Self {
51        let http = reqwest::Client::builder()
52            .timeout(config.refresh_timeout)
53            .build()
54            .unwrap_or_else(|_| reqwest::Client::new());
55        Self {
56            store: Arc::new(store),
57            http,
58            refresh_guard: Mutex::new(()),
59            config,
60        }
61    }
62
63    async fn load_oauth(&self) -> Result<OAuthCredential> {
64        let envelope =
65            self.store
66                .load()
67                .await?
68                .ok_or_else(|| ClaudeCodeAuthError::CredentialsMissing {
69                    path: "<store>".into(),
70                })?;
71        envelope
72            .claude_ai_oauth
73            .ok_or_else(|| ClaudeCodeAuthError::OAuthSectionMissing {
74                path: "<store>".into(),
75            })
76            .map_err(Into::into)
77    }
78
79    async fn refresh(&self, prior: OAuthCredential) -> Result<OAuthCredential> {
80        // Serialise refresh attempts — concurrent refresh_token
81        // usage races into the server-side rotation and one of the
82        // two callers loses.
83        let _guard = self.refresh_guard.lock().await;
84
85        // Re-load under the lock: another caller may have refreshed
86        // while we were waiting. Fall back to the prior in-memory
87        // credential when the store has no record (tests with
88        // ephemeral backends).
89        let current = self
90            .store
91            .load()
92            .await?
93            .and_then(|e| e.claude_ai_oauth)
94            .unwrap_or(prior);
95        if !current.needs_refresh() {
96            return Ok(current);
97        }
98
99        let refresh_token = current
100            .refresh_token
101            .as_deref()
102            .ok_or(ClaudeCodeAuthError::RefreshTokenMissing)?;
103
104        let mut refreshed = refresh_access_token(
105            &self.http,
106            &self.config.token_url,
107            refresh_token,
108            self.config.client_id.as_deref(),
109        )
110        .await?;
111
112        // The token endpoint only returns token fields; carry the
113        // operator-visible metadata through unchanged so store
114        // round-trips preserve it. Rotated refresh tokens come back
115        // populated; missing means "reuse prior".
116        if refreshed.subscription_type.is_none() {
117            refreshed
118                .subscription_type
119                .clone_from(&current.subscription_type);
120        }
121        if refreshed.scopes.is_empty() {
122            refreshed.scopes.clone_from(&current.scopes);
123        }
124        if refreshed.refresh_token.is_none() {
125            refreshed.refresh_token.clone_from(&current.refresh_token);
126        }
127
128        self.store
129            .save(&CredentialFile::with_oauth(refreshed.clone()))
130            .await?;
131        Ok(refreshed)
132    }
133}
134
135#[async_trait]
136impl CredentialProvider for ClaudeCodeOAuthProvider {
137    async fn resolve(&self) -> Result<Credentials> {
138        let oauth = self.load_oauth().await?;
139        let active = if oauth.needs_refresh() {
140            self.refresh(oauth).await?
141        } else {
142            oauth
143        };
144        Ok(Credentials {
145            header_name: http::header::AUTHORIZATION,
146            header_value: active.to_bearer_secret(),
147        })
148    }
149}
150
151#[cfg(test)]
152#[allow(clippy::unwrap_used)]
153mod tests {
154    use super::*;
155    use crate::store::CredentialStore;
156    use chrono::Utc;
157    use secrecy::ExposeSecret;
158    use std::sync::Mutex as StdMutex;
159    use wiremock::matchers::{method, path};
160    use wiremock::{Mock, MockServer, ResponseTemplate};
161
162    #[derive(Clone, Default)]
163    struct MemoryCredentialStore {
164        inner: Arc<StdMutex<Option<CredentialFile>>>,
165    }
166
167    impl MemoryCredentialStore {
168        fn seeded(file: CredentialFile) -> Self {
169            Self {
170                inner: Arc::new(StdMutex::new(Some(file))),
171            }
172        }
173    }
174
175    #[async_trait]
176    impl CredentialStore for MemoryCredentialStore {
177        async fn load(&self) -> crate::error::ClaudeCodeAuthResult<Option<CredentialFile>> {
178            Ok(self.inner.lock().unwrap().clone())
179        }
180        async fn save(&self, file: &CredentialFile) -> crate::error::ClaudeCodeAuthResult<()> {
181            *self.inner.lock().unwrap() = Some(file.clone());
182            Ok(())
183        }
184    }
185
186    fn fresh_oauth() -> OAuthCredential {
187        OAuthCredential::new(
188            "fresh-access",
189            (Utc::now() + chrono::Duration::hours(1)).timestamp_millis(),
190        )
191        .with_refresh_token("ref")
192        .with_subscription_type("pro")
193        .with_scopes(["user:inference"])
194    }
195
196    fn expired_oauth() -> OAuthCredential {
197        OAuthCredential::new(
198            "stale-access",
199            (Utc::now() - chrono::Duration::seconds(5)).timestamp_millis(),
200        )
201        .with_refresh_token("ref")
202        .with_subscription_type("pro")
203        .with_scopes(["user:inference"])
204    }
205
206    #[tokio::test]
207    async fn resolve_returns_bearer_when_token_fresh() {
208        let store = MemoryCredentialStore::seeded(CredentialFile::with_oauth(fresh_oauth()));
209        let provider = ClaudeCodeOAuthProvider::new(store);
210        let creds = provider.resolve().await.unwrap();
211        assert_eq!(creds.header_name, http::header::AUTHORIZATION);
212        assert_eq!(creds.header_value.expose_secret(), "Bearer fresh-access");
213    }
214
215    #[tokio::test]
216    async fn resolve_refreshes_when_token_expired() {
217        let server = MockServer::start().await;
218        Mock::given(method("POST"))
219            .and(path("/oauth/token"))
220            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
221                "access_token": "renewed-access",
222                "refresh_token": "renewed-refresh",
223                "expires_in": 3600
224            })))
225            .mount(&server)
226            .await;
227
228        let store = MemoryCredentialStore::seeded(CredentialFile::with_oauth(expired_oauth()));
229        let provider = ClaudeCodeOAuthProvider::with_config(
230            store.clone(),
231            ClaudeCodeOAuthConfig::new().with_token_url(format!("{}/oauth/token", server.uri())),
232        );
233        let creds = provider.resolve().await.unwrap();
234        assert_eq!(creds.header_value.expose_secret(), "Bearer renewed-access");
235
236        // Store round-trip preserves operator metadata + persists
237        // the rotated refresh token.
238        let saved = store
239            .load()
240            .await
241            .unwrap()
242            .unwrap()
243            .claude_ai_oauth
244            .unwrap();
245        assert_eq!(saved.access_token, "renewed-access");
246        assert_eq!(saved.refresh_token.as_deref(), Some("renewed-refresh"));
247        assert_eq!(saved.subscription_type.as_deref(), Some("pro"));
248        assert!(saved.scopes.contains(&"user:inference".to_owned()));
249    }
250
251    #[tokio::test]
252    async fn resolve_errors_when_store_empty() {
253        let store = MemoryCredentialStore::default();
254        let provider = ClaudeCodeOAuthProvider::new(store);
255        let err = provider.resolve().await.unwrap_err();
256        let msg = err.to_string();
257        assert!(msg.contains("not found"), "got: {msg}");
258    }
259
260    #[tokio::test]
261    async fn resolve_errors_when_refresh_token_absent_and_expired() {
262        let stale = OAuthCredential::new(
263            "stale-access",
264            (Utc::now() - chrono::Duration::seconds(5)).timestamp_millis(),
265        )
266        .with_subscription_type("pro");
267        let store = MemoryCredentialStore::seeded(CredentialFile::with_oauth(stale));
268        let provider = ClaudeCodeOAuthProvider::new(store);
269        let err = provider.resolve().await.unwrap_err();
270        assert!(err.to_string().contains("refresh token absent"));
271    }
272}