entelix_auth_claude_code/
provider.rs1use std::sync::Arc;
6
7use async_trait::async_trait;
8use tokio::sync::Mutex;
9
10use entelix_core::auth::{CredentialProvider, Credentials};
11use entelix_core::error::Result;
12
13use crate::config::ClaudeCodeOAuthConfig;
14use crate::credential::{CredentialFile, OAuthCredential};
15use crate::error::ClaudeCodeAuthError;
16use crate::refresh::refresh_access_token;
17use crate::store::CredentialStore;
18
19pub struct ClaudeCodeOAuthProvider {
27 store: Arc<dyn CredentialStore>,
28 http: reqwest::Client,
29 refresh_guard: Mutex<()>,
30 config: ClaudeCodeOAuthConfig,
31}
32
33impl std::fmt::Debug for ClaudeCodeOAuthProvider {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 f.debug_struct("ClaudeCodeOAuthProvider")
36 .field("config", &self.config)
37 .finish_non_exhaustive()
38 }
39}
40
41impl ClaudeCodeOAuthProvider {
42 pub fn new(store: impl CredentialStore) -> Self {
45 Self::with_config(store, ClaudeCodeOAuthConfig::default())
46 }
47
48 pub fn with_config(store: impl CredentialStore, config: ClaudeCodeOAuthConfig) -> Self {
51 let http = reqwest::Client::builder()
52 .timeout(config.refresh_timeout)
53 .build()
54 .unwrap_or_else(|_| reqwest::Client::new());
55 Self {
56 store: Arc::new(store),
57 http,
58 refresh_guard: Mutex::new(()),
59 config,
60 }
61 }
62
63 async fn load_oauth(&self) -> Result<OAuthCredential> {
64 let envelope =
65 self.store
66 .load()
67 .await?
68 .ok_or_else(|| ClaudeCodeAuthError::CredentialsMissing {
69 path: "<store>".into(),
70 })?;
71 envelope
72 .claude_ai_oauth
73 .ok_or_else(|| ClaudeCodeAuthError::OAuthSectionMissing {
74 path: "<store>".into(),
75 })
76 .map_err(Into::into)
77 }
78
79 async fn refresh(&self, prior: OAuthCredential) -> Result<OAuthCredential> {
80 let _guard = self.refresh_guard.lock().await;
84
85 let current = self
90 .store
91 .load()
92 .await?
93 .and_then(|e| e.claude_ai_oauth)
94 .unwrap_or(prior);
95 if !current.needs_refresh() {
96 return Ok(current);
97 }
98
99 let refresh_token = current
100 .refresh_token
101 .as_deref()
102 .ok_or(ClaudeCodeAuthError::RefreshTokenMissing)?;
103
104 let mut refreshed = refresh_access_token(
105 &self.http,
106 &self.config.token_url,
107 refresh_token,
108 self.config.client_id.as_deref(),
109 )
110 .await?;
111
112 if refreshed.subscription_type.is_none() {
117 refreshed
118 .subscription_type
119 .clone_from(¤t.subscription_type);
120 }
121 if refreshed.scopes.is_empty() {
122 refreshed.scopes.clone_from(¤t.scopes);
123 }
124 if refreshed.refresh_token.is_none() {
125 refreshed.refresh_token.clone_from(¤t.refresh_token);
126 }
127
128 self.store
129 .save(&CredentialFile::with_oauth(refreshed.clone()))
130 .await?;
131 Ok(refreshed)
132 }
133}
134
135#[async_trait]
136impl CredentialProvider for ClaudeCodeOAuthProvider {
137 async fn resolve(&self) -> Result<Credentials> {
138 let oauth = self.load_oauth().await?;
139 let active = if oauth.needs_refresh() {
140 self.refresh(oauth).await?
141 } else {
142 oauth
143 };
144 Ok(Credentials {
145 header_name: http::header::AUTHORIZATION,
146 header_value: active.to_bearer_secret(),
147 })
148 }
149}
150
151#[cfg(test)]
152#[allow(clippy::unwrap_used)]
153mod tests {
154 use super::*;
155 use crate::store::CredentialStore;
156 use chrono::Utc;
157 use secrecy::ExposeSecret;
158 use std::sync::Mutex as StdMutex;
159 use wiremock::matchers::{method, path};
160 use wiremock::{Mock, MockServer, ResponseTemplate};
161
162 #[derive(Clone, Default)]
163 struct MemoryCredentialStore {
164 inner: Arc<StdMutex<Option<CredentialFile>>>,
165 }
166
167 impl MemoryCredentialStore {
168 fn seeded(file: CredentialFile) -> Self {
169 Self {
170 inner: Arc::new(StdMutex::new(Some(file))),
171 }
172 }
173 }
174
175 #[async_trait]
176 impl CredentialStore for MemoryCredentialStore {
177 async fn load(&self) -> crate::error::ClaudeCodeAuthResult<Option<CredentialFile>> {
178 Ok(self.inner.lock().unwrap().clone())
179 }
180 async fn save(&self, file: &CredentialFile) -> crate::error::ClaudeCodeAuthResult<()> {
181 *self.inner.lock().unwrap() = Some(file.clone());
182 Ok(())
183 }
184 }
185
186 fn fresh_oauth() -> OAuthCredential {
187 OAuthCredential::new(
188 "fresh-access",
189 (Utc::now() + chrono::Duration::hours(1)).timestamp_millis(),
190 )
191 .with_refresh_token("ref")
192 .with_subscription_type("pro")
193 .with_scopes(["user:inference"])
194 }
195
196 fn expired_oauth() -> OAuthCredential {
197 OAuthCredential::new(
198 "stale-access",
199 (Utc::now() - chrono::Duration::seconds(5)).timestamp_millis(),
200 )
201 .with_refresh_token("ref")
202 .with_subscription_type("pro")
203 .with_scopes(["user:inference"])
204 }
205
206 #[tokio::test]
207 async fn resolve_returns_bearer_when_token_fresh() {
208 let store = MemoryCredentialStore::seeded(CredentialFile::with_oauth(fresh_oauth()));
209 let provider = ClaudeCodeOAuthProvider::new(store);
210 let creds = provider.resolve().await.unwrap();
211 assert_eq!(creds.header_name, http::header::AUTHORIZATION);
212 assert_eq!(creds.header_value.expose_secret(), "Bearer fresh-access");
213 }
214
215 #[tokio::test]
216 async fn resolve_refreshes_when_token_expired() {
217 let server = MockServer::start().await;
218 Mock::given(method("POST"))
219 .and(path("/oauth/token"))
220 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
221 "access_token": "renewed-access",
222 "refresh_token": "renewed-refresh",
223 "expires_in": 3600
224 })))
225 .mount(&server)
226 .await;
227
228 let store = MemoryCredentialStore::seeded(CredentialFile::with_oauth(expired_oauth()));
229 let provider = ClaudeCodeOAuthProvider::with_config(
230 store.clone(),
231 ClaudeCodeOAuthConfig::new().with_token_url(format!("{}/oauth/token", server.uri())),
232 );
233 let creds = provider.resolve().await.unwrap();
234 assert_eq!(creds.header_value.expose_secret(), "Bearer renewed-access");
235
236 let saved = store
239 .load()
240 .await
241 .unwrap()
242 .unwrap()
243 .claude_ai_oauth
244 .unwrap();
245 assert_eq!(saved.access_token, "renewed-access");
246 assert_eq!(saved.refresh_token.as_deref(), Some("renewed-refresh"));
247 assert_eq!(saved.subscription_type.as_deref(), Some("pro"));
248 assert!(saved.scopes.contains(&"user:inference".to_owned()));
249 }
250
251 #[tokio::test]
252 async fn resolve_errors_when_store_empty() {
253 let store = MemoryCredentialStore::default();
254 let provider = ClaudeCodeOAuthProvider::new(store);
255 let err = provider.resolve().await.unwrap_err();
256 let msg = err.to_string();
257 assert!(msg.contains("not found"), "got: {msg}");
258 }
259
260 #[tokio::test]
261 async fn resolve_errors_when_refresh_token_absent_and_expired() {
262 let stale = OAuthCredential::new(
263 "stale-access",
264 (Utc::now() - chrono::Duration::seconds(5)).timestamp_millis(),
265 )
266 .with_subscription_type("pro");
267 let store = MemoryCredentialStore::seeded(CredentialFile::with_oauth(stale));
268 let provider = ClaudeCodeOAuthProvider::new(store);
269 let err = provider.resolve().await.unwrap_err();
270 assert!(err.to_string().contains("refresh token absent"));
271 }
272}