Skip to main content

shunt/
oauth.rs

1/// OAuth 2.0 PKCE flow + token refresh for claude.ai accounts.
2///
3/// Claude Code authenticates via OAuth, not API keys. Credentials are stored
4/// in ~/.claude/.credentials.json and sent as `Authorization: Bearer <token>`.
5use anyhow::{bail, Context, Result};
6use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
7use serde::{Deserialize, Serialize};
8use sha2::{Digest, Sha256};
9use std::path::PathBuf;
10use std::time::{SystemTime, UNIX_EPOCH};
11
12pub const OAUTH_CLIENT_ID: &str = "9d1c250a-e61b-44d9-88ed-5944d1962f5e";
13pub const OAUTH_AUTHORIZE_URL: &str = "https://claude.ai/oauth/authorize";
14pub const OAUTH_TOKEN_URL: &str = "https://platform.claude.com/v1/oauth/token";
15
16// ---------------------------------------------------------------------------
17// Credential type
18// ---------------------------------------------------------------------------
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct OAuthCredential {
22    pub access_token: String,
23    pub refresh_token: String,
24    /// Milliseconds since Unix epoch
25    pub expires_at: u64,
26    /// Account email, fetched from roles endpoint after auth
27    #[serde(default, skip_serializing_if = "Option::is_none")]
28    pub email: Option<String>,
29}
30
31impl OAuthCredential {
32    /// True if the token expires within the next 5 minutes.
33    pub fn needs_refresh(&self) -> bool {
34        let now_ms = SystemTime::now()
35            .duration_since(UNIX_EPOCH)
36            .unwrap_or_default()
37            .as_millis() as u64;
38        now_ms >= self.expires_at.saturating_sub(5 * 60 * 1000)
39    }
40}
41
42// ---------------------------------------------------------------------------
43// Auto-import from Claude Code's own credential file
44// ---------------------------------------------------------------------------
45
46/// Raw format used by ~/.claude/.credentials.json
47#[derive(Deserialize)]
48#[serde(rename_all = "camelCase")]
49struct ClaudeCredentials {
50    claude_ai_oauth: Option<ClaudeOAuthRaw>,
51}
52
53#[derive(Deserialize)]
54#[serde(rename_all = "camelCase")]
55struct ClaudeOAuthRaw {
56    access_token: String,
57    refresh_token: String,
58    expires_at: u64,
59}
60
61// ---------------------------------------------------------------------------
62// Session info (plan + identity) from stored credentials
63// ---------------------------------------------------------------------------
64
65pub struct SessionInfo {
66    pub email_or_id: String,
67    pub plan: String,
68}
69
70/// Read plan and identity from Claude Code's stored credentials JSON.
71/// Works for both keychain and file-based storage.
72pub fn read_claude_session_info() -> Option<SessionInfo> {
73    #[derive(serde::Deserialize)]
74    #[serde(rename_all = "camelCase")]
75    struct Outer {
76        claude_ai_oauth: Option<Inner>,
77    }
78    #[derive(serde::Deserialize)]
79    #[serde(rename_all = "camelCase")]
80    struct Inner {
81        subscription_type: Option<String>,
82        #[serde(rename = "rateLimitTier")]
83        rate_limit_tier: Option<String>,
84    }
85
86    let text = read_raw_credentials_json()?;
87    let outer: Outer = serde_json::from_str(&text).ok()?;
88    let inner = outer.claude_ai_oauth?;
89
90    let plan = inner.subscription_type.unwrap_or_else(|| "pro".into());
91    let email_or_id = inner.rate_limit_tier.unwrap_or_else(|| "unknown".into());
92
93    Some(SessionInfo { email_or_id, plan })
94}
95
96/// Returns the raw credentials JSON string from keychain (macOS) or file.
97fn read_raw_credentials_json() -> Option<String> {
98    #[cfg(target_os = "macos")]
99    {
100        let out = std::process::Command::new("security")
101            .args(["find-generic-password", "-s", "Claude Code-credentials", "-w"])
102            .output()
103            .ok()?;
104        if out.status.success() {
105            let s = String::from_utf8(out.stdout).ok()?;
106            return Some(s.trim().to_owned());
107        }
108    }
109    std::fs::read_to_string(claude_credentials_path()).ok()
110}
111
112pub fn claude_credentials_path() -> PathBuf {
113    dirs::home_dir()
114        .unwrap_or_else(|| PathBuf::from("."))
115        .join(".claude")
116        .join(".credentials.json")
117}
118
119/// Read the OAuth credential from Claude Code's own credential file.
120/// On macOS, tries the Keychain first (service "Claude Code-credentials"),
121/// then falls back to ~/.claude/.credentials.json.
122pub fn read_claude_credentials() -> Option<OAuthCredential> {
123    // macOS: try Keychain first
124    #[cfg(target_os = "macos")]
125    if let Some(cred) = read_claude_credentials_keychain() {
126        return Some(cred);
127    }
128
129    // Fallback: JSON file (older Claude Code versions / non-macOS)
130    let path = claude_credentials_path();
131    let text = std::fs::read_to_string(&path).ok()?;
132    parse_claude_credentials_json(&text)
133}
134
135#[cfg(target_os = "macos")]
136fn read_claude_credentials_keychain() -> Option<OAuthCredential> {
137    let text = read_raw_credentials_json()?;
138    parse_claude_credentials_json(&text)
139}
140
141fn parse_claude_credentials_json(text: &str) -> Option<OAuthCredential> {
142    let raw: ClaudeCredentials = serde_json::from_str(text).ok()?;
143    let inner = raw.claude_ai_oauth?;
144    Some(OAuthCredential {
145        access_token: inner.access_token,
146        refresh_token: inner.refresh_token,
147        expires_at: inner.expires_at,
148        email: None,
149    })
150}
151
152// ---------------------------------------------------------------------------
153// Token refresh
154// ---------------------------------------------------------------------------
155
156/// Refresh an expired access token. Returns the updated credential.
157pub async fn refresh_token(cred: &OAuthCredential) -> Result<OAuthCredential> {
158    let client = reqwest::Client::new();
159
160    let resp = client
161        .post(OAUTH_TOKEN_URL)
162        .header("content-type", "application/x-www-form-urlencoded")
163        .body(format!(
164            "grant_type=refresh_token&refresh_token={}&client_id={}",
165            urlencoding::encode(&cred.refresh_token),
166            OAUTH_CLIENT_ID,
167        ))
168        .send()
169        .await
170        .context("token refresh request failed")?;
171
172    if !resp.status().is_success() {
173        let status = resp.status();
174        let body = resp.text().await.unwrap_or_default();
175        bail!("token refresh failed ({status}): {body}");
176    }
177
178    let body: serde_json::Value = resp.json().await.context("token refresh: invalid JSON")?;
179
180    let access_token = body["access_token"]
181        .as_str()
182        .context("token refresh: missing access_token")?
183        .to_owned();
184
185    let refresh_token = body["refresh_token"]
186        .as_str()
187        .unwrap_or(&cred.refresh_token)
188        .to_owned();
189
190    // expires_in is seconds from now
191    let expires_in_secs = body["expires_in"].as_u64().unwrap_or(3600);
192    let now_ms = SystemTime::now()
193        .duration_since(UNIX_EPOCH)
194        .unwrap_or_default()
195        .as_millis() as u64;
196    let expires_at = now_ms + expires_in_secs * 1000;
197
198    Ok(OAuthCredential { access_token, refresh_token, expires_at, email: cred.email.clone() })
199}
200
201// ---------------------------------------------------------------------------
202// Account identity (email) from roles endpoint
203// ---------------------------------------------------------------------------
204
205/// Fetch the account email from the Anthropic roles endpoint.
206/// Returns `None` on any error (non-fatal).
207pub async fn fetch_account_email(access_token: &str) -> Option<String> {
208    let client = reqwest::Client::builder()
209        .timeout(std::time::Duration::from_secs(8))
210        .build()
211        .ok()?;
212    let resp = client
213        .get("https://api.anthropic.com/api/oauth/claude_cli/roles")
214        .header("authorization", format!("Bearer {access_token}"))
215        .header("anthropic-version", "2023-06-01")
216        .header("anthropic-dangerous-direct-browser-access", "true")
217        .send()
218        .await
219        .ok()?;
220
221    if !resp.status().is_success() {
222        return None;
223    }
224
225    let body: serde_json::Value = resp.json().await.ok()?;
226    // organization_name is "email's Organization" — extract email prefix
227    let org = body["organization_name"].as_str()?;
228    if let Some(email) = org.strip_suffix("'s Organization") {
229        Some(email.to_owned())
230    } else {
231        Some(org.to_owned())
232    }
233}
234
235// ---------------------------------------------------------------------------
236// PKCE browser OAuth flow (for adding additional accounts)
237// ---------------------------------------------------------------------------
238
239struct Pkce {
240    verifier: String,
241    challenge: String,
242}
243
244fn generate_pkce() -> Pkce {
245    let verifier_bytes: [u8; 32] = rand_bytes();
246    let verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
247
248    let hash = Sha256::digest(verifier.as_bytes());
249    let challenge = URL_SAFE_NO_PAD.encode(hash);
250
251    Pkce { verifier, challenge }
252}
253
254fn rand_bytes<const N: usize>() -> [u8; N] {
255    use std::collections::hash_map::DefaultHasher;
256    use std::hash::{Hash, Hasher};
257    // Simple random bytes — not crypto-grade but fine for PKCE verifier.
258    // The verifier doesn't need to be secret from a client-side tool perspective.
259    let mut bytes = [0u8; N];
260    let seed = SystemTime::now()
261        .duration_since(UNIX_EPOCH)
262        .unwrap_or_default()
263        .subsec_nanos();
264    let pid = std::process::id();
265    for (i, b) in bytes.iter_mut().enumerate() {
266        let mut h = DefaultHasher::new();
267        (seed, pid, i).hash(&mut h);
268        *b = (h.finish() & 0xff) as u8;
269    }
270    bytes
271}
272
273fn random_state() -> String {
274    let bytes: [u8; 16] = rand_bytes();
275    hex::encode(bytes)
276}
277
278pub const OAUTH_REDIRECT_URI: &str = "https://platform.claude.com/oauth/code/callback";
279
280/// Run the PKCE OAuth flow using the registered redirect URI.
281///
282/// Opens the browser to claude.ai. After the user authorizes, the callback page
283/// displays a code (format: CODE#STATE). The user pastes it here; we split out
284/// the state and exchange the code at the token endpoint.
285pub async fn run_oauth_flow() -> Result<OAuthCredential> {
286    use std::io::{self, Write};
287
288    let pkce = generate_pkce();
289    let state = random_state();
290    let redirect_uri = OAUTH_REDIRECT_URI;
291
292    let scope = urlencoding::encode(
293        "user:inference user:profile user:file_upload user:mcp_servers user:sessions:claude_code",
294    );
295    let auth_url = format!(
296        "{base}?response_type=code\
297         &client_id={client_id}\
298         &redirect_uri={redirect}\
299         &scope={scope}\
300         &state={state}\
301         &code_challenge={challenge}\
302         &code_challenge_method=S256",
303        base = OAUTH_AUTHORIZE_URL,
304        client_id = OAUTH_CLIENT_ID,
305        redirect = urlencoding::encode(redirect_uri),
306        scope = scope,
307        state = state,
308        challenge = pkce.challenge,
309    );
310
311    println!("\nOpening browser for claude.ai login...");
312    println!("If it does not open automatically, visit:\n  {auth_url}\n");
313    open_browser(&auth_url);
314
315    println!("After you authorize, the page will show an authorization code.");
316    println!("Copy it and paste it here.");
317    println!();
318    print!("Paste code: ");
319    io::stdout().flush()?;
320
321    let mut pasted = String::new();
322    io::stdin().read_line(&mut pasted)?;
323    // Page shows "code#state"
324    let pasted = pasted.trim();
325    let (code, pasted_state) = if let Some((c, s)) = pasted.split_once('#') {
326        (c.trim(), s.trim())
327    } else {
328        (pasted, state.as_str())
329    };
330
331    if code.is_empty() {
332        bail!("No code entered.");
333    }
334
335    let cred = exchange_code(code, pasted_state, redirect_uri, &pkce.verifier).await?;
336    Ok(cred)
337}
338
339async fn exchange_code(code: &str, state: &str, redirect_uri: &str, verifier: &str) -> Result<OAuthCredential> {
340    let client = reqwest::Client::new();
341
342    let body = serde_json::json!({
343        "grant_type": "authorization_code",
344        "code": code,
345        "state": state,
346        "redirect_uri": redirect_uri,
347        "client_id": OAUTH_CLIENT_ID,
348        "code_verifier": verifier,
349    });
350
351    let resp = client
352        .post(OAUTH_TOKEN_URL)
353        .header("content-type", "application/json")
354        .header("anthropic-version", "2023-06-01")
355        .json(&body)
356        .send()
357        .await
358        .context("token exchange request failed")?;
359
360    if !resp.status().is_success() {
361        let status = resp.status();
362        let body = resp.text().await.unwrap_or_default();
363        bail!("token exchange failed ({status}): {body}");
364    }
365
366    let body: serde_json::Value = resp.json().await.context("token exchange: invalid JSON")?;
367
368    let access_token = body["access_token"]
369        .as_str()
370        .context("token exchange: missing access_token")?
371        .to_owned();
372    let refresh_token = body["refresh_token"]
373        .as_str()
374        .unwrap_or("")
375        .to_owned();
376    let expires_in = body["expires_in"].as_u64().unwrap_or(3600);
377    let now_ms = SystemTime::now()
378        .duration_since(UNIX_EPOCH)
379        .unwrap_or_default()
380        .as_millis() as u64;
381
382    Ok(OAuthCredential {
383        access_token,
384        refresh_token,
385        expires_at: now_ms + expires_in * 1000,
386        email: None,
387    })
388}
389
390// ---------------------------------------------------------------------------
391// Token revocation
392// ---------------------------------------------------------------------------
393
394pub const OAUTH_REVOKE_URL: &str = "https://platform.claude.com/v1/oauth/revoke";
395
396/// Revoke an OAuth token on the server. Best-effort — errors are non-fatal.
397pub async fn revoke_token(access_token: &str) -> bool {
398    let client = reqwest::Client::builder()
399        .timeout(std::time::Duration::from_secs(8))
400        .build()
401        .unwrap_or_default();
402    client
403        .post(OAUTH_REVOKE_URL)
404        .header("content-type", "application/x-www-form-urlencoded")
405        .header("anthropic-version", "2023-06-01")
406        .body(format!("token={}", urlencoding::encode(access_token)))
407        .send()
408        .await
409        .map(|r| r.status().is_success())
410        .unwrap_or(false)
411}
412
413fn open_browser(url: &str) {
414    #[cfg(target_os = "macos")]
415    { std::process::Command::new("open").arg(url).spawn().ok(); }
416
417    #[cfg(target_os = "linux")]
418    { std::process::Command::new("xdg-open").arg(url).spawn().ok(); }
419
420    #[cfg(target_os = "windows")]
421    { std::process::Command::new("cmd").args(["/c", "start", url]).spawn().ok(); }
422}