use serde::{Deserialize, Serialize};
use tokio::sync::oneshot;
mod pkce;
mod callback;
mod token;
mod storage;
mod browser;
mod openai_codex;
pub use pkce::{generate_code_verifier, generate_code_challenge, generate_state, build_auth_url};
pub use callback::{CallbackServerHandle, start_callback_server};
pub use token::{exchange_code_for_tokens, refresh_token, ensure_fresh_token, ensure_fresh_provider_token};
pub use storage::{auth_file_path, load_auth, load_provider_auth, save_auth, save_provider_auth};
pub use browser::open_browser;
pub use openai_codex::{extract_account_id as extract_codex_account_id, login as login_openai_codex};
pub(super) const CLIENT_ID: &str = "9d1c250a-e61b-44d9-88ed-5944d1962f5e";
pub(super) const AUTHORIZE_URL: &str = "https://claude.ai/oauth/authorize";
pub(super) const TOKEN_URL: &str = "https://platform.claude.com/v1/oauth/token";
pub(super) const CALLBACK_HOST: &str = "127.0.0.1";
pub(super) const CALLBACK_PORT: u16 = 53692;
pub(super) const SCOPES: &str = "org:create_api_key user:profile user:inference user:sessions:claude_code user:mcp_servers user:file_upload";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthCredentials {
#[serde(rename = "type")]
pub auth_type: String,
pub refresh: String,
pub access: String,
pub expires: u64,
#[serde(rename = "accountId", skip_serializing_if = "Option::is_none")]
pub account_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthFile {
pub anthropic: OAuthCredentials,
#[serde(rename = "openai-codex", default, skip_serializing_if = "Option::is_none")]
pub openai_codex: Option<OAuthCredentials>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct TokenResponse {
pub(crate) access_token: String,
pub(crate) refresh_token: String,
pub(crate) expires_in: u64,
}
#[derive(Debug, Clone)]
pub struct CallbackResult {
pub code: String,
pub state: String,
}
pub fn is_token_expired(creds: &OAuthCredentials) -> bool {
now_millis() >= creds.expires
}
pub(crate) fn now_millis() -> u64 {
crate::epoch_millis()
}
fn parse_manual_input(input: &str) -> (Option<String>, Option<String>) {
let trimmed = input.trim();
if let Ok(url) = url::Url::parse(trimmed) {
let code = url.query_pairs().find(|(k, _)| k == "code").map(|(_, v)| v.to_string());
let state = url.query_pairs().find(|(k, _)| k == "state").map(|(_, v)| v.to_string());
if code.is_some() {
return (code, state);
}
}
if trimmed.contains('#') {
let parts: Vec<&str> = trimmed.splitn(2, '#').collect();
if parts.len() == 2 && !parts[0].is_empty() && !parts[1].is_empty() {
return (Some(parts[0].to_string()), Some(parts[1].to_string()));
}
}
if !trimmed.is_empty() {
return (Some(trimmed.to_string()), None);
}
(None, None)
}
pub async fn login() -> std::result::Result<OAuthCredentials, String> {
let port = CALLBACK_PORT;
let verifier = generate_code_verifier();
let challenge = generate_code_challenge(&verifier);
let state = generate_state();
let (rx, server_handle) = start_callback_server(state.clone(), port).await?;
let auth_url = build_auth_url(&challenge, &state, port);
eprintln!("\n\x1b[1mOpening browser to sign in...\x1b[0m\n");
if let Err(e) = open_browser(&auth_url) {
eprintln!("Could not open browser automatically: {}", e);
}
eprintln!("\x1b[2mIf the browser didn't open, visit this URL:\x1b[0m");
eprintln!("\x1b[36m{}\x1b[0m\n", auth_url);
let (manual_tx, manual_rx) = oneshot::channel::<CallbackResult>();
let manual_state = state.clone();
let stdin_task = tokio::spawn(async move {
eprintln!("\x1b[2mOr paste the authorization code here:\x1b[0m");
let mut line = String::new();
let result = tokio::task::spawn_blocking(move || {
std::io::stdin().read_line(&mut line).ok();
line.trim().to_string()
})
.await;
if let Ok(input) = result {
if !input.is_empty() {
let (code, parsed_state) = parse_manual_input(&input);
if let Some(code) = code {
let _ = manual_tx.send(CallbackResult {
code,
state: parsed_state.unwrap_or(manual_state),
});
}
}
}
});
let result = tokio::select! {
callback = rx => {
match callback {
Ok(result) => result,
Err(_) => return Err("Callback channel closed".to_string()),
}
}
manual = manual_rx => {
match manual {
Ok(result) => result,
Err(_) => return Err("Manual input channel closed".to_string()),
}
}
};
stdin_task.abort();
if result.state != state {
server_handle.shutdown().await;
return Err("OAuth state mismatch — possible CSRF attack".to_string());
}
eprintln!("\n\x1b[1mExchanging code for tokens...\x1b[0m");
let creds = exchange_code_for_tokens(&result.code, &result.state, &verifier, port).await?;
server_handle.shutdown().await;
save_auth(&creds)?;
Ok(creds)
}
#[cfg(test)]
mod tests {
use super::*;
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
#[test]
fn test_generate_code_verifier() {
let verifier = generate_code_verifier();
assert!(!verifier.is_empty(), "Code verifier should not be empty");
assert!(verifier.len() > 20, "Code verifier should be longer than 20 characters");
let verifier2 = generate_code_verifier();
assert_ne!(verifier, verifier2, "Two calls should produce different verifiers");
}
#[test]
fn test_generate_code_challenge() {
let verifier = "test_verifier_123";
let challenge = generate_code_challenge(verifier);
assert!(!challenge.is_empty(), "Code challenge should not be empty");
let challenge2 = generate_code_challenge(verifier);
assert_eq!(challenge, challenge2, "Same verifier should produce same challenge");
let different_challenge = generate_code_challenge("different_verifier_456");
assert_ne!(challenge, different_challenge, "Different verifiers should produce different challenges");
}
#[test]
fn test_generate_state() {
let state = generate_state();
assert!(!state.is_empty(), "State should not be empty");
let state2 = generate_state();
assert_ne!(state, state2, "Two calls should produce different states");
}
#[test]
fn test_build_auth_url() {
let challenge = "test_challenge";
let state = "test_state";
let port = 8080;
let url = build_auth_url(challenge, state, port);
assert!(url.contains("claude.ai/oauth/authorize"));
assert!(url.contains("client_id=9d1c250a-e61b-44d9-88ed-5944d1962f5e"));
assert!(url.contains(&format!("code_challenge={}", challenge)));
assert!(url.contains(&format!("state={}", state)));
assert!(url.contains("localhost"));
assert!(url.contains(&port.to_string()));
assert!(url.contains("redirect_uri="));
}
#[test]
fn test_is_token_expired() {
let expired_creds = OAuthCredentials {
auth_type: "oauth".to_string(),
refresh: "test_refresh".to_string(),
access: "test_access".to_string(),
expires: 0,
account_id: None,
};
assert!(is_token_expired(&expired_creds));
let future_time = now_millis() + 3600000;
let fresh_creds = OAuthCredentials {
auth_type: "oauth".to_string(),
refresh: "test_refresh".to_string(),
access: "test_access".to_string(),
expires: future_time,
account_id: None,
};
assert!(!is_token_expired(&fresh_creds));
assert_eq!(fresh_creds.auth_type, "oauth");
}
#[test]
fn test_pkce_challenge_sha256() {
let verifier = "test_verifier_string";
let challenge = generate_code_challenge(verifier);
use sha2::{Sha256, Digest};
let mut hasher = Sha256::new();
hasher.update(verifier.as_bytes());
let hash = hasher.finalize();
let expected = URL_SAFE_NO_PAD.encode(hash);
assert_eq!(challenge, expected);
}
#[test]
fn test_code_verifier_length() {
let verifier = generate_code_verifier();
assert_eq!(verifier.len(), 43);
}
#[test]
fn test_state_length() {
let state = generate_state();
assert_eq!(state.len(), 43);
}
#[test]
fn test_build_auth_url_required_params() {
let url = build_auth_url("test_challenge", "test_state", 8080);
assert!(url.contains("response_type=code"));
assert!(url.contains("code_challenge_method=S256"));
assert!(url.contains("scope="));
assert!(url.contains("redirect_uri="));
assert!(url.contains("8080"));
}
#[test]
fn test_is_token_expired_edge_cases() {
let current_time = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64;
let exactly_now_creds = OAuthCredentials {
auth_type: "oauth".to_string(),
refresh: "test_refresh".to_string(),
access: "test_access".to_string(),
expires: current_time,
account_id: None,
};
assert!(is_token_expired(&exactly_now_creds));
let one_ms_future_creds = OAuthCredentials {
auth_type: "oauth".to_string(),
refresh: "test_refresh".to_string(),
access: "test_access".to_string(),
expires: current_time + 1,
account_id: None,
};
assert!(!is_token_expired(&one_ms_future_creds));
}
#[test]
fn test_auth_file_path() {
let path = auth_file_path();
let path_str = path.to_string_lossy();
assert!(path_str.ends_with("auth.json"));
}
#[test]
fn test_oauth_credentials_serialization_roundtrip() {
let original_creds = OAuthCredentials {
auth_type: "oauth".to_string(),
refresh: "test_refresh_token".to_string(),
access: "test_access_token".to_string(),
expires: 1234567890,
account_id: None,
};
let json = serde_json::to_string(&original_creds).expect("Should serialize");
let deserialized_creds: OAuthCredentials = serde_json::from_str(&json).expect("Should deserialize");
assert_eq!(original_creds.auth_type, deserialized_creds.auth_type);
assert_eq!(original_creds.refresh, deserialized_creds.refresh);
assert_eq!(original_creds.access, deserialized_creds.access);
assert_eq!(original_creds.expires, deserialized_creds.expires);
}
}