Skip to main content

atomcode_core/auth/
oauth.rs

1use std::collections::HashMap;
2use std::io::{self, BufRead, Write};
3use std::net::TcpListener;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::sync::{mpsc, Arc};
6use std::thread;
7use std::time::Duration;
8
9use anyhow::{Context, Result};
10use serde::{Deserialize, Serialize};
11
12use atomcode_telemetry::{Event, Telemetry};
13
14/// Default Platform server base URL (client_secret is kept on the broker).
15/// Override with the `ATOMCODE_PLATFORM_SERVER` environment variable.
16const DEFAULT_PLATFORM_SERVER: &str = "https://acs.atomgit.com";
17
18/// Sanitize a user-supplied base URL: add `http://` if no scheme is present,
19/// and strip trailing `/` so path concatenation never produces `//`.
20fn sanitize_base_url(raw: &str) -> String {
21    let trimmed = raw.trim();
22    let with_scheme = if trimmed.contains("://") {
23        trimmed.to_string()
24    } else {
25        format!("http://{}", trimmed)
26    };
27    with_scheme.trim_end_matches('/').to_string()
28}
29
30/// Return the Platform server base URL, reading `ATOMCODE_PLATFORM_SERVER` once
31/// at first call and caching the result for the process lifetime. This ensures
32/// all URL-derived functions within a single login/session flow target the
33/// same server even if the env var changes mid-flight.
34fn platform_base_url() -> &'static str {
35    use std::sync::OnceLock;
36    static BASE: OnceLock<String> = OnceLock::new();
37    BASE.get_or_init(|| {
38        let raw = std::env::var("ATOMCODE_PLATFORM_SERVER")
39            .unwrap_or_else(|_| DEFAULT_PLATFORM_SERVER.to_string());
40        sanitize_base_url(&raw)
41    })
42}
43
44/// Platform server URLs (derived from `ATOMCODE_PLATFORM_SERVER`).
45pub fn platform_broker_url() -> String { platform_base_url().to_string() }
46pub fn platform_login_url() -> String { format!("{}/auth/login", platform_base_url()) }
47pub fn platform_check_url() -> String { format!("{}/auth/check", platform_base_url()) }
48pub fn platform_token_url() -> String { format!("{}/auth/token", platform_base_url()) }
49pub fn platform_exchange_url() -> String { format!("{}/oauth/exchange", platform_base_url()) }
50pub fn platform_refresh_url() -> String { format!("{}/oauth/refresh", platform_base_url()) }
51#[allow(dead_code)]
52pub fn authorize_url() -> String { format!("{}/oauth/authorize", platform_base_url()) }
53#[allow(dead_code)]
54pub fn token_url() -> String { format!("{}/oauth/token", platform_base_url()) }
55#[allow(dead_code)]
56pub fn user_url() -> String { format!("{}/api/v5/user", platform_base_url()) }
57
58/// Blocking HTTP client pre-configured with `ATOMCODE_USER_AGENT`. Every
59/// OAuth-side request must carry the token or AtomGit's gate rejects it.
60/// Centralized so a future UA format change (e.g. append install-id)
61/// happens in one spot rather than at each `Client::new()` site.
62fn blocking_client() -> reqwest::blocking::Client {
63    // Hard timeouts here too — the `get_valid_token` path calls
64    // `refresh_access_token` synchronously whenever a stored token
65    // looks expired, and that runs on the main TUI thread (via
66    // `Client::from_stored_auth` → `/status`, drift monitor, etc.).
67    // Without a cap, a slow or unreachable OAuth server would hang
68    // the UI indefinitely. Same budget as the coding-plan client.
69    reqwest::blocking::Client::builder()
70        .connect_timeout(std::time::Duration::from_secs(5))
71        .timeout(std::time::Duration::from_secs(10))
72        .user_agent(crate::ATOMCODE_USER_AGENT)
73        .build()
74        .unwrap_or_else(|_| reqwest::blocking::Client::new())
75}
76
77/// Stored authentication data
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct AuthInfo {
80    pub access_token: String,
81    pub refresh_token: Option<String>,
82    pub token_type: String,
83    pub expires_in: Option<i64>,
84    /// Unix timestamp (seconds) when this token was obtained
85    #[serde(default)]
86    pub created_at: i64,
87    pub user: UserInfo,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct UserInfo {
92    pub id: String,
93    pub username: String,
94    pub name: Option<String>,
95    pub email: Option<String>,
96    pub avatar_url: Option<String>,
97}
98
99// ============================================================================
100// Platform API types
101// ============================================================================
102
103#[derive(Debug, Deserialize)]
104struct PlatformLoginResponse {
105    login_url: String,
106    state: String,
107}
108
109#[derive(Debug, Deserialize)]
110struct PlatformCheckResponse {
111    valid: bool,
112}
113
114#[derive(Debug, Deserialize)]
115struct PlatformUserInfo {
116    id: String,
117    username: String,
118    name: Option<String>,
119    email: Option<String>,
120    avatar_url: Option<String>,
121}
122
123#[derive(Debug, Deserialize)]
124struct PlatformTokenResponse {
125    access_token: String,
126    token_type: String,
127    expires_in: Option<i64>,
128    refresh_token: Option<String>,
129    user: PlatformUserInfo,
130}
131
132// ============================================================================
133// ESC-cancel support for the OAuth poll loop
134// ============================================================================
135//
136// The poll loop in `login()` historically did `loop { http_check; sleep(2s) }`
137// with no input handling — Linux/WSL users with broken `xdg-open` had no way
138// to exit short of Ctrl+C (which kills the whole CLI/TUI). We now print the
139// auth URL up-front for those users and accept ESC during the wait.
140//
141// Cooked mode (set by `suspend_for_external` in the TUI, default everywhere
142// in CLI mode) line-buffers stdin — ESC alone won't reach `read()` until the
143// user hits Enter. So while waiting, we temporarily switch stdin to cbreak
144// (non-canonical, no echo) via an RAII `CbreakGuard`, restoring the original
145// termios on every drop path. If `tcgetattr`/`tcsetattr` fail (non-tty stdin
146// from a pipe or CI), the guard returns `None` and the loop falls back to a
147// plain sleep — login still works, ESC just has no effect.
148//
149// Windows has no `poll(2)` over stdin and the existing
150// `read_callback_from_stdin_until_stopped` path is already gated off there
151// for the same reason. We follow the same pattern: `CbreakGuard` is a
152// zero-sized stub that always returns `None`, and `wait_for_esc_or_timeout`
153// degrades to `thread::sleep`.
154
155/// Outcome of waiting for stdin activity during the OAuth poll loop.
156//
157// On Windows `wait_for_esc_or_timeout` always returns `Timeout` (no
158// poll(2) over stdin), so `Cancelled` and `OtherInput` are constructed
159// only on Unix. The variants must still exist on Windows because
160// `classify_input` and its tests reference them — `cargo test` runs on
161// every platform. Suppress the dead-code warning rather than gate the
162// type, so the test surface stays portable.
163#[cfg_attr(target_os = "windows", allow(dead_code))]
164#[derive(Debug, Clone, Copy, PartialEq, Eq)]
165enum EscOutcome {
166    /// Bare ESC keypress — user cancelled.
167    Cancelled,
168    /// poll(2) timed out, or `read` returned 0 / error.
169    Timeout,
170    /// Some bytes arrived but it wasn't a bare ESC (escape sequence,
171    /// stray letter / Enter, paste). Treated identically to Timeout
172    /// at the call site — fall through to the HTTP check.
173    OtherInput,
174}
175
176/// Classify a freshly-read stdin buffer as cancel / timeout / ignore.
177///
178/// Bare ESC = single 0x1B byte. Anything else (escape sequence, normal
179/// keystroke, pasted text) is `OtherInput`. Empty buffer = `Timeout`.
180///
181/// Terminals batch escape sequences (e.g. arrow up = `\x1B[A`) into a
182/// single write to the master pty, so a 32-byte non-blocking read sees
183/// the whole sequence at once and we never mistake its prefix for bare
184/// ESC. See spec `2026-04-28-show-oauth-url-design.md` §5.
185//
186// Only called from the Unix `wait_for_esc_or_timeout`. Kept callable on
187// Windows because the unit-test module exercises it on every platform —
188// the logic is byte-pattern matching, no platform deps. `dead_code`
189// suppression scoped to Windows so Unix still gets the warning if a
190// future change makes it genuinely unused there.
191#[cfg_attr(target_os = "windows", allow(dead_code))]
192fn classify_input(bytes: &[u8]) -> EscOutcome {
193    match bytes {
194        [] => EscOutcome::Timeout,
195        [0x1B] => EscOutcome::Cancelled,
196        _ => EscOutcome::OtherInput,
197    }
198}
199
200#[cfg(not(target_os = "windows"))]
201struct CbreakGuard {
202    fd: std::os::unix::io::RawFd,
203    orig: libc::termios,
204}
205
206#[cfg(target_os = "windows")]
207struct CbreakGuard;
208
209impl CbreakGuard {
210    /// Try to switch stdin to cbreak. Returns `None` if stdin isn't a
211    /// tty (ENOTTY) or if `tcsetattr` fails. On Windows always returns
212    /// `None` — no equivalent of the Unix poll-based path.
213    #[cfg(not(target_os = "windows"))]
214    fn new() -> Option<Self> {
215        use std::os::unix::io::AsRawFd;
216        let fd = io::stdin().as_raw_fd();
217        let mut orig: libc::termios = unsafe { std::mem::zeroed() };
218        if unsafe { libc::tcgetattr(fd, &mut orig) } != 0 {
219            return None;
220        }
221        let mut new = orig;
222        new.c_lflag &= !(libc::ICANON | libc::ECHO);
223        new.c_cc[libc::VMIN] = 0;
224        new.c_cc[libc::VTIME] = 0;
225        if unsafe { libc::tcsetattr(fd, libc::TCSANOW, &new) } != 0 {
226            return None;
227        }
228        Some(Self { fd, orig })
229    }
230
231    #[cfg(target_os = "windows")]
232    fn new() -> Option<Self> {
233        None
234    }
235}
236
237#[cfg(not(target_os = "windows"))]
238impl Drop for CbreakGuard {
239    fn drop(&mut self) {
240        // Best-effort restore. If this somehow fails the terminal is
241        // stuck in cbreak — `stty sane` recovers it. Drop runs on every
242        // exit path including panic so the common case is always clean.
243        unsafe {
244            libc::tcsetattr(self.fd, libc::TCSANOW, &self.orig);
245        }
246    }
247}
248
249/// Wait up to `timeout` for stdin activity (ESC keypress) or sleep
250/// until the timeout expires. Used to interleave ESC-cancel checks
251/// with the OAuth `/auth/check` poll cadence.
252///
253/// On Windows or when the cbreak guard couldn't be established, this
254/// just sleeps and returns `Timeout` — ESC never fires but login still
255/// works.
256#[cfg(not(target_os = "windows"))]
257fn wait_for_esc_or_timeout(guard: &Option<CbreakGuard>, timeout: Duration) -> EscOutcome {
258    let Some(g) = guard.as_ref() else {
259        thread::sleep(timeout);
260        return EscOutcome::Timeout;
261    };
262
263    let mut pfd = libc::pollfd {
264        fd: g.fd,
265        events: libc::POLLIN,
266        revents: 0,
267    };
268    let timeout_ms = timeout.as_millis().min(i32::MAX as u128) as i32;
269    let rc = unsafe { libc::poll(&mut pfd, 1, timeout_ms) };
270    if rc <= 0 {
271        // 0 = timeout (no data); <0 = poll error (EINTR etc.). Either
272        // way the right move is "fall through to HTTP check"; the
273        // outer loop's HTTP round-trip is the natural rate limit.
274        return EscOutcome::Timeout;
275    }
276    let mut buf = [0u8; 32];
277    let n = unsafe { libc::read(g.fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len()) };
278    if n <= 0 {
279        return EscOutcome::Timeout;
280    }
281    classify_input(&buf[..n as usize])
282}
283
284#[cfg(target_os = "windows")]
285fn wait_for_esc_or_timeout(_guard: &Option<CbreakGuard>, timeout: Duration) -> EscOutcome {
286    thread::sleep(timeout);
287    EscOutcome::Timeout
288}
289
290/// Outcome of one `LoginSession::poll_once` call.
291#[derive(Debug, Clone, Copy, PartialEq, Eq)]
292pub enum PollOutcome {
293    /// User hasn't completed the browser sign-in yet — wait and retry.
294    Pending,
295    /// `/auth/check` reported `valid=true`. Caller should call `finish()`.
296    Authorized,
297}
298
299/// In-flight OAuth session. Returned by `start_login()`. The caller
300/// drives the flow:
301///
302/// 1. Display `session.url()` and (best-effort) `open_browser()`.
303/// 2. Loop `poll_once()` until `Authorized`, sleeping between calls
304///    AT THE CALLER'S CADENCE — this lets the TUI interleave UI events
305///    (ESC for cancel) and the CLI use a simple `thread::sleep`.
306/// 3. Call `finish()` to exchange `state` → token.
307pub struct LoginSession {
308    state: String,
309    login_url: String,
310    client: reqwest::blocking::Client,
311}
312
313impl LoginSession {
314    /// Authorization URL the user must visit. Stable for the lifetime
315    /// of the session — safe to show once and reuse.
316    pub fn url(&self) -> &str {
317        &self.login_url
318    }
319
320    /// Best-effort browser launch. Always silent — failures are expected
321    /// on Linux/WSL where the URL on screen is the user's actual path.
322    pub fn open_browser_best_effort(&self) {
323        let _ = open_browser(&self.login_url);
324    }
325
326    /// One non-blocking HTTP check against `/auth/check`. Returns
327    /// `Pending` until the user finishes the browser flow, then
328    /// `Authorized`. Errors only on transport/parse failures; a
329    /// "not yet" answer is `Ok(Pending)`, never `Err`.
330    pub fn poll_once(&self) -> Result<PollOutcome> {
331        let resp = self
332            .client
333            .get(platform_check_url())
334            .query(&[("state", &self.state)])
335            .send()
336            .context("Failed to call /auth/check")?;
337        if resp.status().is_success() {
338            if let Ok(check) = resp.json::<PlatformCheckResponse>() {
339                if check.valid {
340                    return Ok(PollOutcome::Authorized);
341                }
342            }
343        }
344        Ok(PollOutcome::Pending)
345    }
346
347    /// Final step: `/auth/token` exchange + `LoginSuccess` telemetry.
348    /// Consumes the session — only call after `poll_once` returned
349    /// `Authorized`.
350    pub fn finish(self, tel: Option<&Arc<Telemetry>>) -> Result<AuthInfo> {
351        let token_resp: PlatformTokenResponse = self
352            .client
353            .get(platform_token_url())
354            .query(&[("state", &self.state)])
355            .send()
356            .context("Failed to call /auth/token")?
357            .json()
358            .context("Failed to parse /auth/token response")?;
359
360        let created_at = std::time::SystemTime::now()
361            .duration_since(std::time::UNIX_EPOCH)
362            .unwrap()
363            .as_secs() as i64;
364
365        let auth_info = AuthInfo {
366            access_token: token_resp.access_token,
367            refresh_token: token_resp.refresh_token,
368            token_type: token_resp.token_type,
369            expires_in: token_resp.expires_in,
370            created_at,
371            user: UserInfo {
372                id: token_resp.user.id,
373                username: token_resp.user.username,
374                name: token_resp.user.name,
375                email: token_resp.user.email,
376                avatar_url: token_resp.user.avatar_url,
377            },
378        };
379
380        if let Some(t) = tel {
381            // Push account_id onto the telemetry handle BEFORE emitting
382            // login_success so the event itself — and every subsequent event in
383            // this process — carries the id. The handle-level setter outlives
384            // any task-local scope, so events emitted outside the main scope
385            // (e.g. before scope is entered, or from spawned tasks) inherit it.
386            t.set_account_id(Some(auth_info.user.id.to_string()));
387            t.track(Event::LoginSuccess);
388        }
389
390        Ok(auth_info)
391    }
392}
393
394/// Begin OAuth login: call `/auth/login`, return a session containing
395/// the auth URL + state. Cheap (one HTTP round-trip), never blocks on
396/// user action — separated from polling so callers can render the URL
397/// before yielding control to the wait loop.
398pub fn start_login() -> Result<LoginSession> {
399    let client = reqwest::blocking::Client::new();
400    let resp: PlatformLoginResponse = client
401        .get(platform_login_url())
402        .query(&[("provider", "atomgit")])
403        .send()
404        .context("Failed to call /auth/login")?
405        .json()
406        .context("Failed to parse /auth/login response")?;
407    Ok(LoginSession {
408        state: resp.state,
409        login_url: strip_force_login(&resp.login_url),
410        client,
411    })
412}
413
414/// Drop `force_login=true` from the broker-supplied OAuth URL. The
415/// broker emits this flag to force re-authentication on every login;
416/// stripping it lets users already signed in to atomgit.com
417/// auto-authorize and skip the consent page. State binding via the
418/// `state` parameter is unchanged, so the request is still anchored
419/// to this specific login attempt.
420fn strip_force_login(url: &str) -> String {
421    url.replace("&force_login=true", "")
422        .replace("?force_login=true&", "?")
423        .replace("?force_login=true", "")
424}
425
426/// Stdout-driven OAuth login: prints the URL, opens the browser,
427/// polls `/auth/check` with stdin-driven ESC cancel. Used by the CLI
428/// (`atomcode login`, `atomcode codingplan`) and by `setup.rs`'s
429/// `step_login` when the TUI hasn't already pre-flighted login.
430///
431/// TUI callers should NOT use this — render via `start_login()` +
432/// `LoginSession::poll_once()` so the input box stays visible and ESC
433/// is captured through `input_rx` (no termios manipulation needed).
434///
435/// `tel` is optional so non-CLI callers (tests, coding_plan setup) can
436/// pass `None` when they don't hold a telemetry handle.
437pub fn login(tel: Option<&Arc<Telemetry>>) -> Result<AuthInfo> {
438    let session = start_login()?;
439
440    // Always print the URL — `xdg-open` on Linux/WSL silently fails
441    // often enough that we can't rely on it. On the desktop happy path
442    // the browser opens *and* the URL stays in scrollback as a backup.
443    println!("  Browser didn't open? Open the URL below in any browser to sign in:");
444    println!("  {}", session.url());
445
446    // Try to enter cbreak so we can detect a bare-ESC keypress. None
447    // (non-tty stdin / tcsetattr failure) → fall back to plain sleep,
448    // and don't advertise an ESC affordance that wouldn't work.
449    let cbreak = CbreakGuard::new();
450    if cbreak.is_some() {
451        println!();
452        println!("  Press ESC to cancel");
453    }
454
455    session.open_browser_best_effort();
456
457    loop {
458        match session.poll_once()? {
459            PollOutcome::Authorized => break,
460            PollOutcome::Pending => {}
461        }
462        match wait_for_esc_or_timeout(&cbreak, Duration::from_secs(2)) {
463            EscOutcome::Cancelled => anyhow::bail!("login cancelled by user"),
464            EscOutcome::Timeout | EscOutcome::OtherInput => {}
465        }
466    }
467
468    session.finish(tel)
469}
470
471/// Extract state from a pasted callback URL (kept for potential future fallback use)
472#[allow(dead_code)]
473fn pasted_state(url: &str) -> Option<String> {
474    url.split('?')
475        .nth(1)?
476        .split('&')
477        .filter_map(|pair| {
478            let mut parts = pair.splitn(2, '=');
479            if parts.next()? == "state" {
480                Some(urlencoding_decode(parts.next()?))
481            } else {
482                None
483            }
484        })
485        .next()
486}
487
488/// Generate random state string for CSRF protection
489#[allow(dead_code)]
490fn generate_state() -> String {
491    use std::time::{SystemTime, UNIX_EPOCH};
492    let timestamp = SystemTime::now()
493        .duration_since(UNIX_EPOCH)
494        .unwrap()
495        .as_nanos();
496    format!("atomcode_{}", timestamp)
497}
498
499/// Open browser with the authorization URL.
500///
501/// `pub` because TUI modals (e.g. the QR-login onboarding step) need to
502/// invoke the same platform browser launch the CLI flow already does via
503/// `LoginSession::open_browser_best_effort` — callers without a live
504/// `LoginSession` only carry the URL string, so they go through this
505/// free function directly.
506#[cfg(target_os = "macos")]
507pub fn open_browser(url: &str) -> Result<()> {
508    std::process::Command::new("open")
509        .arg(url)
510        .spawn()
511        .context("Failed to open browser")?;
512    Ok(())
513}
514
515#[cfg(target_os = "linux")]
516pub fn open_browser(url: &str) -> Result<()> {
517    std::process::Command::new("xdg-open")
518        .arg(url)
519        .spawn()
520        .context("Failed to open browser")?;
521    Ok(())
522}
523
524#[cfg(target_os = "windows")]
525pub fn open_browser(url: &str) -> Result<()> {
526    use std::os::windows::process::CommandExt;
527    std::process::Command::new("cmd")
528        .raw_arg(format!("/C start \"\" \"{}\"", url))
529        .spawn()
530        .context("Failed to open browser")?;
531    Ok(())
532}
533
534#[cfg(not(any(target_os = "macos", target_os = "linux", target_os = "windows")))]
535pub fn open_browser(_url: &str) -> Result<()> {
536    anyhow::bail!("Unsupported platform for browser auto-open");
537}
538
539/// Race a local TCP listener against stdin paste; return the first
540/// `(code, state)` that arrives. Listener handles the normal desktop path
541/// where the browser hits `127.0.0.1:8765`; stdin path handles WSL /
542/// headless Linux where the user copies the callback URL from their
543/// browser's address bar and pastes it in.
544///
545/// Kept for potential future fallback use — the platform-broker flow in
546/// `login()` is the active callback path now.
547#[allow(dead_code)]
548fn await_callback(port: u16) -> Result<(String, String)> {
549    let listener = match TcpListener::bind(("127.0.0.1", port)) {
550        Ok(l) => Some(l),
551        Err(e) => {
552            println!("  Could not bind port {} ({}). Paste path only.", port, e);
553            None
554        }
555    };
556
557    println!(
558        "  Waiting for callback on http://127.0.0.1:{}/callback",
559        port
560    );
561    println!("  Or paste the full callback URL here and press Enter:");
562    println!("  (Ctrl+C to cancel)\n");
563
564    let (tx, rx) = mpsc::channel::<Result<(String, String)>>();
565    let stop = Arc::new(AtomicBool::new(false));
566
567    #[cfg_attr(not(target_os = "windows"), allow(unused_variables))]
568    let has_listener = listener.is_some();
569    if let Some(listener) = listener {
570        let tx_l = tx.clone();
571        let stop_l = Arc::clone(&stop);
572        thread::spawn(move || {
573            let r = accept_callback_until_stopped(listener, &stop_l);
574            let _ = tx_l.send(r);
575        });
576    }
577
578    // Stdin reader — spawn on Unix **regardless** of listener status. The
579    // listener covers the desktop path where the browser hits
580    // 127.0.0.1:8765; stdin covers everything else (headless Linux / SSH /
581    // Wayland without xdg-open / WSL under X forwarding failure). Earlier
582    // versions gated this on `!has_listener`, which silently broke Linux:
583    // the listener binds fine but the browser can't reach it, and with
584    // no stdin reader spawned the user's pasted URL went nowhere and the
585    // whole login hung forever.
586    //
587    // Must be cancellable: previous revisions used a blocking
588    // `stdin.lock().read_line()` + a "zombie thread is harmless" comment.
589    // It wasn't harmless — FD 0 and /dev/tty point to the same terminal
590    // device on Unix, so the kernel's line discipline delivers each byte
591    // to whichever reader calls `read` first. When the listener won the
592    // race, the zombie `read_line` was still blocked; the user's first
593    // keystroke after login got read by the zombie (parsed as a bad
594    // callback URL, dropped) instead of by crossterm's /dev/tty reader.
595    // Reported as "Chinese IME commits need two attempts to land".
596    //
597    // Fix: poll(2)-based loop that checks the `stop` AtomicBool between
598    // 100 ms timeouts, so when the listener wins we set `stop=true` and
599    // the stdin thread exits before the user types anything.
600    //
601    // Windows is still gated off because its stdin `read_line` blocks on
602    // a console handle that can't be cancelled from another thread and
603    // doesn't have an equivalent poll(2) path.
604    #[cfg(not(target_os = "windows"))]
605    {
606        let tx_stdin = tx.clone();
607        let stop_stdin = Arc::clone(&stop);
608        thread::spawn(move || {
609            let r = read_callback_from_stdin_until_stopped(&stop_stdin);
610            let _ = tx_stdin.send(r);
611        });
612    }
613    #[cfg(target_os = "windows")]
614    {
615        if !has_listener {
616            let tx_stdin = tx.clone();
617            thread::spawn(move || {
618                let stdin = io::stdin();
619                let mut line = String::new();
620                let r = match stdin.lock().read_line(&mut line) {
621                    Ok(0) => Err(anyhow::anyhow!("stdin closed")),
622                    Ok(_) => parse_pasted_callback(&line),
623                    Err(e) => Err(anyhow::Error::new(e).context("Failed to read from stdin")),
624                };
625                let _ = tx_stdin.send(r);
626            });
627        }
628    }
629    // Drop the original `tx` — the listener and stdin readers each
630    // cloned their own. Without this drop the channel would never
631    // close after both readers finish, so `rx.recv()` on an early
632    // cancellation would hang.
633    drop(tx);
634
635    let result = rx.recv().context("login cancelled")?;
636    stop.store(true, Ordering::Relaxed);
637    result
638}
639
640/// Accept a single OAuth callback on an already-bound listener, polling a
641/// Poll stdin for a pasted callback URL, checking `stop` every 100 ms so
642/// the caller can cancel (e.g. when the listener won the race). Returns
643/// `Err("stdin cancelled")` on stop, `Err(...)` on a read error or a line
644/// that doesn't parse as a callback URL, `Ok((code, state))` on success.
645///
646/// Uses `poll(2)` + non-blocking reads so we never sit inside a blocking
647/// `read_line()` — that was the bug behind "first keystroke after login
648/// goes to a zombie stdin thread instead of crossterm". On macOS / Linux,
649/// FD 0 (this thread's read) and /dev/tty (crossterm's read) point to
650/// the same terminal device; whichever syscall lands on a byte first
651/// gets it, and a blocked `read_line` stays in line for the next input.
652#[cfg(not(target_os = "windows"))]
653#[allow(dead_code)]
654fn read_callback_from_stdin_until_stopped(stop: &AtomicBool) -> Result<(String, String)> {
655    use std::os::unix::io::AsRawFd;
656
657    let stdin = io::stdin();
658    let fd = stdin.as_raw_fd();
659
660    // Save original flags so we restore them on exit — leaving stdin
661    // non-blocking after login would break subsequent code that expects
662    // the normal blocking shape (e.g. any future CLI prompt helper).
663    let orig_flags = unsafe { libc::fcntl(fd, libc::F_GETFL) };
664    if orig_flags >= 0 {
665        unsafe {
666            libc::fcntl(fd, libc::F_SETFL, orig_flags | libc::O_NONBLOCK);
667        }
668    }
669
670    // RAII guard: restore flags on any exit path (stop, error, parse fail).
671    struct FlagGuard {
672        fd: std::os::unix::io::RawFd,
673        orig_flags: i32,
674    }
675    impl Drop for FlagGuard {
676        fn drop(&mut self) {
677            if self.orig_flags >= 0 {
678                unsafe {
679                    libc::fcntl(self.fd, libc::F_SETFL, self.orig_flags);
680                }
681            }
682        }
683    }
684    let _guard = FlagGuard { fd, orig_flags };
685
686    let mut line = String::new();
687    let mut buf = [0u8; 256];
688    loop {
689        if stop.load(Ordering::Relaxed) {
690            anyhow::bail!("stdin cancelled");
691        }
692        let mut pfd = libc::pollfd {
693            fd,
694            events: libc::POLLIN,
695            revents: 0,
696        };
697        let poll_rc = unsafe { libc::poll(&mut pfd, 1, 100) };
698        if poll_rc < 0 {
699            let err = io::Error::last_os_error();
700            if err.kind() == io::ErrorKind::Interrupted {
701                continue;
702            }
703            return Err(anyhow::Error::new(err).context("poll(stdin)"));
704        }
705        if poll_rc == 0 {
706            continue; // timeout — re-check stop, re-poll
707        }
708        // Data available; drain what's there. read(2) in non-blocking
709        // mode returns up to one pipe buffer in a single call.
710        let n = unsafe { libc::read(fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len()) };
711        if n < 0 {
712            let err = io::Error::last_os_error();
713            if err.kind() == io::ErrorKind::WouldBlock || err.kind() == io::ErrorKind::Interrupted {
714                continue;
715            }
716            return Err(anyhow::Error::new(err).context("read(stdin)"));
717        }
718        if n == 0 {
719            anyhow::bail!("stdin closed");
720        }
721        // Append as UTF-8 (lossy — pasted URLs are ASCII; any weird
722        // bytes in a URL would fail `parse_pasted_callback` anyway).
723        line.push_str(&String::from_utf8_lossy(&buf[..n as usize]));
724        if line.contains('\n') {
725            return parse_pasted_callback(&line);
726        }
727    }
728}
729
730/// `stop` flag every 200ms so the caller can cancel (e.g. when the paste
731/// path won the race).
732#[allow(dead_code)]
733fn accept_callback_until_stopped(
734    listener: TcpListener,
735    stop: &AtomicBool,
736) -> Result<(String, String)> {
737    listener
738        .set_nonblocking(true)
739        .context("Failed to set non-blocking mode")?;
740
741    let mut stream = loop {
742        if stop.load(Ordering::Relaxed) {
743            anyhow::bail!("listener cancelled");
744        }
745        match listener.accept() {
746            Ok((stream, _)) => break stream,
747            Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
748                thread::sleep(Duration::from_millis(200));
749                continue;
750            }
751            Err(e) => return Err(e).context("Failed to accept connection"),
752        }
753    };
754
755    stream.set_nonblocking(false)?;
756
757    // Read HTTP request
758    let mut reader = io::BufReader::new(&mut stream);
759    let mut request_line = String::new();
760    reader.read_line(&mut request_line)?;
761
762    // Parse the request line (GET /callback?code=...&state=... HTTP/1.1)
763    let url: String = request_line
764        .split_whitespace()
765        .nth(1)
766        .context("Invalid HTTP request")?
767        .to_string();
768
769    // Parse query parameters
770    let query_start = url.find('?').context("No query parameters in callback")?;
771    let query = &url[query_start + 1..];
772
773    let params: HashMap<String, String> = query
774        .split('&')
775        .filter_map(|pair| {
776            let mut parts = pair.splitn(2, '=');
777            let key = parts.next()?;
778            let value = parts
779                .next()
780                .map(|v| urlencoding_decode(v))
781                .unwrap_or_default();
782            Some((key.to_string(), value))
783        })
784        .collect();
785
786    // Check for error — redirect browser to AtomGit
787    if let Some(error) = params.get("error") {
788        let error_desc = params
789            .get("error_description")
790            .map(|s| s.as_str())
791            .unwrap_or(error);
792        let response = "HTTP/1.1 302 Found\r\nLocation: https://atomgit.com\r\n\r\n";
793        let _ = stream.write_all(response.as_bytes());
794        let _ = stream.flush();
795        anyhow::bail!("OAuth error: {}", error_desc);
796    }
797
798    let code = params.get("code").context("No code in callback")?.clone();
799    let state = params.get("state").cloned().unwrap_or_default();
800
801    // Send success response to browser
802    let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=utf-8\r\n\r\n\
803        <html><head><title>AtomCode Login</title>\
804        <style>body{font-family:system-ui;display:flex;justify-content:center;align-items:center;height:100vh;margin:0;background:#1a1a2e;color:#eee}\
805        .container{text-align:center;padding:2rem}h1{color:#7c3aed;margin:0}p{color:#888}\
806        .success{color:#22c55e;font-size:4rem}</style></head>\
807        <body><div class=\"container\">\
808        <div class=\"success\">✓</div>\
809        <h1>Authorization Successful</h1>\
810        <p>You can close this window and return to AtomCode.</p>\
811        </div></body></html>";
812
813    stream.write_all(response.as_bytes())?;
814    stream.flush()?;
815
816    Ok((code, state))
817}
818
819/// Simple URL decoding
820fn urlencoding_decode(s: &str) -> String {
821    let mut result = String::new();
822    let mut chars = s.chars().peekable();
823
824    while let Some(c) = chars.next() {
825        if c == '%' {
826            let hex: String = chars.by_ref().take(2).collect();
827            if let Ok(byte) = u8::from_str_radix(&hex, 16) {
828                result.push(byte as char);
829            }
830        } else if c == '+' {
831            result.push(' ');
832        } else {
833            result.push(c);
834        }
835    }
836
837    result
838}
839
840/// Refresh the access token using the stored refresh_token via Platform Broker.
841/// Returns updated AuthInfo with new tokens, and saves it to disk.
842pub fn refresh_access_token(auth: &AuthInfo) -> Result<AuthInfo> {
843    let refresh_token = auth
844        .refresh_token
845        .as_deref()
846        .context("No refresh_token available — please /login again")?;
847
848    let client = blocking_client();
849
850    // Call Platform Broker API for refresh
851    let response = client
852        .post(platform_refresh_url())
853        .json(&serde_json::json!({ "refresh_token": refresh_token }))
854        .send()
855        .context("Failed to send refresh token request to broker")?;
856
857    if !response.status().is_success() {
858        let status = response.status();
859        let body = response.text().unwrap_or_default();
860        anyhow::bail!(
861            "Token refresh failed ({}): {} — please /login again",
862            status,
863            body
864        );
865    }
866
867    #[derive(Deserialize)]
868    struct BrokerResponse {
869        access_token: String,
870        token_type: Option<String>,
871        expires_in: Option<i64>,
872        refresh_token: Option<String>,
873        user: Option<PlatformUserInfo>,
874    }
875
876    let broker_resp: BrokerResponse = response.json().context("Failed to parse broker response")?;
877
878    let created_at = std::time::SystemTime::now()
879        .duration_since(std::time::UNIX_EPOCH)
880        .unwrap()
881        .as_secs() as i64;
882
883    let new_auth = AuthInfo {
884        access_token: broker_resp.access_token,
885        refresh_token: broker_resp
886            .refresh_token
887            .or_else(|| auth.refresh_token.clone()),
888        token_type: broker_resp
889            .token_type
890            .unwrap_or_else(|| auth.token_type.clone()),
891        expires_in: broker_resp.expires_in.or(auth.expires_in),
892        created_at,
893        user: broker_resp
894            .user
895            .map(|u| UserInfo {
896                id: u.id,
897                username: u.username,
898                name: u.name,
899                email: u.email,
900                avatar_url: u.avatar_url,
901            })
902            .unwrap_or_else(|| auth.user.clone()),
903    };
904
905    save_auth(&new_auth)?;
906    Ok(new_auth)
907}
908
909/// Get a valid access token, refreshing automatically if expired.
910/// Returns the access token string ready to use.
911pub fn get_valid_token() -> Result<String> {
912    let auth = get_stored_auth().context("Not logged in — please use /login first")?;
913
914    // Check if token is expired (with 5-minute safety margin)
915    if let Some(expires_in) = auth.expires_in {
916        let now = std::time::SystemTime::now()
917            .duration_since(std::time::UNIX_EPOCH)
918            .unwrap()
919            .as_secs() as i64;
920        let expires_at = auth.created_at + expires_in;
921
922        if now >= expires_at - 300 {
923            // Token expired or about to expire — try refresh
924            match refresh_access_token(&auth) {
925                Ok(new_auth) => return Ok(new_auth.access_token),
926                Err(e) => anyhow::bail!("Token expired and refresh failed: {}", e),
927            }
928        }
929    } else if auth.created_at == 0 {
930        // Legacy auth.toml without created_at — no way to know if expired,
931        // try refresh if refresh_token is available, otherwise use as-is
932        if auth.refresh_token.is_some() {
933            if let Ok(new_auth) = refresh_access_token(&auth) {
934                return Ok(new_auth.access_token);
935            }
936        }
937    }
938
939    Ok(auth.access_token)
940}
941
942/// Logout - clear stored auth.
943///
944/// Core-layer function: does the filesystem work and returns. User-facing
945/// messaging is the caller's job — this was previously `println!`-ing
946/// "Logged out successfully" directly, which bypassed the TUI renderer
947/// and bled into the input box area on next repaint, and also produced
948/// a duplicate line in CLI mode where `handle_command` prints its own
949/// confirmation. No `Err` distinguishes "file absent" from "file removed" —
950/// both are success from the user's perspective ("you're logged out").
951pub fn logout() -> Result<()> {
952    let auth_path = auth_file_path();
953    if auth_path.exists() {
954        std::fs::remove_file(&auth_path).context("Failed to remove auth file")?;
955    }
956    Ok(())
957}
958
959/// Get stored auth info
960pub fn get_stored_auth() -> Option<AuthInfo> {
961    let auth_path = auth_file_path();
962    if !auth_path.exists() {
963        return None;
964    }
965
966    let content = std::fs::read_to_string(&auth_path).ok()?;
967    toml::from_str(&content).ok()
968}
969
970/// Save auth info to file
971pub fn save_auth(auth: &AuthInfo) -> Result<()> {
972    let auth_path = auth_file_path();
973
974    // Ensure parent directory exists
975    if let Some(parent) = auth_path.parent() {
976        std::fs::create_dir_all(parent).context("Failed to create auth directory")?;
977        // Set directory permissions to 0o700 (owner only) on Unix
978        #[cfg(unix)]
979        {
980            use std::os::unix::fs::PermissionsExt;
981            let _ = std::fs::set_permissions(parent, std::fs::Permissions::from_mode(0o700));
982        }
983    }
984
985    let content = toml::to_string_pretty(auth).context("Failed to serialize auth info")?;
986    super::write_auth_file_secure(&auth_path, &content).context("Failed to write auth file")?;
987
988    // Set file permissions to 0o600 (owner read/write only) on Unix
989    #[cfg(unix)]
990    {
991        use std::os::unix::fs::PermissionsExt;
992        std::fs::set_permissions(&auth_path, std::fs::Permissions::from_mode(0o600))
993            .context("Failed to set auth file permissions")?;
994    }
995
996    // No stdout output here. `save_auth` is called from CLI flows, TUI
997    // slash commands, the daemon, AND the silent in-chat 401 → refresh
998    // path. Printing here would corrupt the TUI input box on the silent
999    // refresh path (the cursor sits in the prompt and `println!` bypasses
1000    // the renderer). CLI callers print their own user-facing success
1001    // message right after calling this.
1002    Ok(())
1003}
1004
1005/// Get path to auth file
1006pub fn auth_file_path() -> std::path::PathBuf {
1007    crate::config::Config::config_dir().join("auth.toml")
1008}
1009
1010/// Check if user is logged in
1011pub fn is_logged_in() -> bool {
1012    get_stored_auth().is_some()
1013}
1014
1015/// Get current user info (if logged in)
1016pub fn current_user() -> Option<UserInfo> {
1017    get_stored_auth().map(|auth| auth.user)
1018}
1019
1020/// Parse a user-pasted OAuth callback URL into (code, state).
1021///
1022/// Accepts any URL with a query string containing `code` and `state`.
1023/// Rejects raw `code` without URL context — state validation is CSRF
1024/// protection and we want the full round-trip, not a manually typed code.
1025#[allow(dead_code)]
1026fn parse_pasted_callback(input: &str) -> Result<(String, String)> {
1027    // Defensively strip bracketed-paste markers. The TUI disables DECSET
1028    // 2004 before calling us, but a user pasting into a terminal we didn't
1029    // configure (or with a stray prior session) can still deliver these.
1030    let cleaned = input
1031        .trim()
1032        .trim_start_matches("\x1b[200~")
1033        .trim_end_matches("\x1b[201~")
1034        .trim();
1035
1036    let query_start = cleaned.find('?').context(
1037        "Could not parse callback URL — paste the full http://127.0.0.1:8765/callback?... URL",
1038    )?;
1039    let query = &cleaned[query_start + 1..];
1040
1041    let params: HashMap<String, String> = query
1042        .split('&')
1043        .filter_map(|pair| {
1044            let mut parts = pair.splitn(2, '=');
1045            let key = parts.next()?;
1046            let value = parts
1047                .next()
1048                .map(|v| urlencoding_decode(v))
1049                .unwrap_or_default();
1050            Some((key.to_string(), value))
1051        })
1052        .collect();
1053
1054    if let Some(error) = params.get("error") {
1055        let desc = params
1056            .get("error_description")
1057            .map(|s| s.as_str())
1058            .unwrap_or(error);
1059        anyhow::bail!("OAuth error: {}", desc);
1060    }
1061
1062    let code = params
1063        .get("code")
1064        .context("Callback URL missing 'code' parameter")?
1065        .clone();
1066    let state = params
1067        .get("state")
1068        .context("Callback URL missing 'state' parameter (paste the full URL, not just the code)")?
1069        .clone();
1070
1071    Ok((code, state))
1072}
1073
1074#[cfg(test)]
1075mod tests {
1076    use super::*;
1077
1078    #[test]
1079    fn strip_force_login_removes_trailing_param() {
1080        let url = "https://atomgit.com/oauth/authorize?client_id=abc&state=xyz&force_login=true";
1081        assert_eq!(
1082            strip_force_login(url),
1083            "https://atomgit.com/oauth/authorize?client_id=abc&state=xyz"
1084        );
1085    }
1086
1087    #[test]
1088    fn strip_force_login_removes_middle_param() {
1089        let url = "https://atomgit.com/oauth/authorize?client_id=abc&force_login=true&state=xyz";
1090        assert_eq!(
1091            strip_force_login(url),
1092            "https://atomgit.com/oauth/authorize?client_id=abc&state=xyz"
1093        );
1094    }
1095
1096    #[test]
1097    fn strip_force_login_removes_only_param() {
1098        let url = "https://atomgit.com/oauth/authorize?force_login=true";
1099        assert_eq!(
1100            strip_force_login(url),
1101            "https://atomgit.com/oauth/authorize"
1102        );
1103    }
1104
1105    #[test]
1106    fn strip_force_login_removes_first_of_many() {
1107        let url = "https://atomgit.com/oauth/authorize?force_login=true&state=xyz";
1108        assert_eq!(
1109            strip_force_login(url),
1110            "https://atomgit.com/oauth/authorize?state=xyz"
1111        );
1112    }
1113
1114    #[test]
1115    fn strip_force_login_passthrough_when_absent() {
1116        let url = "https://atomgit.com/oauth/authorize?client_id=abc&state=xyz";
1117        assert_eq!(strip_force_login(url), url);
1118    }
1119
1120    #[test]
1121    fn parse_happy_path_loopback_url() {
1122        let (code, state) =
1123            parse_pasted_callback("http://127.0.0.1:8765/callback?code=abc&state=xyz").unwrap();
1124        assert_eq!(code, "abc");
1125        assert_eq!(state, "xyz");
1126    }
1127
1128    #[test]
1129    fn parse_any_host_with_extra_params() {
1130        let (code, state) =
1131            parse_pasted_callback("https://example.com/x?foo=1&code=abc&state=xyz&bar=2").unwrap();
1132        assert_eq!(code, "abc");
1133        assert_eq!(state, "xyz");
1134    }
1135
1136    #[test]
1137    fn parse_missing_state_errors_with_full_url_hint() {
1138        let err = parse_pasted_callback("http://127.0.0.1:8765/callback?code=abc")
1139            .unwrap_err()
1140            .to_string();
1141        assert!(err.contains("state"), "got: {err}");
1142        assert!(err.contains("full URL"), "got: {err}");
1143    }
1144
1145    #[test]
1146    fn parse_missing_code_errors() {
1147        let err = parse_pasted_callback("http://127.0.0.1:8765/callback?state=xyz")
1148            .unwrap_err()
1149            .to_string();
1150        assert!(err.contains("code"), "got: {err}");
1151    }
1152
1153    #[test]
1154    fn parse_error_response_includes_description() {
1155        let err = parse_pasted_callback(
1156            "http://127.0.0.1:8765/callback?error=access_denied&error_description=User+denied",
1157        )
1158        .unwrap_err()
1159        .to_string();
1160        assert!(err.contains("User denied"), "got: {err}");
1161    }
1162
1163    #[test]
1164    fn parse_not_a_url_errors() {
1165        let err = parse_pasted_callback("this is not a url")
1166            .unwrap_err()
1167            .to_string();
1168        assert!(err.contains("full"), "got: {err}");
1169    }
1170
1171    #[test]
1172    fn parse_url_encoded_state_is_decoded() {
1173        let (_, state) =
1174            parse_pasted_callback("http://127.0.0.1:8765/callback?code=c&state=atomcode_%3Atest")
1175                .unwrap();
1176        assert_eq!(state, "atomcode_:test");
1177    }
1178
1179    #[test]
1180    fn parse_strips_bracketed_paste_markers() {
1181        let input = "\x1b[200~http://127.0.0.1:8765/callback?code=abc&state=xyz\x1b[201~";
1182        let (code, state) = parse_pasted_callback(input).unwrap();
1183        assert_eq!(code, "abc");
1184        assert_eq!(state, "xyz");
1185    }
1186
1187    #[test]
1188    fn parse_trims_surrounding_whitespace() {
1189        let (code, state) =
1190            parse_pasted_callback("   http://127.0.0.1:8765/callback?code=abc&state=xyz\n")
1191                .unwrap();
1192        assert_eq!(code, "abc");
1193        assert_eq!(state, "xyz");
1194    }
1195
1196    // ----- classify_input (ESC vs escape-sequence disambiguation) -----
1197
1198    #[test]
1199    fn classify_input_bare_esc_cancels() {
1200        assert_eq!(classify_input(&[0x1B]), EscOutcome::Cancelled);
1201    }
1202
1203    #[test]
1204    fn classify_input_arrow_key_ignored() {
1205        // Up arrow = ESC [ A — three bytes arriving in a single read.
1206        assert_eq!(classify_input(b"\x1B[A"), EscOutcome::OtherInput);
1207    }
1208
1209    #[test]
1210    fn classify_input_alt_letter_ignored() {
1211        // Alt+a delivered as ESC + 'a' on most terminals.
1212        assert_eq!(classify_input(b"\x1Ba"), EscOutcome::OtherInput);
1213    }
1214
1215    #[test]
1216    fn classify_input_normal_byte_ignored() {
1217        assert_eq!(classify_input(b"q"), EscOutcome::OtherInput);
1218    }
1219
1220    #[test]
1221    fn classify_input_empty_is_timeout() {
1222        assert_eq!(classify_input(&[]), EscOutcome::Timeout);
1223    }
1224
1225    #[test]
1226    fn classify_input_pasted_text_ignored() {
1227        assert_eq!(classify_input(b"hello\n"), EscOutcome::OtherInput);
1228    }
1229
1230    #[test]
1231    fn classify_input_csi_color_code_ignored() {
1232        // Bracketed-paste / OSC sequences and other CSI fragments must
1233        // not be mistaken for ESC. `\x1B[31m` = SGR red.
1234        assert_eq!(classify_input(b"\x1B[31m"), EscOutcome::OtherInput);
1235    }
1236
1237    // ----- sanitize_base_url -----
1238
1239    #[test]
1240    fn sanitize_adds_http_if_no_scheme() {
1241        assert_eq!(sanitize_base_url("127.0.0.1:8765"), "http://127.0.0.1:8765");
1242    }
1243
1244    #[test]
1245    fn sanitize_preserves_http_scheme() {
1246        assert_eq!(sanitize_base_url("http://127.0.0.1:8765"), "http://127.0.0.1:8765");
1247    }
1248
1249    #[test]
1250    fn sanitize_preserves_https_scheme() {
1251        assert_eq!(sanitize_base_url("https://acs.example.com"), "https://acs.example.com");
1252    }
1253
1254    #[test]
1255    fn sanitize_strips_trailing_slash() {
1256        assert_eq!(sanitize_base_url("http://127.0.0.1:8765/"), "http://127.0.0.1:8765");
1257        assert_eq!(sanitize_base_url("http://127.0.0.1:8765///"), "http://127.0.0.1:8765");
1258    }
1259
1260    #[test]
1261    fn sanitize_trims_whitespace() {
1262        assert_eq!(sanitize_base_url("  http://127.0.0.1:8765  "), "http://127.0.0.1:8765");
1263    }
1264
1265    #[test]
1266    fn sanitize_no_scheme_with_trailing_slash() {
1267        assert_eq!(sanitize_base_url("127.0.0.1:8765/"), "http://127.0.0.1:8765");
1268    }
1269}