use anyhow::{Context, Result, anyhow};
use ring::aead::{self, Aad, LessSafeKey, NONCE_LEN, Nonce, UnboundKey};
use ring::rand::{SecureRandom, SystemRandom};
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::PathBuf;
pub use super::credentials::AuthCredentialsStoreMode;
use super::pkce::PkceChallenge;
use crate::storage_paths::auth_storage_dir;
const OPENROUTER_AUTH_URL: &str = "https://openrouter.ai/auth";
const OPENROUTER_KEYS_URL: &str = "https://openrouter.ai/api/v1/auth/keys";
pub const DEFAULT_CALLBACK_PORT: u16 = 8484;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(default)]
pub struct OpenRouterOAuthConfig {
pub use_oauth: bool,
pub callback_port: u16,
pub auto_refresh: bool,
pub flow_timeout_secs: u64,
}
impl Default for OpenRouterOAuthConfig {
fn default() -> Self {
Self {
use_oauth: false,
callback_port: DEFAULT_CALLBACK_PORT,
auto_refresh: true,
flow_timeout_secs: 300,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenRouterToken {
pub api_key: String,
pub obtained_at: u64,
pub expires_at: Option<u64>,
pub label: Option<String>,
}
impl OpenRouterToken {
pub fn is_expired(&self) -> bool {
if let Some(expires_at) = self.expires_at {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
now >= expires_at
} else {
false
}
}
}
#[derive(Debug, Serialize, Deserialize)]
struct EncryptedToken {
nonce: String,
ciphertext: String,
version: u8,
}
pub fn get_auth_url(challenge: &PkceChallenge, callback_port: u16) -> String {
let callback_url = format!("http://localhost:{}/callback", callback_port);
format!(
"{}?callback_url={}&code_challenge={}&code_challenge_method={}",
OPENROUTER_AUTH_URL,
urlencoding::encode(&callback_url),
urlencoding::encode(&challenge.code_challenge),
challenge.code_challenge_method
)
}
pub async fn exchange_code_for_token(code: &str, challenge: &PkceChallenge) -> Result<String> {
let client = reqwest::Client::new();
let payload = serde_json::json!({
"code": code,
"code_verifier": challenge.code_verifier,
"code_challenge_method": challenge.code_challenge_method
});
let response = client
.post(OPENROUTER_KEYS_URL)
.header("Content-Type", "application/json")
.json(&payload)
.send()
.await
.context("Failed to send token exchange request")?;
let status = response.status();
let body = response
.text()
.await
.context("Failed to read response body")?;
if !status.is_success() {
if status.as_u16() == 400 {
return Err(anyhow!(
"Invalid code_challenge_method. Ensure you're using the same method (S256) in both steps."
));
} else if status.as_u16() == 403 {
return Err(anyhow!(
"Invalid code or code_verifier. The authorization code may have expired."
));
} else if status.as_u16() == 405 {
return Err(anyhow!(
"Method not allowed. Ensure you're using POST over HTTPS."
));
}
return Err(anyhow!("Token exchange failed (HTTP {}): {}", status, body));
}
let response_json: serde_json::Value =
serde_json::from_str(&body).context("Failed to parse token response")?;
let api_key = response_json
.get("key")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow!("Response missing 'key' field"))?
.to_string();
Ok(api_key)
}
fn get_token_path() -> Result<PathBuf> {
Ok(auth_storage_dir()?.join("openrouter.json"))
}
fn derive_encryption_key() -> Result<LessSafeKey> {
use ring::digest::{SHA256, digest};
let mut key_material = Vec::new();
if let Ok(hostname) = hostname::get() {
key_material.extend_from_slice(hostname.as_encoded_bytes());
}
#[cfg(unix)]
{
key_material.extend_from_slice(&nix::unistd::getuid().as_raw().to_le_bytes());
}
#[cfg(not(unix))]
{
if let Ok(user) = std::env::var("USER").or_else(|_| std::env::var("USERNAME")) {
key_material.extend_from_slice(user.as_bytes());
}
}
key_material.extend_from_slice(b"vtcode-openrouter-oauth-v1");
let hash = digest(&SHA256, &key_material);
let key_bytes: &[u8; 32] = hash.as_ref()[..32].try_into().context("Hash too short")?;
let unbound_key = UnboundKey::new(&aead::AES_256_GCM, key_bytes)
.map_err(|_| anyhow!("Invalid key length"))?;
Ok(LessSafeKey::new(unbound_key))
}
fn encrypt_token(token: &OpenRouterToken) -> Result<EncryptedToken> {
let key = derive_encryption_key()?;
let rng = SystemRandom::new();
let mut nonce_bytes = [0u8; NONCE_LEN];
rng.fill(&mut nonce_bytes)
.map_err(|_| anyhow!("Failed to generate nonce"))?;
let plaintext = serde_json::to_vec(token).context("Failed to serialize token")?;
let mut ciphertext = plaintext;
let nonce = Nonce::assume_unique_for_key(nonce_bytes);
key.seal_in_place_append_tag(nonce, Aad::empty(), &mut ciphertext)
.map_err(|_| anyhow!("Encryption failed"))?;
use base64::{Engine, engine::general_purpose::STANDARD};
Ok(EncryptedToken {
nonce: STANDARD.encode(nonce_bytes),
ciphertext: STANDARD.encode(&ciphertext),
version: 1,
})
}
fn decrypt_token(encrypted: &EncryptedToken) -> Result<OpenRouterToken> {
if encrypted.version != 1 {
return Err(anyhow!(
"Unsupported token format version: {}",
encrypted.version
));
}
use base64::{Engine, engine::general_purpose::STANDARD};
let key = derive_encryption_key()?;
let nonce_bytes: [u8; NONCE_LEN] = STANDARD
.decode(&encrypted.nonce)
.context("Invalid nonce encoding")?
.try_into()
.map_err(|_| anyhow!("Invalid nonce length"))?;
let mut ciphertext = STANDARD
.decode(&encrypted.ciphertext)
.context("Invalid ciphertext encoding")?;
let nonce = Nonce::assume_unique_for_key(nonce_bytes);
let plaintext = key
.open_in_place(nonce, Aad::empty(), &mut ciphertext)
.map_err(|_| {
anyhow!("Decryption failed - token may be corrupted or from different machine")
})?;
serde_json::from_slice(plaintext).context("Failed to deserialize token")
}
pub fn save_oauth_token_with_mode(
token: &OpenRouterToken,
mode: AuthCredentialsStoreMode,
) -> Result<()> {
let effective_mode = mode.effective_mode();
match effective_mode {
AuthCredentialsStoreMode::Keyring => save_oauth_token_keyring(token),
AuthCredentialsStoreMode::File => save_oauth_token_file(token),
_ => unreachable!(),
}
}
fn save_oauth_token_keyring(token: &OpenRouterToken) -> Result<()> {
let entry =
keyring::Entry::new("vtcode", "openrouter_oauth").context("Failed to access OS keyring")?;
let token_json =
serde_json::to_string(token).context("Failed to serialize token for keyring")?;
entry
.set_password(&token_json)
.context("Failed to store token in OS keyring")?;
tracing::info!("OAuth token saved to OS keyring");
Ok(())
}
fn save_oauth_token_file(token: &OpenRouterToken) -> Result<()> {
let path = get_token_path()?;
let encrypted = encrypt_token(token)?;
let json =
serde_json::to_string_pretty(&encrypted).context("Failed to serialize encrypted token")?;
fs::write(&path, json).context("Failed to write token file")?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let perms = fs::Permissions::from_mode(0o600);
fs::set_permissions(&path, perms).context("Failed to set token file permissions")?;
}
tracing::info!("OAuth token saved to {}", path.display());
Ok(())
}
pub fn save_oauth_token(token: &OpenRouterToken) -> Result<()> {
save_oauth_token_with_mode(token, AuthCredentialsStoreMode::default())
}
pub fn load_oauth_token_with_mode(
mode: AuthCredentialsStoreMode,
) -> Result<Option<OpenRouterToken>> {
let effective_mode = mode.effective_mode();
match effective_mode {
AuthCredentialsStoreMode::Keyring => load_oauth_token_keyring(),
AuthCredentialsStoreMode::File => load_oauth_token_file(),
_ => unreachable!(),
}
}
fn load_oauth_token_keyring() -> Result<Option<OpenRouterToken>> {
let entry = match keyring::Entry::new("vtcode", "openrouter_oauth") {
Ok(e) => e,
Err(_) => return Ok(None),
};
let token_json = match entry.get_password() {
Ok(json) => json,
Err(keyring::Error::NoEntry) => return Ok(None),
Err(e) => return Err(anyhow!("Failed to read from keyring: {}", e)),
};
let token: OpenRouterToken =
serde_json::from_str(&token_json).context("Failed to parse token from keyring")?;
if token.is_expired() {
tracing::warn!("OAuth token has expired, removing...");
clear_oauth_token_keyring()?;
return Ok(None);
}
Ok(Some(token))
}
fn load_oauth_token_file() -> Result<Option<OpenRouterToken>> {
let path = get_token_path()?;
if !path.exists() {
return Ok(None);
}
let json = fs::read_to_string(&path).context("Failed to read token file")?;
let encrypted: EncryptedToken =
serde_json::from_str(&json).context("Failed to parse token file")?;
let token = decrypt_token(&encrypted)?;
if token.is_expired() {
tracing::warn!("OAuth token has expired, removing...");
clear_oauth_token_file()?;
return Ok(None);
}
Ok(Some(token))
}
pub fn load_oauth_token() -> Result<Option<OpenRouterToken>> {
match load_oauth_token_keyring() {
Ok(Some(token)) => return Ok(Some(token)),
Ok(None) => {
tracing::debug!("No token in keyring, checking file storage");
}
Err(e) => {
let error_str = e.to_string().to_lowercase();
if error_str.contains("no entry") || error_str.contains("not found") {
tracing::debug!("Keyring entry not found, checking file storage");
} else {
return Err(e);
}
}
}
load_oauth_token_file()
}
fn clear_oauth_token_keyring() -> Result<()> {
let entry = match keyring::Entry::new("vtcode", "openrouter_oauth") {
Ok(e) => e,
Err(_) => return Ok(()),
};
match entry.delete_credential() {
Ok(_) => tracing::info!("OAuth token cleared from keyring"),
Err(keyring::Error::NoEntry) => {}
Err(e) => return Err(anyhow!("Failed to clear keyring entry: {}", e)),
}
Ok(())
}
fn clear_oauth_token_file() -> Result<()> {
let path = get_token_path()?;
if path.exists() {
fs::remove_file(&path).context("Failed to remove token file")?;
tracing::info!("OAuth token cleared from file");
}
Ok(())
}
pub fn clear_oauth_token_with_mode(mode: AuthCredentialsStoreMode) -> Result<()> {
match mode.effective_mode() {
AuthCredentialsStoreMode::Keyring => clear_oauth_token_keyring(),
AuthCredentialsStoreMode::File => clear_oauth_token_file(),
AuthCredentialsStoreMode::Auto => {
let _ = clear_oauth_token_keyring();
let _ = clear_oauth_token_file();
Ok(())
}
}
}
pub fn clear_oauth_token() -> Result<()> {
let _ = clear_oauth_token_keyring();
let _ = clear_oauth_token_file();
tracing::info!("OAuth token cleared from all storage");
Ok(())
}
pub fn get_auth_status_with_mode(mode: AuthCredentialsStoreMode) -> Result<AuthStatus> {
match load_oauth_token_with_mode(mode)? {
Some(token) => {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let age_seconds = now.saturating_sub(token.obtained_at);
Ok(AuthStatus::Authenticated {
label: token.label,
age_seconds,
expires_in: token.expires_at.map(|e| e.saturating_sub(now)),
})
}
None => Ok(AuthStatus::NotAuthenticated),
}
}
pub fn get_auth_status() -> Result<AuthStatus> {
match load_oauth_token()? {
Some(token) => {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let age_seconds = now.saturating_sub(token.obtained_at);
Ok(AuthStatus::Authenticated {
label: token.label,
age_seconds,
expires_in: token.expires_at.map(|e| e.saturating_sub(now)),
})
}
None => Ok(AuthStatus::NotAuthenticated),
}
}
#[derive(Debug, Clone)]
pub enum AuthStatus {
Authenticated {
label: Option<String>,
age_seconds: u64,
expires_in: Option<u64>,
},
NotAuthenticated,
}
impl AuthStatus {
pub fn is_authenticated(&self) -> bool {
matches!(self, AuthStatus::Authenticated { .. })
}
pub fn display_string(&self) -> String {
match self {
AuthStatus::Authenticated {
label,
age_seconds,
expires_in,
} => {
let label_str = label
.as_ref()
.map(|l| format!(" ({})", l))
.unwrap_or_default();
let age_str = humanize_duration(*age_seconds);
let expiry_str = expires_in
.map(|e| format!(", expires in {}", humanize_duration(e)))
.unwrap_or_default();
format!(
"Authenticated{}, obtained {}{}",
label_str, age_str, expiry_str
)
}
AuthStatus::NotAuthenticated => "Not authenticated".to_string(),
}
}
}
fn humanize_duration(seconds: u64) -> String {
if seconds < 60 {
format!("{}s ago", seconds)
} else if seconds < 3600 {
format!("{}m ago", seconds / 60)
} else if seconds < 86400 {
format!("{}h ago", seconds / 3600)
} else {
format!("{}d ago", seconds / 86400)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_auth_url_generation() {
let challenge = PkceChallenge {
code_verifier: "test_verifier".to_string(),
code_challenge: "test_challenge".to_string(),
code_challenge_method: "S256".to_string(),
};
let url = get_auth_url(&challenge, 8484);
assert!(url.starts_with("https://openrouter.ai/auth"));
assert!(url.contains("callback_url="));
assert!(url.contains("code_challenge=test_challenge"));
assert!(url.contains("code_challenge_method=S256"));
}
#[test]
fn test_token_expiry_check() {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
let token = OpenRouterToken {
api_key: "test".to_string(),
obtained_at: now,
expires_at: Some(now + 3600),
label: None,
};
assert!(!token.is_expired());
let expired_token = OpenRouterToken {
api_key: "test".to_string(),
obtained_at: now - 7200,
expires_at: Some(now - 3600),
label: None,
};
assert!(expired_token.is_expired());
let no_expiry_token = OpenRouterToken {
api_key: "test".to_string(),
obtained_at: now,
expires_at: None,
label: None,
};
assert!(!no_expiry_token.is_expired());
}
#[test]
fn test_encryption_roundtrip() {
let token = OpenRouterToken {
api_key: "sk-test-key-12345".to_string(),
obtained_at: 1234567890,
expires_at: Some(1234567890 + 86400),
label: Some("Test Token".to_string()),
};
let encrypted = encrypt_token(&token).unwrap();
let decrypted = decrypt_token(&encrypted).unwrap();
assert_eq!(decrypted.api_key, token.api_key);
assert_eq!(decrypted.obtained_at, token.obtained_at);
assert_eq!(decrypted.expires_at, token.expires_at);
assert_eq!(decrypted.label, token.label);
}
#[test]
fn test_auth_status_display() {
let status = AuthStatus::Authenticated {
label: Some("My App".to_string()),
age_seconds: 3700,
expires_in: Some(86000),
};
let display = status.display_string();
assert!(display.contains("Authenticated"));
assert!(display.contains("My App"));
}
}