Skip to main content

codetether_agent/cli/
auth.rs

1//! Provider authentication commands.
2
3use super::{
4    AuthArgs, AuthCommand, CodexAuthArgs, CookieAuthArgs, CopilotAuthArgs, LoginAuthArgs,
5    RegisterAuthArgs,
6};
7use crate::provider::copilot::normalize_enterprise_domain;
8use crate::provider::openai_codex::{OAuthCredentials, OpenAiCodexProvider};
9use crate::secrets::{self, ProviderSecrets};
10use anyhow::{Context, Result};
11use reqwest::Client;
12use serde::Deserialize;
13use serde::de::{self, Deserializer};
14use serde_json::json;
15use std::collections::HashMap;
16use std::io::{self, Write};
17use std::path::PathBuf;
18use tokio::io::{AsyncReadExt, AsyncWriteExt};
19use tokio::net::TcpListener;
20use tokio::time::{Duration, Instant, sleep};
21
22const DEFAULT_GITHUB_DOMAIN: &str = "github.com";
23const OAUTH_POLLING_SAFETY_MARGIN_MS: u64 = 3000;
24const CODEX_CALLBACK_ADDR_V4: &str = "127.0.0.1:1455";
25const CODEX_CALLBACK_ADDR_V6: &str = "[::1]:1455";
26const CODEX_CALLBACK_DISPLAY_ADDR: &str = "localhost:1455";
27const CODEX_CALLBACK_TIMEOUT_SECS: u64 = 300;
28const CODEX_CALLBACK_TIMEOUT_SSH_SECS: u64 = 15;
29const CODEX_DEVICE_AUTH_TIMEOUT_SECS: u64 = 15 * 60;
30
31#[derive(Debug, Deserialize)]
32struct DeviceCodeResponse {
33    device_code: String,
34    user_code: String,
35    verification_uri: String,
36    #[serde(default)]
37    verification_uri_complete: Option<String>,
38    #[serde(default)]
39    interval: Option<u64>,
40}
41
42#[derive(Debug, Deserialize)]
43struct AccessTokenResponse {
44    #[serde(default)]
45    access_token: Option<String>,
46    #[serde(default)]
47    error: Option<String>,
48    #[serde(default)]
49    error_description: Option<String>,
50    #[serde(default)]
51    interval: Option<u64>,
52}
53
54#[derive(Debug, Deserialize)]
55struct CodexDeviceCodeResponse {
56    device_auth_id: String,
57    #[serde(alias = "usercode")]
58    user_code: String,
59    #[serde(default, deserialize_with = "deserialize_interval_seconds")]
60    interval: u64,
61}
62
63#[derive(Debug, Deserialize)]
64struct CodexDeviceCodeTokenResponse {
65    authorization_code: String,
66    code_verifier: String,
67}
68
69#[derive(Debug, Deserialize)]
70struct CodexDeviceErrorResponse {
71    #[serde(default)]
72    error: Option<String>,
73    #[serde(default)]
74    error_description: Option<String>,
75}
76
77pub async fn execute(args: AuthArgs) -> Result<()> {
78    match args.command {
79        AuthCommand::Copilot(copilot_args) => authenticate_copilot(copilot_args).await,
80        AuthCommand::Codex(codex_args) => authenticate_codex(codex_args).await,
81        AuthCommand::Cookies(cookie_args) => authenticate_cookie_import(cookie_args).await,
82        AuthCommand::Register(register_args) => authenticate_register(register_args).await,
83        AuthCommand::Login(login_args) => authenticate_login(login_args).await,
84    }
85}
86
87#[derive(Debug, Deserialize)]
88struct LoginResponsePayload {
89    access_token: String,
90    expires_at: String,
91    user: serde_json::Value,
92}
93
94async fn login_with_password(
95    client: &Client,
96    server_url: &str,
97    email: &str,
98    password: &str,
99) -> Result<LoginResponsePayload> {
100    let user_agent = format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"));
101
102    let resp = client
103        .post(format!("{}/v1/users/login", server_url))
104        .header("User-Agent", &user_agent)
105        .header("Content-Type", "application/json")
106        .json(&json!({
107            "email": email,
108            "password": password,
109        }))
110        .send()
111        .await
112        .context("Failed to connect to CodeTether server")?;
113
114    if !resp.status().is_success() {
115        let status = resp.status();
116        let body: serde_json::Value = resp.json().await.unwrap_or_default();
117        let detail = body
118            .get("detail")
119            .and_then(|v| v.as_str())
120            .unwrap_or("Authentication failed");
121        anyhow::bail!("Login failed ({}): {}", status, detail);
122    }
123
124    let login: LoginResponsePayload = resp
125        .json()
126        .await
127        .context("Failed to parse login response")?;
128
129    Ok(login)
130}
131
132fn write_saved_credentials(
133    server_url: &str,
134    email: &str,
135    login: &LoginResponsePayload,
136) -> Result<PathBuf> {
137    // Store token to ~/.config/codetether-agent/credentials.json
138    let cred_path = credential_file_path()?;
139    if let Some(parent) = cred_path.parent() {
140        std::fs::create_dir_all(parent)
141            .with_context(|| format!("Failed to create config dir: {}", parent.display()))?;
142    }
143
144    let creds = json!({
145        "server": server_url,
146        "access_token": login.access_token,
147        "expires_at": login.expires_at,
148        "email": email,
149    });
150
151    // Write with restrictive permissions (owner-only read/write)
152    #[cfg(unix)]
153    {
154        use std::os::unix::fs::OpenOptionsExt;
155        let file = std::fs::OpenOptions::new()
156            .write(true)
157            .create(true)
158            .truncate(true)
159            .mode(0o600)
160            .open(&cred_path)
161            .with_context(|| format!("Failed to write credentials to {}", cred_path.display()))?;
162        serde_json::to_writer_pretty(file, &creds)?;
163    }
164    #[cfg(not(unix))]
165    {
166        let file = std::fs::File::create(&cred_path)
167            .with_context(|| format!("Failed to write credentials to {}", cred_path.display()))?;
168        serde_json::to_writer_pretty(file, &creds)?;
169    }
170
171    Ok(cred_path)
172}
173
174async fn authenticate_register(args: RegisterAuthArgs) -> Result<()> {
175    #[derive(Debug, Deserialize)]
176    struct RegisterResponse {
177        user_id: String,
178        email: String,
179        message: String,
180        #[serde(default)]
181        instance_url: Option<String>,
182        #[serde(default)]
183        instance_namespace: Option<String>,
184        #[serde(default)]
185        provisioning_status: Option<String>,
186    }
187
188    let server_url = args.server.trim_end_matches('/').to_string();
189
190    let email = match args.email {
191        Some(e) => e,
192        None => {
193            print!("Email: ");
194            io::stdout().flush()?;
195            let mut email = String::new();
196            io::stdin().read_line(&mut email)?;
197            email.trim().to_string()
198        }
199    };
200
201    if email.is_empty() {
202        anyhow::bail!("Email is required");
203    }
204
205    let password = rpassword_prompt("Password (min 8 chars): ")?;
206    if password.trim().len() < 8 {
207        anyhow::bail!("Password must be at least 8 characters");
208    }
209    let confirm = rpassword_prompt("Confirm password: ")?;
210    if password != confirm {
211        anyhow::bail!("Passwords do not match");
212    }
213
214    println!("Registering with {}...", server_url);
215
216    let client = Client::new();
217    let user_agent = format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"));
218
219    let resp = client
220        .post(format!("{}/v1/users/register", server_url))
221        .header("User-Agent", &user_agent)
222        .header("Content-Type", "application/json")
223        .json(&json!({
224            "email": email,
225            "password": password,
226            "first_name": args.first_name,
227            "last_name": args.last_name,
228            "referral_source": args.referral_source,
229        }))
230        .send()
231        .await
232        .context("Failed to connect to CodeTether server")?;
233
234    if !resp.status().is_success() {
235        let status = resp.status();
236        let body: serde_json::Value = resp.json().await.unwrap_or_default();
237        let detail = body
238            .get("detail")
239            .and_then(|v| v.as_str())
240            .unwrap_or("Registration failed");
241        anyhow::bail!("Registration failed ({}): {}", status, detail);
242    }
243
244    let reg: RegisterResponse = resp
245        .json()
246        .await
247        .context("Failed to parse registration response")?;
248
249    println!(
250        "Account created for {} (user_id={})",
251        reg.email, reg.user_id
252    );
253    println!("{}", reg.message);
254    if let Some(status) = reg.provisioning_status.as_deref() {
255        println!("Provisioning status: {}", status);
256    }
257    if let Some(url) = reg.instance_url.as_deref() {
258        println!("Instance URL: {}", url);
259    }
260    if let Some(ns) = reg.instance_namespace.as_deref() {
261        println!("Instance namespace: {}", ns);
262    }
263
264    // Auto-login and save credentials for the worker.
265    println!("Logging in...");
266    let login = login_with_password(&client, &server_url, &reg.email, &password).await?;
267    let cred_path = write_saved_credentials(&server_url, &reg.email, &login)?;
268
269    let user_email = login
270        .user
271        .get("email")
272        .and_then(|v| v.as_str())
273        .unwrap_or(&reg.email);
274
275    println!("Logged in as {} (expires {})", user_email, login.expires_at);
276    println!("Credentials saved to {}", cred_path.display());
277    println!("\nThe CLI will automatically use these credentials for `codetether worker`.");
278
279    Ok(())
280}
281
282async fn authenticate_login(args: LoginAuthArgs) -> Result<()> {
283    let server_url = args.server.trim_end_matches('/').to_string();
284
285    // Prompt for email if not provided
286    let email = match args.email {
287        Some(e) => e,
288        None => {
289            print!("Email: ");
290            io::stdout().flush()?;
291            let mut email = String::new();
292            io::stdin().read_line(&mut email)?;
293            email.trim().to_string()
294        }
295    };
296
297    if email.is_empty() {
298        anyhow::bail!("Email is required");
299    }
300
301    // Prompt for password (no echo)
302    let password = rpassword_prompt("Password: ")?;
303    if password.is_empty() {
304        anyhow::bail!("Password is required");
305    }
306
307    println!("Authenticating with {}...", server_url);
308
309    let client = Client::new();
310
311    let login = login_with_password(&client, &server_url, &email, &password).await?;
312    let cred_path = write_saved_credentials(&server_url, &email, &login)?;
313
314    let user_email = login
315        .user
316        .get("email")
317        .and_then(|v| v.as_str())
318        .unwrap_or(&email);
319
320    println!("Logged in as {} (expires {})", user_email, login.expires_at);
321    println!("Credentials saved to {}", cred_path.display());
322    println!("\nThe CLI will automatically use these credentials for `codetether worker`.");
323
324    Ok(())
325}
326
327/// Read password from terminal without echo.
328fn rpassword_prompt(prompt: &str) -> Result<String> {
329    print!("{}", prompt);
330    io::stdout().flush()?;
331
332    // Disable echo on Unix
333    #[cfg(unix)]
334    {
335        use std::io::BufRead;
336        // Save terminal state
337        let fd = 0; // stdin
338        let orig = unsafe {
339            let mut termios = std::mem::zeroed::<libc::termios>();
340            libc::tcgetattr(fd, &mut termios);
341            termios
342        };
343
344        // Disable echo
345        unsafe {
346            let mut termios = orig;
347            termios.c_lflag &= !libc::ECHO;
348            libc::tcsetattr(fd, libc::TCSANOW, &termios);
349        }
350
351        let mut password = String::new();
352        let result = io::stdin().lock().read_line(&mut password);
353
354        // Restore terminal state
355        unsafe {
356            libc::tcsetattr(fd, libc::TCSANOW, &orig);
357        }
358        println!(); // newline after password entry
359
360        result?;
361        Ok(password.trim().to_string())
362    }
363
364    #[cfg(not(unix))]
365    {
366        let mut password = String::new();
367        io::stdin().read_line(&mut password)?;
368        Ok(password.trim().to_string())
369    }
370}
371
372/// Get the path to the credential storage file.
373fn credential_file_path() -> Result<std::path::PathBuf> {
374    use directories::ProjectDirs;
375    let dirs = ProjectDirs::from("ai", "codetether", "codetether-agent")
376        .ok_or_else(|| anyhow::anyhow!("Cannot determine config directory"))?;
377    Ok(dirs.config_dir().join("credentials.json"))
378}
379
380/// Stored credentials from `codetether auth login`.
381#[derive(Debug, Deserialize)]
382pub struct SavedCredentials {
383    pub server: String,
384    pub access_token: String,
385    pub expires_at: String,
386    #[serde(default)]
387    pub email: String,
388}
389
390/// Load saved credentials from disk, returning `None` if the file doesn't exist,
391/// is malformed, or the token has expired.
392pub fn load_saved_credentials() -> Option<SavedCredentials> {
393    let path = credential_file_path().ok()?;
394    let data = std::fs::read_to_string(&path).ok()?;
395    let creds: SavedCredentials = serde_json::from_str(&data).ok()?;
396
397    // Check expiry if parseable
398    if let Ok(expires) = chrono::DateTime::parse_from_rfc3339(&creds.expires_at)
399        && expires < chrono::Utc::now()
400    {
401        tracing::warn!("Saved credentials have expired — run `codetether auth login` to refresh");
402        return None;
403    }
404
405    Some(creds)
406}
407
408async fn authenticate_codex(args: CodexAuthArgs) -> Result<()> {
409    if secrets::secrets_manager().is_none() {
410        anyhow::bail!(
411            "HashiCorp Vault is not configured. Set VAULT_ADDR and VAULT_TOKEN before running `codetether auth codex`."
412        );
413    }
414    secrets::verify_reachable().await.context(
415        "HashiCorp Vault is configured but unreachable. Start Vault, correct VAULT_ADDR/VAULT_TOKEN, or unset stale Vault env before running `codetether auth codex`.",
416    )?;
417
418    if args.device_code {
419        let credentials = authenticate_codex_device_code().await?;
420        return store_codex_credentials(credentials).await;
421    }
422
423    let (authorization_url, code_verifier, expected_state) =
424        OpenAiCodexProvider::get_authorization_url();
425
426    println!("OpenAI Codex OAuth authentication");
427    println!(
428        "Sign in with your ChatGPT subscription account (Plus/Pro/Team/Enterprise) to use Codex models without API credits."
429    );
430
431    let is_ssh_session =
432        std::env::var_os("SSH_CONNECTION").is_some() || std::env::var_os("SSH_TTY").is_some();
433    if is_ssh_session {
434        println!("Detected SSH session.");
435        println!(
436            "If your browser runs on your local machine, port-forward callback traffic first:"
437        );
438        println!("  ssh -L 1455:127.0.0.1:1455 <remote-host>");
439        println!("Without forwarding, manual callback paste is still supported.");
440    }
441
442    println!("Open this URL: {}", authorization_url);
443    println!(
444        "After approving access, copy the browser callback URL and paste it below (it starts with http://localhost:1455/auth/callback)."
445    );
446
447    let callback_timeout = if is_ssh_session {
448        Duration::from_secs(CODEX_CALLBACK_TIMEOUT_SSH_SECS)
449    } else {
450        Duration::from_secs(CODEX_CALLBACK_TIMEOUT_SECS)
451    };
452    let auto_callback = capture_oauth_callback_auto(callback_timeout).await?;
453    let (authorization_code, returned_state) = if let Some(callback) = auto_callback {
454        println!("Captured callback automatically.");
455        callback
456    } else {
457        if is_ssh_session {
458            println!(
459                "Press Enter to switch to device-code auth, or paste callback URL from your browser."
460            );
461        }
462        let callback_input = if is_ssh_session {
463            prompt_optional_line("Callback URL: ")?
464        } else {
465            prompt_line("Callback URL: ")?
466        };
467
468        if callback_input.trim().is_empty() {
469            let credentials = authenticate_codex_device_code().await?;
470            return store_codex_credentials(credentials).await;
471        }
472
473        extract_oauth_code_and_state(&callback_input)?
474    };
475
476    if returned_state != expected_state {
477        anyhow::bail!(
478            "OAuth state mismatch. Retry `codetether auth codex` and paste the callback URL from the same login attempt."
479        );
480    }
481
482    let credentials = OpenAiCodexProvider::exchange_code(&authorization_code, &code_verifier)
483        .await
484        .context("Failed to exchange ChatGPT OAuth code for Codex tokens")?;
485
486    store_codex_credentials(credentials).await
487}
488
489async fn store_codex_credentials(credentials: OAuthCredentials) -> Result<()> {
490    let chatgpt_account_id = credentials
491        .chatgpt_account_id
492        .clone()
493        .or_else(|| {
494            credentials
495                .id_token
496                .as_deref()
497                .and_then(OpenAiCodexProvider::extract_chatgpt_account_id)
498        })
499        .or_else(|| OpenAiCodexProvider::extract_chatgpt_account_id(&credentials.access_token));
500
501    let mut expected_token_exchange_fallback = false;
502    let api_key = if let Some(id_token) = credentials.id_token.as_deref() {
503        match OpenAiCodexProvider::exchange_id_token_for_api_key(id_token).await {
504            Ok(key) => Some(key),
505            Err(error) => {
506                if is_expected_codex_id_token_exchange_fallback(&error) {
507                    expected_token_exchange_fallback = true;
508                    tracing::info!(
509                        error = %error,
510                        "Expected id_token exchange fallback; using OAuth access token for Codex backend"
511                    );
512                } else {
513                    tracing::warn!(
514                        error = %error,
515                        "Failed to exchange id_token for OpenAI API key; falling back to OAuth access token"
516                    );
517                }
518                None
519            }
520        }
521    } else {
522        tracing::warn!(
523            "OAuth token exchange did not return an id_token; cannot derive OpenAI API key"
524        );
525        None
526    };
527
528    let mut extra = HashMap::new();
529    extra.insert(
530        "access_token".to_string(),
531        serde_json::Value::String(credentials.access_token.clone()),
532    );
533    extra.insert(
534        "refresh_token".to_string(),
535        serde_json::Value::String(credentials.refresh_token.clone()),
536    );
537    extra.insert(
538        "expires_at".to_string(),
539        serde_json::Value::Number(credentials.expires_at.into()),
540    );
541    if let Some(id_token) = credentials.id_token.as_ref() {
542        extra.insert(
543            "id_token".to_string(),
544            serde_json::Value::String(id_token.clone()),
545        );
546    }
547    if let Some(account_id) = chatgpt_account_id.as_ref() {
548        extra.insert(
549            "chatgpt_account_id".to_string(),
550            serde_json::Value::String(account_id.clone()),
551        );
552    }
553
554    let provider_secrets = ProviderSecrets {
555        api_key: api_key.clone(),
556        base_url: None,
557        organization: chatgpt_account_id.clone(),
558        headers: None,
559        extra,
560    };
561
562    secrets::set_provider_secrets("openai-codex", &provider_secrets)
563        .await
564        .context("Failed to store openai-codex OAuth credentials in Vault")?;
565
566    let expires_display = chrono::DateTime::from_timestamp(credentials.expires_at as i64, 0)
567        .map(|ts| ts.to_rfc3339())
568        .unwrap_or_else(|| credentials.expires_at.to_string());
569
570    println!("Saved openai-codex credentials to HashiCorp Vault.");
571    if api_key.is_some() {
572        println!("Stored exchanged OpenAI API key for Codex model requests.");
573    } else {
574        println!(
575            "Could not exchange an OpenAI API key; Codex requests will use ChatGPT OAuth backend tokens."
576        );
577        if expected_token_exchange_fallback {
578            println!(
579                "Note: this fallback is expected when your id_token does not include organization context."
580            );
581        }
582    }
583    if let Some(account_id) = chatgpt_account_id {
584        println!("Using ChatGPT workspace/account ID: {}", account_id);
585    }
586    println!("Access token expires at {}", expires_display);
587    println!("You can now select models like `openai-codex/gpt-5.5`.");
588    Ok(())
589}
590
591fn is_expected_codex_id_token_exchange_fallback(error: &anyhow::Error) -> bool {
592    let msg = error.to_string().to_ascii_lowercase();
593    msg.contains("missing organization_id")
594        || (msg.contains("invalid_subject_token") && msg.contains("organization"))
595}
596
597async fn authenticate_codex_device_code() -> Result<OAuthCredentials> {
598    let client = Client::new();
599    let issuer = OpenAiCodexProvider::oauth_issuer().trim_end_matches('/');
600    let user_agent = format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"));
601    let device_code = request_codex_device_code(&client, issuer, &user_agent).await?;
602
603    println!("OpenAI Codex device authentication");
604    println!("Open this URL: {issuer}/codex/device");
605    println!("Enter code: {}", device_code.user_code);
606    println!("Waiting for authorization...");
607
608    let code = poll_for_codex_authorization_code(&client, issuer, &user_agent, &device_code)
609        .await
610        .context("Timed out waiting for device authorization")?;
611
612    let redirect_uri = format!("{issuer}/deviceauth/callback");
613    OpenAiCodexProvider::exchange_code_with_redirect_uri(
614        &code.authorization_code,
615        &code.code_verifier,
616        &redirect_uri,
617    )
618    .await
619    .context("Failed to exchange device authorization code for Codex tokens")
620}
621
622#[derive(Debug, Clone)]
623struct CookieRow {
624    domain: String,
625    include_subdomains: bool,
626    path: String,
627    secure: bool,
628    expires_epoch: i64,
629    name: String,
630    value: String,
631    http_only: bool,
632}
633
634async fn authenticate_cookie_import(args: CookieAuthArgs) -> Result<()> {
635    if secrets::secrets_manager().is_none() {
636        anyhow::bail!(
637            "HashiCorp Vault is not configured. Set VAULT_ADDR and VAULT_TOKEN before running `codetether auth cookies`."
638        );
639    }
640
641    let provider_id = args.provider.trim().to_string();
642    if provider_id.is_empty() {
643        anyhow::bail!("--provider cannot be empty");
644    }
645
646    let raw = tokio::fs::read_to_string(&args.file)
647        .await
648        .with_context(|| format!("Failed to read cookie file {}", args.file.display()))?;
649    let rows = parse_netscape_cookie_file(&raw);
650    if rows.is_empty() {
651        anyhow::bail!(
652            "No valid cookie rows found in {} (expected Netscape format)",
653            args.file.display()
654        );
655    }
656
657    let (selected, dropped_expired, dropped_non_auth) =
658        select_cookie_rows(&rows, &provider_id, args.keep_all);
659    if selected.is_empty() {
660        anyhow::bail!("No usable cookies remained after filtering");
661    }
662
663    let rendered = render_netscape_cookie_file(&selected);
664    let now = chrono::Utc::now();
665    let (earliest_expiry, latest_expiry) = cookie_expiry_bounds(&selected);
666    let cookie_names: Vec<String> = selected.iter().map(|row| row.name.clone()).collect();
667    let mut extra = HashMap::new();
668    extra.insert("cookies".to_string(), json!(rendered));
669    extra.insert("cookie_format".to_string(), json!("netscape"));
670    extra.insert("imported_at".to_string(), json!(now.to_rfc3339()));
671    extra.insert("cookie_count".to_string(), json!(selected.len()));
672    extra.insert("cookie_names".to_string(), json!(cookie_names));
673    extra.insert("dropped_expired".to_string(), json!(dropped_expired));
674    extra.insert("dropped_non_auth".to_string(), json!(dropped_non_auth));
675    extra.insert("keep_all".to_string(), json!(args.keep_all));
676    extra.insert(
677        "strategy".to_string(),
678        json!(if args.keep_all {
679            "cookies_all_v1"
680        } else {
681            "cookies_auth_subset_v1"
682        }),
683    );
684
685    if let Some(epoch) = earliest_expiry {
686        extra.insert("earliest_expiry_epoch".to_string(), json!(epoch));
687        if let Some(ts) = chrono::DateTime::from_timestamp(epoch, 0) {
688            extra.insert(
689                "earliest_expiry_rfc3339".to_string(),
690                json!(ts.to_rfc3339()),
691            );
692        }
693        extra.insert(
694            "rotate_before_epoch".to_string(),
695            json!(epoch.saturating_sub(24 * 60 * 60)),
696        );
697    }
698    if let Some(epoch) = latest_expiry {
699        extra.insert("latest_expiry_epoch".to_string(), json!(epoch));
700    }
701
702    let provider_secrets = ProviderSecrets {
703        api_key: None,
704        base_url: None,
705        organization: None,
706        headers: None,
707        extra,
708    };
709
710    secrets::set_provider_secrets(&provider_id, &provider_secrets)
711        .await
712        .with_context(|| format!("Failed to store {} cookies in Vault", provider_id))?;
713    let can_read_back = secrets::get_provider_secrets(&provider_id)
714        .await
715        .map(|saved| saved.extra.contains_key("cookies"))
716        .unwrap_or(false);
717
718    println!(
719        "Saved {} cookies to HashiCorp Vault provider '{}'.",
720        selected.len(),
721        provider_id
722    );
723    println!(
724        "Dropped {} expired and {} non-auth cookies.",
725        dropped_expired, dropped_non_auth
726    );
727    if let Some(epoch) = earliest_expiry
728        && let Some(ts) = chrono::DateTime::from_timestamp(epoch, 0)
729    {
730        println!(
731            "Earliest cookie expiry: {} (rotate at least 24h before this).",
732            ts.to_rfc3339()
733        );
734    }
735    println!(
736        "Vault path: {}/{}",
737        std::env::var("VAULT_SECRETS_PATH").unwrap_or_else(|_| "codetether/providers".to_string()),
738        provider_id
739    );
740    println!(
741        "Read-back verification: {}",
742        if can_read_back { "ok" } else { "failed" }
743    );
744    Ok(())
745}
746
747fn parse_netscape_cookie_file(raw: &str) -> Vec<CookieRow> {
748    raw.lines().filter_map(parse_netscape_cookie_line).collect()
749}
750
751fn parse_netscape_cookie_line(line: &str) -> Option<CookieRow> {
752    let trimmed = line.trim();
753    if trimmed.is_empty() || (trimmed.starts_with('#') && !trimmed.starts_with("#HttpOnly_")) {
754        return None;
755    }
756
757    let (http_only, normalized) = if let Some(rest) = trimmed.strip_prefix("#HttpOnly_") {
758        (true, rest)
759    } else {
760        (false, trimmed)
761    };
762    let parts: Vec<&str> = normalized.split('\t').collect();
763    if parts.len() < 7 {
764        return None;
765    }
766
767    Some(CookieRow {
768        domain: parts[0].trim().to_string(),
769        include_subdomains: parts[1].trim().eq_ignore_ascii_case("TRUE"),
770        path: parts[2].trim().to_string(),
771        secure: parts[3].trim().eq_ignore_ascii_case("TRUE"),
772        expires_epoch: parts[4].trim().parse::<i64>().unwrap_or(0),
773        name: parts[5].trim().to_string(),
774        value: parts[6].trim().to_string(),
775        http_only,
776    })
777}
778
779fn select_cookie_rows(
780    rows: &[CookieRow],
781    provider_id: &str,
782    keep_all: bool,
783) -> (Vec<CookieRow>, usize, usize) {
784    let now_epoch = chrono::Utc::now().timestamp();
785    let allowed = preferred_cookie_names(provider_id);
786    let mut selected_by_name: HashMap<String, CookieRow> = HashMap::new();
787    let mut dropped_expired = 0usize;
788    let mut dropped_non_auth = 0usize;
789
790    for row in rows {
791        if row.name.is_empty() {
792            continue;
793        }
794        if row.expires_epoch > 0 && row.expires_epoch <= now_epoch {
795            dropped_expired += 1;
796            continue;
797        }
798        if !keep_all && !allowed.is_empty() && !allowed.iter().any(|name| *name == row.name) {
799            dropped_non_auth += 1;
800            continue;
801        }
802        match selected_by_name.get(&row.name) {
803            Some(existing) if existing.expires_epoch >= row.expires_epoch => {}
804            _ => {
805                selected_by_name.insert(row.name.clone(), row.clone());
806            }
807        }
808    }
809
810    let mut selected: Vec<CookieRow> = selected_by_name.into_values().collect();
811    selected.sort_by(|left, right| left.name.cmp(&right.name));
812    (selected, dropped_expired, dropped_non_auth)
813}
814
815fn preferred_cookie_names(provider_id: &str) -> &'static [&'static str] {
816    match provider_id {
817        "nextdoor-web" => &[
818            "ndbr_at",
819            "ndbr_idt",
820            "ndbr_adt",
821            "csrftoken",
822            "ndp_session_id",
823            "WE",
824            "WE3P",
825            "DAID",
826        ],
827        "gemini-web" => &[
828            "__Secure-1PSID",
829            "__Secure-1PSIDTS",
830            "__Secure-1PSIDCC",
831            "SID",
832            "HSID",
833            "SSID",
834            "APISID",
835            "SAPISID",
836        ],
837        _ => &[],
838    }
839}
840
841fn render_netscape_cookie_file(rows: &[CookieRow]) -> String {
842    let mut lines = vec![
843        "# Netscape HTTP Cookie File".to_string(),
844        "# Generated by codetether auth cookies".to_string(),
845    ];
846    lines.extend(rows.iter().map(|row| {
847        let domain = if row.http_only {
848            format!("#HttpOnly_{}", row.domain)
849        } else {
850            row.domain.clone()
851        };
852        format!(
853            "{}\t{}\t{}\t{}\t{}\t{}\t{}",
854            domain,
855            bool_flag(row.include_subdomains),
856            row.path,
857            bool_flag(row.secure),
858            row.expires_epoch,
859            row.name,
860            row.value
861        )
862    }));
863    format!("{}\n", lines.join("\n"))
864}
865
866fn cookie_expiry_bounds(rows: &[CookieRow]) -> (Option<i64>, Option<i64>) {
867    let mut expiries = rows.iter().map(|row| row.expires_epoch).filter(|e| *e > 0);
868    let first = expiries.next();
869    let Some(mut min_epoch) = first else {
870        return (None, None);
871    };
872    let mut max_epoch = min_epoch;
873    for epoch in expiries {
874        if epoch < min_epoch {
875            min_epoch = epoch;
876        }
877        if epoch > max_epoch {
878            max_epoch = epoch;
879        }
880    }
881    (Some(min_epoch), Some(max_epoch))
882}
883
884fn bool_flag(value: bool) -> &'static str {
885    if value { "TRUE" } else { "FALSE" }
886}
887
888async fn capture_oauth_callback_auto(timeout: Duration) -> Result<Option<(String, String)>> {
889    let mut listeners = Vec::new();
890    for address in [CODEX_CALLBACK_ADDR_V4, CODEX_CALLBACK_ADDR_V6] {
891        match TcpListener::bind(address).await {
892            Ok(listener) => listeners.push(listener),
893            Err(error) => {
894                tracing::debug!(
895                    address,
896                    error = %error,
897                    "Failed to bind one OAuth callback listener address"
898                );
899            }
900        }
901    }
902
903    if listeners.is_empty() {
904        tracing::warn!(
905            ipv4 = CODEX_CALLBACK_ADDR_V4,
906            ipv6 = CODEX_CALLBACK_ADDR_V6,
907            "Failed to bind OAuth callback listener on localhost addresses; falling back to manual paste"
908        );
909        return Ok(None);
910    }
911
912    println!(
913        "Waiting up to {}s for automatic callback capture on http://{}/auth/callback ...",
914        timeout.as_secs(),
915        CODEX_CALLBACK_DISPLAY_ADDR
916    );
917
918    match wait_for_oauth_callback_any(listeners, timeout).await {
919        Ok(callback) => Ok(Some(callback)),
920        Err(error) => {
921            tracing::warn!(
922                error = %error,
923                "Automatic OAuth callback capture did not complete; falling back to manual paste"
924            );
925            Ok(None)
926        }
927    }
928}
929
930async fn wait_for_oauth_callback_any(
931    mut listeners: Vec<TcpListener>,
932    timeout: Duration,
933) -> Result<(String, String)> {
934    match listeners.len() {
935        0 => anyhow::bail!("No OAuth callback listeners were available"),
936        1 => {
937            let listener = listeners.pop().expect("length checked");
938            wait_for_oauth_callback(listener, timeout).await
939        }
940        _ => {
941            let listener2 = listeners.pop().expect("length checked");
942            let listener1 = listeners.pop().expect("length checked");
943
944            let mut waiter1 = Box::pin(wait_for_oauth_callback(listener1, timeout));
945            let mut waiter2 = Box::pin(wait_for_oauth_callback(listener2, timeout));
946
947            tokio::select! {
948                result1 = &mut waiter1 => {
949                    match result1 {
950                        Ok(callback) => Ok(callback),
951                        Err(err1) => match waiter2.await {
952                            Ok(callback) => Ok(callback),
953                            Err(err2) => anyhow::bail!("{}; {}", err1, err2),
954                        },
955                    }
956                }
957                result2 = &mut waiter2 => {
958                    match result2 {
959                        Ok(callback) => Ok(callback),
960                        Err(err2) => match waiter1.await {
961                            Ok(callback) => Ok(callback),
962                            Err(err1) => anyhow::bail!("{}; {}", err2, err1),
963                        },
964                    }
965                }
966            }
967        }
968    }
969}
970
971async fn wait_for_oauth_callback(
972    listener: TcpListener,
973    timeout: Duration,
974) -> Result<(String, String)> {
975    let deadline = Instant::now() + timeout;
976
977    loop {
978        let now = Instant::now();
979        if now >= deadline {
980            anyhow::bail!("Timed out waiting for OAuth callback");
981        }
982        let remaining = deadline - now;
983
984        let (mut stream, peer_addr) = tokio::time::timeout(remaining, listener.accept())
985            .await
986            .context("Timed out waiting for callback connection")?
987            .context("Failed to accept callback connection")?;
988
989        let request = read_http_request(&mut stream).await?;
990        match parse_oauth_callback_request(&request) {
991            Ok((code, state)) => {
992                write_http_response(
993                    &mut stream,
994                    200,
995                    "OK",
996                    "<html><body><h1>CodeTether login complete</h1><p>You can close this tab.</p></body></html>",
997                )
998                .await?;
999                return Ok((code, state));
1000            }
1001            Err(error) => {
1002                tracing::warn!(
1003                    peer = %peer_addr,
1004                    error = %error,
1005                    "Ignoring non-callback HTTP request while waiting for OAuth callback"
1006                );
1007                write_http_response(
1008                    &mut stream,
1009                    400,
1010                    "Bad Request",
1011                    "<html><body><h1>Invalid callback request</h1><p>Retry authorization from CodeTether.</p></body></html>",
1012                )
1013                .await?;
1014            }
1015        }
1016    }
1017}
1018
1019async fn read_http_request(stream: &mut tokio::net::TcpStream) -> Result<String> {
1020    let mut buffer = [0u8; 8192];
1021    let read = stream
1022        .read(&mut buffer)
1023        .await
1024        .context("Failed to read callback request")?;
1025    if read == 0 {
1026        anyhow::bail!("Callback request stream closed before data was received");
1027    }
1028    Ok(String::from_utf8_lossy(&buffer[..read]).to_string())
1029}
1030
1031fn parse_oauth_callback_request(request: &str) -> Result<(String, String)> {
1032    let first_line = request
1033        .lines()
1034        .next()
1035        .ok_or_else(|| anyhow::anyhow!("Missing HTTP request line"))?;
1036    let mut parts = first_line.split_whitespace();
1037
1038    let method = parts.next().unwrap_or_default();
1039    let method = method.to_ascii_uppercase();
1040
1041    let target = parts
1042        .next()
1043        .ok_or_else(|| anyhow::anyhow!("Missing callback target"))?;
1044    let target_query = target.split_once('?').map(|(_, query)| query.trim());
1045    let body = request
1046        .split_once("\r\n\r\n")
1047        .map(|(_, body)| body)
1048        .or_else(|| request.split_once("\n\n").map(|(_, body)| body))
1049        .map(str::trim)
1050        .filter(|body| !body.is_empty());
1051
1052    let callback_payload = match method.as_str() {
1053        "GET" => target_query
1054            .or(body)
1055            .ok_or_else(|| anyhow::anyhow!("Callback target missing query string"))?,
1056        "POST" => body
1057            .or(target_query)
1058            .ok_or_else(|| anyhow::anyhow!("Callback POST body missing OAuth payload"))?,
1059        _ => anyhow::bail!("Unsupported callback method: {}", method),
1060    };
1061
1062    extract_oauth_code_and_state(callback_payload)
1063}
1064
1065async fn write_http_response(
1066    stream: &mut tokio::net::TcpStream,
1067    status_code: u16,
1068    status_text: &str,
1069    body: &str,
1070) -> Result<()> {
1071    let response = format!(
1072        "HTTP/1.1 {} {}\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
1073        status_code,
1074        status_text,
1075        body.len(),
1076        body
1077    );
1078    stream
1079        .write_all(response.as_bytes())
1080        .await
1081        .context("Failed to write callback response")?;
1082    Ok(())
1083}
1084
1085async fn authenticate_copilot(args: CopilotAuthArgs) -> Result<()> {
1086    if secrets::secrets_manager().is_none() {
1087        anyhow::bail!(
1088            "HashiCorp Vault is not configured. Set VAULT_ADDR and VAULT_TOKEN before running `codetether auth copilot`."
1089        );
1090    }
1091
1092    let (provider_id, domain, enterprise_domain) = match args.enterprise_url {
1093        Some(raw) => {
1094            let domain = normalize_enterprise_domain(&raw);
1095            if domain.is_empty() {
1096                anyhow::bail!("--enterprise-url cannot be empty");
1097            }
1098            ("github-copilot-enterprise", domain.clone(), Some(domain))
1099        }
1100        None => ("github-copilot", DEFAULT_GITHUB_DOMAIN.to_string(), None),
1101    };
1102
1103    let client = Client::new();
1104    let client_id = resolve_client_id(args.client_id)?;
1105    let user_agent = format!("codetether-agent/{}", env!("CARGO_PKG_VERSION"));
1106    let device = request_device_code(&client, &domain, &user_agent, &client_id).await?;
1107
1108    println!("GitHub Copilot device authentication");
1109    println!(
1110        "Open this URL: {}",
1111        device
1112            .verification_uri_complete
1113            .as_deref()
1114            .unwrap_or(&device.verification_uri)
1115    );
1116    println!("Enter code: {}", device.user_code);
1117    println!("Waiting for authorization...");
1118
1119    let token = poll_for_access_token(&client, &domain, &user_agent, &client_id, &device).await?;
1120
1121    let mut extra = HashMap::new();
1122    if let Some(enterprise_url) = enterprise_domain {
1123        extra.insert(
1124            "enterpriseUrl".to_string(),
1125            serde_json::Value::String(enterprise_url),
1126        );
1127    }
1128
1129    let provider_secrets = ProviderSecrets {
1130        api_key: Some(token),
1131        base_url: None,
1132        organization: None,
1133        headers: None,
1134        extra,
1135    };
1136
1137    secrets::set_provider_secrets(provider_id, &provider_secrets)
1138        .await
1139        .with_context(|| format!("Failed to store {} auth token in Vault", provider_id))?;
1140
1141    println!("Saved {} credentials to HashiCorp Vault.", provider_id);
1142    Ok(())
1143}
1144
1145async fn request_codex_device_code(
1146    client: &Client,
1147    issuer: &str,
1148    user_agent: &str,
1149) -> Result<CodexDeviceCodeResponse> {
1150    let url = format!("{issuer}/api/accounts/deviceauth/usercode");
1151    let response = client
1152        .post(&url)
1153        .header("Accept", "application/json")
1154        .header("Content-Type", "application/json")
1155        .header("User-Agent", user_agent)
1156        .json(&json!({
1157            "client_id": OpenAiCodexProvider::oauth_client_id(),
1158        }))
1159        .send()
1160        .await
1161        .with_context(|| format!("Failed to reach device authorization endpoint: {url}"))?;
1162
1163    let status = response.status();
1164    if !status.is_success() {
1165        let body = response.text().await.unwrap_or_default();
1166        if status == reqwest::StatusCode::NOT_FOUND {
1167            anyhow::bail!(
1168                "Device code login is not enabled for this Codex server. Use browser OAuth flow instead."
1169            );
1170        }
1171        anyhow::bail!(
1172            "Failed to initiate Codex device authorization ({}): {}",
1173            status,
1174            truncate_body(&body)
1175        );
1176    }
1177
1178    let mut device: CodexDeviceCodeResponse = response
1179        .json()
1180        .await
1181        .context("Failed to parse Codex device authorization response")?;
1182    if device.interval == 0 {
1183        device.interval = 5;
1184    }
1185    Ok(device)
1186}
1187
1188async fn poll_for_codex_authorization_code(
1189    client: &Client,
1190    issuer: &str,
1191    user_agent: &str,
1192    device: &CodexDeviceCodeResponse,
1193) -> Result<CodexDeviceCodeTokenResponse> {
1194    let url = format!("{issuer}/api/accounts/deviceauth/token");
1195    let interval_secs = device.interval.max(1);
1196    let timeout = Duration::from_secs(CODEX_DEVICE_AUTH_TIMEOUT_SECS);
1197    let start = Instant::now();
1198
1199    loop {
1200        let response = client
1201            .post(&url)
1202            .header("Accept", "application/json")
1203            .header("Content-Type", "application/json")
1204            .header("User-Agent", user_agent)
1205            .json(&json!({
1206                "device_auth_id": device.device_auth_id,
1207                "user_code": device.user_code,
1208            }))
1209            .send()
1210            .await
1211            .with_context(|| format!("Failed to poll device authorization endpoint: {url}"))?;
1212
1213        let status = response.status();
1214        if status.is_success() {
1215            return response
1216                .json()
1217                .await
1218                .context("Failed to parse Codex device authorization response");
1219        }
1220
1221        let body = response.text().await.unwrap_or_default();
1222        if status == reqwest::StatusCode::FORBIDDEN || status == reqwest::StatusCode::NOT_FOUND {
1223            if start.elapsed() >= timeout {
1224                anyhow::bail!(
1225                    "Device authorization timed out after {} seconds",
1226                    CODEX_DEVICE_AUTH_TIMEOUT_SECS
1227                );
1228            }
1229            sleep_with_margin(interval_secs).await;
1230            continue;
1231        }
1232
1233        if let Ok(payload) = serde_json::from_str::<CodexDeviceErrorResponse>(&body)
1234            && let Some(error) = payload.error.as_deref()
1235        {
1236            let description = payload
1237                .error_description
1238                .as_deref()
1239                .unwrap_or("No error description provided");
1240            anyhow::bail!("Codex device authorization failed: {error} ({description})");
1241        }
1242
1243        anyhow::bail!(
1244            "Codex device authorization failed ({}): {}",
1245            status,
1246            truncate_body(&body)
1247        );
1248    }
1249}
1250
1251async fn request_device_code(
1252    client: &Client,
1253    domain: &str,
1254    user_agent: &str,
1255    client_id: &str,
1256) -> Result<DeviceCodeResponse> {
1257    let url = format!("https://{domain}/login/device/code");
1258    let response = client
1259        .post(&url)
1260        .header("Accept", "application/json")
1261        .header("Content-Type", "application/json")
1262        .header("User-Agent", user_agent)
1263        .json(&json!({
1264            "client_id": client_id,
1265            "scope": "read:user",
1266        }))
1267        .send()
1268        .await
1269        .with_context(|| format!("Failed to reach device authorization endpoint: {url}"))?;
1270
1271    let status = response.status();
1272    if !status.is_success() {
1273        let body = response.text().await.unwrap_or_default();
1274        anyhow::bail!(
1275            "Failed to initiate device authorization ({}): {}",
1276            status,
1277            truncate_body(&body)
1278        );
1279    }
1280
1281    let mut device: DeviceCodeResponse = response
1282        .json()
1283        .await
1284        .context("Failed to parse device authorization response")?;
1285    if device.interval.unwrap_or(0) == 0 {
1286        device.interval = Some(5);
1287    }
1288    Ok(device)
1289}
1290
1291async fn poll_for_access_token(
1292    client: &Client,
1293    domain: &str,
1294    user_agent: &str,
1295    client_id: &str,
1296    device: &DeviceCodeResponse,
1297) -> Result<String> {
1298    let url = format!("https://{domain}/login/oauth/access_token");
1299    let mut interval_secs = device.interval.unwrap_or(5).max(1);
1300
1301    loop {
1302        let response = client
1303            .post(&url)
1304            .header("Accept", "application/json")
1305            .header("Content-Type", "application/json")
1306            .header("User-Agent", user_agent)
1307            .json(&json!({
1308                "client_id": client_id,
1309                "device_code": device.device_code,
1310                "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
1311            }))
1312            .send()
1313            .await
1314            .with_context(|| format!("Failed to poll token endpoint: {url}"))?;
1315
1316        let status = response.status();
1317        if !status.is_success() {
1318            let body = response.text().await.unwrap_or_default();
1319            anyhow::bail!(
1320                "Failed to exchange device code for access token ({}): {}",
1321                status,
1322                truncate_body(&body)
1323            );
1324        }
1325
1326        let payload: AccessTokenResponse = response
1327            .json()
1328            .await
1329            .context("Failed to parse OAuth token response")?;
1330
1331        if let Some(token) = payload.access_token
1332            && !token.trim().is_empty()
1333        {
1334            return Ok(token);
1335        }
1336
1337        match payload.error.as_deref() {
1338            Some("authorization_pending") => sleep_with_margin(interval_secs).await,
1339            Some("slow_down") => {
1340                interval_secs = payload
1341                    .interval
1342                    .filter(|value| *value > 0)
1343                    .unwrap_or(interval_secs + 5);
1344                sleep_with_margin(interval_secs).await;
1345            }
1346            Some(error) => {
1347                let description = payload
1348                    .error_description
1349                    .unwrap_or_else(|| "No error description provided".to_string());
1350                anyhow::bail!("Copilot OAuth failed: {} ({})", error, description);
1351            }
1352            None => sleep_with_margin(interval_secs).await,
1353        }
1354    }
1355}
1356
1357fn resolve_client_id(client_id: Option<String>) -> Result<String> {
1358    let id = client_id
1359        .map(|value| value.trim().to_string())
1360        .filter(|value| !value.is_empty())
1361        .ok_or_else(|| {
1362            anyhow::anyhow!(
1363                "GitHub OAuth client ID is required. Pass `--client-id <id>` or set `CODETETHER_COPILOT_OAUTH_CLIENT_ID`."
1364            )
1365        })?;
1366
1367    Ok(id)
1368}
1369
1370async fn sleep_with_margin(interval_secs: u64) {
1371    sleep(Duration::from_millis(
1372        interval_secs.saturating_mul(1000) + OAUTH_POLLING_SAFETY_MARGIN_MS,
1373    ))
1374    .await;
1375}
1376
1377fn truncate_body(body: &str) -> String {
1378    const MAX_LEN: usize = 300;
1379    if body.len() <= MAX_LEN {
1380        body.to_string()
1381    } else {
1382        format!("{}...", &body[..MAX_LEN])
1383    }
1384}
1385
1386fn deserialize_interval_seconds<'de, D>(deserializer: D) -> std::result::Result<u64, D::Error>
1387where
1388    D: Deserializer<'de>,
1389{
1390    #[derive(Deserialize)]
1391    #[serde(untagged)]
1392    enum IntervalValue {
1393        Number(u64),
1394        String(String),
1395    }
1396
1397    let value = Option::<IntervalValue>::deserialize(deserializer)?;
1398    match value {
1399        Some(IntervalValue::Number(value)) => Ok(value),
1400        Some(IntervalValue::String(value)) => value
1401            .trim()
1402            .parse::<u64>()
1403            .map_err(|error| de::Error::custom(format!("invalid interval value: {error}"))),
1404        None => Ok(0),
1405    }
1406}
1407
1408fn prompt_line(prompt: &str) -> Result<String> {
1409    print!("{prompt}");
1410    io::stdout().flush()?;
1411
1412    let mut input = String::new();
1413    io::stdin().read_line(&mut input)?;
1414    let trimmed = input.trim().to_string();
1415    if trimmed.is_empty() {
1416        anyhow::bail!("Input is required");
1417    }
1418    Ok(trimmed)
1419}
1420
1421fn prompt_optional_line(prompt: &str) -> Result<String> {
1422    print!("{prompt}");
1423    io::stdout().flush()?;
1424
1425    let mut input = String::new();
1426    io::stdin().read_line(&mut input)?;
1427    Ok(input.trim().to_string())
1428}
1429
1430fn extract_oauth_code_and_state(callback_input: &str) -> Result<(String, String)> {
1431    let input = callback_input.trim();
1432    if input.is_empty() {
1433        anyhow::bail!("Callback URL is required");
1434    }
1435
1436    let query = if input.contains("://") {
1437        let url =
1438            reqwest::Url::parse(input).with_context(|| format!("Invalid callback URL: {input}"))?;
1439        url.query()
1440            .map(str::to_string)
1441            .ok_or_else(|| anyhow::anyhow!("Callback URL is missing query parameters"))?
1442    } else if let Some((_, params)) = input.split_once('?') {
1443        params.to_string()
1444    } else {
1445        input.to_string()
1446    };
1447
1448    let params = parse_query_pairs(&query);
1449    if let Some(error) = params.get("error") {
1450        let error_description = params
1451            .get("error_description")
1452            .map(String::as_str)
1453            .unwrap_or("No error description provided");
1454        anyhow::bail!(
1455            "OAuth authorization failed: {} ({})",
1456            error,
1457            error_description
1458        );
1459    }
1460
1461    let code = params
1462        .get("code")
1463        .cloned()
1464        .filter(|value| !value.is_empty())
1465        .ok_or_else(|| anyhow::anyhow!("Callback URL does not include an OAuth code"))?;
1466    let state = params
1467        .get("state")
1468        .cloned()
1469        .filter(|value| !value.is_empty())
1470        .ok_or_else(|| anyhow::anyhow!("Callback URL does not include OAuth state"))?;
1471
1472    Ok((code, state))
1473}
1474
1475fn parse_query_pairs(query: &str) -> HashMap<String, String> {
1476    let mut params = HashMap::new();
1477
1478    for pair in query.split('&') {
1479        if pair.trim().is_empty() {
1480            continue;
1481        }
1482
1483        let (raw_key, raw_value) = match pair.split_once('=') {
1484            Some((key, value)) => (key, value),
1485            None => (pair, ""),
1486        };
1487        let key = decode_query_component(raw_key);
1488        let value = decode_query_component(raw_value);
1489        params.insert(key, value);
1490    }
1491
1492    params
1493}
1494
1495fn decode_query_component(component: &str) -> String {
1496    match urlencoding::decode(component) {
1497        Ok(value) => value.into_owned(),
1498        Err(_) => component.to_string(),
1499    }
1500}
1501
1502#[cfg(test)]
1503mod tests {
1504    use super::{
1505        CodexDeviceCodeResponse, extract_oauth_code_and_state, parse_netscape_cookie_line,
1506        parse_oauth_callback_request, select_cookie_rows,
1507    };
1508
1509    #[test]
1510    fn extracts_code_and_state_from_full_callback_url() {
1511        let input = "http://localhost:1455/auth/callback?code=abc123&state=xyz789";
1512        let (code, state) = extract_oauth_code_and_state(input).expect("expected callback parse");
1513        assert_eq!(code, "abc123");
1514        assert_eq!(state, "xyz789");
1515    }
1516
1517    #[test]
1518    fn extracts_code_and_state_from_raw_query_string() {
1519        let input = "code=abc123&state=xyz789";
1520        let (code, state) = extract_oauth_code_and_state(input).expect("expected callback parse");
1521        assert_eq!(code, "abc123");
1522        assert_eq!(state, "xyz789");
1523    }
1524
1525    #[test]
1526    fn returns_error_when_state_is_missing() {
1527        let input = "http://localhost:1455/auth/callback?code=abc123";
1528        let err = extract_oauth_code_and_state(input).expect_err("expected missing state");
1529        assert!(err.to_string().contains("OAuth state"));
1530    }
1531
1532    #[test]
1533    fn parses_oauth_callback_http_request() {
1534        let request =
1535            "GET /auth/callback?code=abc123&state=xyz789 HTTP/1.1\r\nHost: localhost:1455\r\n\r\n";
1536        let (code, state) =
1537            parse_oauth_callback_request(request).expect("expected valid callback request");
1538        assert_eq!(code, "abc123");
1539        assert_eq!(state, "xyz789");
1540    }
1541
1542    #[test]
1543    fn parses_post_callback_with_query_params() {
1544        let request =
1545            "POST /auth/callback?code=abc123&state=xyz789 HTTP/1.1\r\nHost: localhost:1455\r\n\r\n";
1546        let (code, state) =
1547            parse_oauth_callback_request(request).expect("expected POST callback parse");
1548        assert_eq!(code, "abc123");
1549        assert_eq!(state, "xyz789");
1550    }
1551
1552    #[test]
1553    fn parses_post_form_encoded_callback_request() {
1554        let request = "POST /auth/callback HTTP/1.1\r\nHost: localhost:1455\r\nContent-Type: application/x-www-form-urlencoded\r\nContent-Length: 25\r\n\r\ncode=abc123&state=xyz789";
1555        let (code, state) =
1556            parse_oauth_callback_request(request).expect("expected form POST callback parse");
1557        assert_eq!(code, "abc123");
1558        assert_eq!(state, "xyz789");
1559    }
1560
1561    #[test]
1562    fn rejects_unsupported_callback_method() {
1563        let request = "OPTIONS /auth/callback HTTP/1.1\r\nHost: localhost:1455\r\n\r\n";
1564        let err = parse_oauth_callback_request(request)
1565            .expect_err("expected unsupported callback method");
1566        assert!(err.to_string().contains("Unsupported callback method"));
1567    }
1568
1569    #[test]
1570    fn parses_codex_device_interval_from_string() {
1571        let parsed: CodexDeviceCodeResponse = serde_json::from_str(
1572            r#"{"device_auth_id":"id-1","user_code":"ABCD-EFGH","interval":"7"}"#,
1573        )
1574        .expect("expected valid device-code payload");
1575        assert_eq!(parsed.interval, 7);
1576    }
1577
1578    #[test]
1579    fn parses_codex_device_interval_from_number() {
1580        let parsed: CodexDeviceCodeResponse = serde_json::from_str(
1581            r#"{"device_auth_id":"id-1","user_code":"ABCD-EFGH","interval":9}"#,
1582        )
1583        .expect("expected valid numeric interval payload");
1584        assert_eq!(parsed.interval, 9);
1585    }
1586
1587    #[test]
1588    fn parses_netscape_cookie_with_httponly_prefix() {
1589        let line = "#HttpOnly_.nextdoor.com\tTRUE\t/\tTRUE\t1803495701\tndbr_at\ttoken123";
1590        let parsed = parse_netscape_cookie_line(line).expect("expected cookie parse");
1591        assert_eq!(parsed.domain, ".nextdoor.com");
1592        assert!(parsed.http_only);
1593        assert_eq!(parsed.name, "ndbr_at");
1594    }
1595
1596    #[test]
1597    fn nextdoor_filter_keeps_auth_cookies_only() {
1598        let rows = vec![
1599            parse_netscape_cookie_line(
1600                ".nextdoor.com\tTRUE\t/\tTRUE\t4803495701\tndbr_at\tauth-token",
1601            )
1602            .expect("auth cookie"),
1603            parse_netscape_cookie_line(".nextdoor.com\tTRUE\t/\tFALSE\t4803495701\t_ga\ttracking")
1604                .expect("tracking cookie"),
1605        ];
1606        let (selected, dropped_expired, dropped_non_auth) =
1607            select_cookie_rows(&rows, "nextdoor-web", false);
1608        assert_eq!(selected.len(), 1);
1609        assert_eq!(selected[0].name, "ndbr_at");
1610        assert_eq!(dropped_expired, 0);
1611        assert_eq!(dropped_non_auth, 1);
1612    }
1613}