1use anyhow::{Context, Result, anyhow};
32use ring::aead::{self, Aad, LessSafeKey, NONCE_LEN, Nonce, UnboundKey};
33use ring::rand::{SecureRandom, SystemRandom};
34use serde::{Deserialize, Serialize};
35use std::fs;
36use std::path::PathBuf;
37
38pub use super::credentials::AuthCredentialsStoreMode;
39use super::pkce::PkceChallenge;
40use crate::storage_paths::auth_storage_dir;
41
42const OPENROUTER_AUTH_URL: &str = "https://openrouter.ai/auth";
44const OPENROUTER_KEYS_URL: &str = "https://openrouter.ai/api/v1/auth/keys";
45
46pub const DEFAULT_CALLBACK_PORT: u16 = 8484;
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
52#[serde(default)]
53pub struct OpenRouterOAuthConfig {
54 pub use_oauth: bool,
56 pub callback_port: u16,
58 pub auto_refresh: bool,
60 pub flow_timeout_secs: u64,
62}
63
64impl Default for OpenRouterOAuthConfig {
65 fn default() -> Self {
66 Self {
67 use_oauth: false,
68 callback_port: DEFAULT_CALLBACK_PORT,
69 auto_refresh: true,
70 flow_timeout_secs: 300,
71 }
72 }
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct OpenRouterToken {
78 pub api_key: String,
80 pub obtained_at: u64,
82 pub expires_at: Option<u64>,
84 pub label: Option<String>,
86}
87
88impl OpenRouterToken {
89 pub fn is_expired(&self) -> bool {
91 if let Some(expires_at) = self.expires_at {
92 let now = std::time::SystemTime::now()
93 .duration_since(std::time::UNIX_EPOCH)
94 .map(|d| d.as_secs())
95 .unwrap_or(0);
96 now >= expires_at
97 } else {
98 false
99 }
100 }
101}
102
103#[derive(Debug, Serialize, Deserialize)]
105struct EncryptedToken {
106 nonce: String,
108 ciphertext: String,
110 version: u8,
112}
113
114pub fn get_auth_url(challenge: &PkceChallenge, callback_port: u16) -> String {
123 let callback_url = format!("http://localhost:{}/callback", callback_port);
124 format!(
125 "{}?callback_url={}&code_challenge={}&code_challenge_method={}",
126 OPENROUTER_AUTH_URL,
127 urlencoding::encode(&callback_url),
128 urlencoding::encode(&challenge.code_challenge),
129 challenge.code_challenge_method
130 )
131}
132
133pub async fn exchange_code_for_token(code: &str, challenge: &PkceChallenge) -> Result<String> {
145 let client = reqwest::Client::new();
146
147 let payload = serde_json::json!({
148 "code": code,
149 "code_verifier": challenge.code_verifier,
150 "code_challenge_method": challenge.code_challenge_method
151 });
152
153 let response = client
154 .post(OPENROUTER_KEYS_URL)
155 .header("Content-Type", "application/json")
156 .json(&payload)
157 .send()
158 .await
159 .context("Failed to send token exchange request")?;
160
161 let status = response.status();
162 let body = response
163 .text()
164 .await
165 .context("Failed to read response body")?;
166
167 if !status.is_success() {
168 if status.as_u16() == 400 {
170 return Err(anyhow!(
171 "Invalid code_challenge_method. Ensure you're using the same method (S256) in both steps."
172 ));
173 } else if status.as_u16() == 403 {
174 return Err(anyhow!(
175 "Invalid code or code_verifier. The authorization code may have expired."
176 ));
177 } else if status.as_u16() == 405 {
178 return Err(anyhow!(
179 "Method not allowed. Ensure you're using POST over HTTPS."
180 ));
181 }
182 return Err(anyhow!("Token exchange failed (HTTP {}): {}", status, body));
183 }
184
185 let response_json: serde_json::Value =
187 serde_json::from_str(&body).context("Failed to parse token response")?;
188
189 let api_key = response_json
190 .get("key")
191 .and_then(|v| v.as_str())
192 .ok_or_else(|| anyhow!("Response missing 'key' field"))?
193 .to_string();
194
195 Ok(api_key)
196}
197
198fn get_token_path() -> Result<PathBuf> {
200 Ok(auth_storage_dir()?.join("openrouter.json"))
201}
202
203fn derive_encryption_key() -> Result<LessSafeKey> {
205 use ring::digest::{SHA256, digest};
206
207 let mut key_material = Vec::new();
209
210 if let Ok(hostname) = hostname::get() {
212 key_material.extend_from_slice(hostname.as_encoded_bytes());
213 }
214
215 #[cfg(unix)]
217 {
218 key_material.extend_from_slice(&nix::unistd::getuid().as_raw().to_le_bytes());
219 }
220 #[cfg(not(unix))]
221 {
222 if let Ok(user) = std::env::var("USER").or_else(|_| std::env::var("USERNAME")) {
223 key_material.extend_from_slice(user.as_bytes());
224 }
225 }
226
227 key_material.extend_from_slice(b"vtcode-openrouter-oauth-v1");
229
230 let hash = digest(&SHA256, &key_material);
232 let key_bytes: &[u8; 32] = hash.as_ref()[..32].try_into().context("Hash too short")?;
233
234 let unbound_key = UnboundKey::new(&aead::AES_256_GCM, key_bytes)
235 .map_err(|_| anyhow!("Invalid key length"))?;
236
237 Ok(LessSafeKey::new(unbound_key))
238}
239
240fn encrypt_token(token: &OpenRouterToken) -> Result<EncryptedToken> {
242 let key = derive_encryption_key()?;
243 let rng = SystemRandom::new();
244
245 let mut nonce_bytes = [0u8; NONCE_LEN];
247 rng.fill(&mut nonce_bytes)
248 .map_err(|_| anyhow!("Failed to generate nonce"))?;
249
250 let plaintext = serde_json::to_vec(token).context("Failed to serialize token")?;
252
253 let mut ciphertext = plaintext;
255 let nonce = Nonce::assume_unique_for_key(nonce_bytes);
256 key.seal_in_place_append_tag(nonce, Aad::empty(), &mut ciphertext)
257 .map_err(|_| anyhow!("Encryption failed"))?;
258
259 use base64::{Engine, engine::general_purpose::STANDARD};
260
261 Ok(EncryptedToken {
262 nonce: STANDARD.encode(nonce_bytes),
263 ciphertext: STANDARD.encode(&ciphertext),
264 version: 1,
265 })
266}
267
268fn decrypt_token(encrypted: &EncryptedToken) -> Result<OpenRouterToken> {
270 if encrypted.version != 1 {
271 return Err(anyhow!(
272 "Unsupported token format version: {}",
273 encrypted.version
274 ));
275 }
276
277 use base64::{Engine, engine::general_purpose::STANDARD};
278
279 let key = derive_encryption_key()?;
280
281 let nonce_bytes: [u8; NONCE_LEN] = STANDARD
282 .decode(&encrypted.nonce)
283 .context("Invalid nonce encoding")?
284 .try_into()
285 .map_err(|_| anyhow!("Invalid nonce length"))?;
286
287 let mut ciphertext = STANDARD
288 .decode(&encrypted.ciphertext)
289 .context("Invalid ciphertext encoding")?;
290
291 let nonce = Nonce::assume_unique_for_key(nonce_bytes);
292 let plaintext = key
293 .open_in_place(nonce, Aad::empty(), &mut ciphertext)
294 .map_err(|_| {
295 anyhow!("Decryption failed - token may be corrupted or from different machine")
296 })?;
297
298 serde_json::from_slice(plaintext).context("Failed to deserialize token")
299}
300
301pub fn save_oauth_token_with_mode(
307 token: &OpenRouterToken,
308 mode: AuthCredentialsStoreMode,
309) -> Result<()> {
310 let effective_mode = mode.effective_mode();
311
312 match effective_mode {
313 AuthCredentialsStoreMode::Keyring => save_oauth_token_keyring(token),
314 AuthCredentialsStoreMode::File => save_oauth_token_file(token),
315 _ => unreachable!(),
316 }
317}
318
319fn save_oauth_token_keyring(token: &OpenRouterToken) -> Result<()> {
321 let entry =
322 keyring::Entry::new("vtcode", "openrouter_oauth").context("Failed to access OS keyring")?;
323
324 let token_json =
326 serde_json::to_string(token).context("Failed to serialize token for keyring")?;
327
328 entry
329 .set_password(&token_json)
330 .context("Failed to store token in OS keyring")?;
331
332 tracing::info!("OAuth token saved to OS keyring");
333 Ok(())
334}
335
336fn save_oauth_token_file(token: &OpenRouterToken) -> Result<()> {
338 let path = get_token_path()?;
339 let encrypted = encrypt_token(token)?;
340 let json =
341 serde_json::to_string_pretty(&encrypted).context("Failed to serialize encrypted token")?;
342
343 fs::write(&path, json).context("Failed to write token file")?;
344
345 #[cfg(unix)]
347 {
348 use std::os::unix::fs::PermissionsExt;
349 let perms = fs::Permissions::from_mode(0o600);
350 fs::set_permissions(&path, perms).context("Failed to set token file permissions")?;
351 }
352
353 tracing::info!("OAuth token saved to {}", path.display());
354 Ok(())
355}
356
357pub fn save_oauth_token(token: &OpenRouterToken) -> Result<()> {
362 save_oauth_token_with_mode(token, AuthCredentialsStoreMode::default())
363}
364
365pub fn load_oauth_token_with_mode(
369 mode: AuthCredentialsStoreMode,
370) -> Result<Option<OpenRouterToken>> {
371 let effective_mode = mode.effective_mode();
372
373 match effective_mode {
374 AuthCredentialsStoreMode::Keyring => load_oauth_token_keyring(),
375 AuthCredentialsStoreMode::File => load_oauth_token_file(),
376 _ => unreachable!(),
377 }
378}
379
380fn load_oauth_token_keyring() -> Result<Option<OpenRouterToken>> {
382 let entry = match keyring::Entry::new("vtcode", "openrouter_oauth") {
383 Ok(e) => e,
384 Err(_) => return Ok(None),
385 };
386
387 let token_json = match entry.get_password() {
388 Ok(json) => json,
389 Err(keyring::Error::NoEntry) => return Ok(None),
390 Err(e) => return Err(anyhow!("Failed to read from keyring: {}", e)),
391 };
392
393 let token: OpenRouterToken =
394 serde_json::from_str(&token_json).context("Failed to parse token from keyring")?;
395
396 if token.is_expired() {
398 tracing::warn!("OAuth token has expired, removing...");
399 clear_oauth_token_keyring()?;
400 return Ok(None);
401 }
402
403 Ok(Some(token))
404}
405
406fn load_oauth_token_file() -> Result<Option<OpenRouterToken>> {
408 let path = get_token_path()?;
409
410 if !path.exists() {
411 return Ok(None);
412 }
413
414 let json = fs::read_to_string(&path).context("Failed to read token file")?;
415 let encrypted: EncryptedToken =
416 serde_json::from_str(&json).context("Failed to parse token file")?;
417
418 let token = decrypt_token(&encrypted)?;
419
420 if token.is_expired() {
422 tracing::warn!("OAuth token has expired, removing...");
423 clear_oauth_token_file()?;
424 return Ok(None);
425 }
426
427 Ok(Some(token))
428}
429
430pub fn load_oauth_token() -> Result<Option<OpenRouterToken>> {
442 match load_oauth_token_keyring() {
443 Ok(Some(token)) => return Ok(Some(token)),
444 Ok(None) => {
445 tracing::debug!("No token in keyring, checking file storage");
447 }
448 Err(e) => {
449 let error_str = e.to_string().to_lowercase();
451 if error_str.contains("no entry") || error_str.contains("not found") {
452 tracing::debug!("Keyring entry not found, checking file storage");
453 } else {
454 return Err(e);
457 }
458 }
459 }
460
461 load_oauth_token_file()
463}
464
465fn clear_oauth_token_keyring() -> Result<()> {
467 let entry = match keyring::Entry::new("vtcode", "openrouter_oauth") {
468 Ok(e) => e,
469 Err(_) => return Ok(()),
470 };
471
472 match entry.delete_credential() {
473 Ok(_) => tracing::info!("OAuth token cleared from keyring"),
474 Err(keyring::Error::NoEntry) => {}
475 Err(e) => return Err(anyhow!("Failed to clear keyring entry: {}", e)),
476 }
477
478 Ok(())
479}
480
481fn clear_oauth_token_file() -> Result<()> {
483 let path = get_token_path()?;
484
485 if path.exists() {
486 fs::remove_file(&path).context("Failed to remove token file")?;
487 tracing::info!("OAuth token cleared from file");
488 }
489
490 Ok(())
491}
492
493pub fn clear_oauth_token_with_mode(mode: AuthCredentialsStoreMode) -> Result<()> {
495 match mode.effective_mode() {
496 AuthCredentialsStoreMode::Keyring => clear_oauth_token_keyring(),
497 AuthCredentialsStoreMode::File => clear_oauth_token_file(),
498 AuthCredentialsStoreMode::Auto => {
499 let _ = clear_oauth_token_keyring();
500 let _ = clear_oauth_token_file();
501 Ok(())
502 }
503 }
504}
505
506pub fn clear_oauth_token() -> Result<()> {
507 let _ = clear_oauth_token_keyring();
509 let _ = clear_oauth_token_file();
510
511 tracing::info!("OAuth token cleared from all storage");
512 Ok(())
513}
514
515pub fn get_auth_status_with_mode(mode: AuthCredentialsStoreMode) -> Result<AuthStatus> {
517 match load_oauth_token_with_mode(mode)? {
518 Some(token) => {
519 let now = std::time::SystemTime::now()
520 .duration_since(std::time::UNIX_EPOCH)
521 .map(|d| d.as_secs())
522 .unwrap_or(0);
523
524 let age_seconds = now.saturating_sub(token.obtained_at);
525
526 Ok(AuthStatus::Authenticated {
527 label: token.label,
528 age_seconds,
529 expires_in: token.expires_at.map(|e| e.saturating_sub(now)),
530 })
531 }
532 None => Ok(AuthStatus::NotAuthenticated),
533 }
534}
535
536pub fn get_auth_status() -> Result<AuthStatus> {
537 match load_oauth_token()? {
538 Some(token) => {
539 let now = std::time::SystemTime::now()
540 .duration_since(std::time::UNIX_EPOCH)
541 .map(|d| d.as_secs())
542 .unwrap_or(0);
543
544 let age_seconds = now.saturating_sub(token.obtained_at);
545
546 Ok(AuthStatus::Authenticated {
547 label: token.label,
548 age_seconds,
549 expires_in: token.expires_at.map(|e| e.saturating_sub(now)),
550 })
551 }
552 None => Ok(AuthStatus::NotAuthenticated),
553 }
554}
555
556#[derive(Debug, Clone)]
558pub enum AuthStatus {
559 Authenticated {
561 label: Option<String>,
563 age_seconds: u64,
565 expires_in: Option<u64>,
567 },
568 NotAuthenticated,
570}
571
572impl AuthStatus {
573 pub fn is_authenticated(&self) -> bool {
575 matches!(self, AuthStatus::Authenticated { .. })
576 }
577
578 pub fn display_string(&self) -> String {
580 match self {
581 AuthStatus::Authenticated {
582 label,
583 age_seconds,
584 expires_in,
585 } => {
586 let label_str = label
587 .as_ref()
588 .map(|l| format!(" ({})", l))
589 .unwrap_or_default();
590 let age_str = humanize_duration(*age_seconds);
591 let expiry_str = expires_in
592 .map(|e| format!(", expires in {}", humanize_duration(e)))
593 .unwrap_or_default();
594 format!(
595 "Authenticated{}, obtained {}{}",
596 label_str, age_str, expiry_str
597 )
598 }
599 AuthStatus::NotAuthenticated => "Not authenticated".to_string(),
600 }
601 }
602}
603
604fn humanize_duration(seconds: u64) -> String {
606 if seconds < 60 {
607 format!("{}s ago", seconds)
608 } else if seconds < 3600 {
609 format!("{}m ago", seconds / 60)
610 } else if seconds < 86400 {
611 format!("{}h ago", seconds / 3600)
612 } else {
613 format!("{}d ago", seconds / 86400)
614 }
615}
616
617#[cfg(test)]
618mod tests {
619 use super::*;
620
621 #[test]
622 fn test_auth_url_generation() {
623 let challenge = PkceChallenge {
624 code_verifier: "test_verifier".to_string(),
625 code_challenge: "test_challenge".to_string(),
626 code_challenge_method: "S256".to_string(),
627 };
628
629 let url = get_auth_url(&challenge, 8484);
630
631 assert!(url.starts_with("https://openrouter.ai/auth"));
632 assert!(url.contains("callback_url="));
633 assert!(url.contains("code_challenge=test_challenge"));
634 assert!(url.contains("code_challenge_method=S256"));
635 }
636
637 #[test]
638 fn test_token_expiry_check() {
639 let now = std::time::SystemTime::now()
640 .duration_since(std::time::UNIX_EPOCH)
641 .unwrap()
642 .as_secs();
643
644 let token = OpenRouterToken {
646 api_key: "test".to_string(),
647 obtained_at: now,
648 expires_at: Some(now + 3600),
649 label: None,
650 };
651 assert!(!token.is_expired());
652
653 let expired_token = OpenRouterToken {
655 api_key: "test".to_string(),
656 obtained_at: now - 7200,
657 expires_at: Some(now - 3600),
658 label: None,
659 };
660 assert!(expired_token.is_expired());
661
662 let no_expiry_token = OpenRouterToken {
664 api_key: "test".to_string(),
665 obtained_at: now,
666 expires_at: None,
667 label: None,
668 };
669 assert!(!no_expiry_token.is_expired());
670 }
671
672 #[test]
673 fn test_encryption_roundtrip() {
674 let token = OpenRouterToken {
675 api_key: "sk-test-key-12345".to_string(),
676 obtained_at: 1234567890,
677 expires_at: Some(1234567890 + 86400),
678 label: Some("Test Token".to_string()),
679 };
680
681 let encrypted = encrypt_token(&token).unwrap();
682 let decrypted = decrypt_token(&encrypted).unwrap();
683
684 assert_eq!(decrypted.api_key, token.api_key);
685 assert_eq!(decrypted.obtained_at, token.obtained_at);
686 assert_eq!(decrypted.expires_at, token.expires_at);
687 assert_eq!(decrypted.label, token.label);
688 }
689
690 #[test]
691 fn test_auth_status_display() {
692 let status = AuthStatus::Authenticated {
693 label: Some("My App".to_string()),
694 age_seconds: 3700,
695 expires_in: Some(86000),
696 };
697
698 let display = status.display_string();
699 assert!(display.contains("Authenticated"));
700 assert!(display.contains("My App"));
701 }
702}