1use std::collections::HashMap;
2use std::io::{self, BufRead, Write};
3use std::net::TcpListener;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::sync::{mpsc, Arc};
6use std::thread;
7use std::time::Duration;
8
9use anyhow::{Context, Result};
10use serde::{Deserialize, Serialize};
11
12use atomcode_telemetry::{Event, Telemetry};
13
14const DEFAULT_PLATFORM_SERVER: &str = "https://acs.atomgit.com";
17
18fn sanitize_base_url(raw: &str) -> String {
21 let trimmed = raw.trim();
22 let with_scheme = if trimmed.contains("://") {
23 trimmed.to_string()
24 } else {
25 format!("http://{}", trimmed)
26 };
27 with_scheme.trim_end_matches('/').to_string()
28}
29
30fn platform_base_url() -> &'static str {
35 use std::sync::OnceLock;
36 static BASE: OnceLock<String> = OnceLock::new();
37 BASE.get_or_init(|| {
38 let raw = std::env::var("ATOMCODE_PLATFORM_SERVER")
39 .unwrap_or_else(|_| DEFAULT_PLATFORM_SERVER.to_string());
40 sanitize_base_url(&raw)
41 })
42}
43
44pub fn platform_broker_url() -> String { platform_base_url().to_string() }
46pub fn platform_login_url() -> String { format!("{}/auth/login", platform_base_url()) }
47pub fn platform_check_url() -> String { format!("{}/auth/check", platform_base_url()) }
48pub fn platform_token_url() -> String { format!("{}/auth/token", platform_base_url()) }
49pub fn platform_exchange_url() -> String { format!("{}/oauth/exchange", platform_base_url()) }
50pub fn platform_refresh_url() -> String { format!("{}/oauth/refresh", platform_base_url()) }
51#[allow(dead_code)]
52pub fn authorize_url() -> String { format!("{}/oauth/authorize", platform_base_url()) }
53#[allow(dead_code)]
54pub fn token_url() -> String { format!("{}/oauth/token", platform_base_url()) }
55#[allow(dead_code)]
56pub fn user_url() -> String { format!("{}/api/v5/user", platform_base_url()) }
57
58fn blocking_client() -> reqwest::blocking::Client {
63 reqwest::blocking::Client::builder()
70 .connect_timeout(std::time::Duration::from_secs(5))
71 .timeout(std::time::Duration::from_secs(10))
72 .user_agent(crate::ATOMCODE_USER_AGENT)
73 .build()
74 .unwrap_or_else(|_| reqwest::blocking::Client::new())
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct AuthInfo {
80 pub access_token: String,
81 pub refresh_token: Option<String>,
82 pub token_type: String,
83 pub expires_in: Option<i64>,
84 #[serde(default)]
86 pub created_at: i64,
87 pub user: UserInfo,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct UserInfo {
92 pub id: String,
93 pub username: String,
94 pub name: Option<String>,
95 pub email: Option<String>,
96 pub avatar_url: Option<String>,
97}
98
99#[derive(Debug, Deserialize)]
104struct PlatformLoginResponse {
105 login_url: String,
106 state: String,
107}
108
109#[derive(Debug, Deserialize)]
110struct PlatformCheckResponse {
111 valid: bool,
112}
113
114#[derive(Debug, Deserialize)]
115struct PlatformUserInfo {
116 id: String,
117 username: String,
118 name: Option<String>,
119 email: Option<String>,
120 avatar_url: Option<String>,
121}
122
123#[derive(Debug, Deserialize)]
124struct PlatformTokenResponse {
125 access_token: String,
126 token_type: String,
127 expires_in: Option<i64>,
128 refresh_token: Option<String>,
129 user: PlatformUserInfo,
130}
131
132#[cfg_attr(target_os = "windows", allow(dead_code))]
164#[derive(Debug, Clone, Copy, PartialEq, Eq)]
165enum EscOutcome {
166 Cancelled,
168 Timeout,
170 OtherInput,
174}
175
176#[cfg_attr(target_os = "windows", allow(dead_code))]
192fn classify_input(bytes: &[u8]) -> EscOutcome {
193 match bytes {
194 [] => EscOutcome::Timeout,
195 [0x1B] => EscOutcome::Cancelled,
196 _ => EscOutcome::OtherInput,
197 }
198}
199
200#[cfg(not(target_os = "windows"))]
201struct CbreakGuard {
202 fd: std::os::unix::io::RawFd,
203 orig: libc::termios,
204}
205
206#[cfg(target_os = "windows")]
207struct CbreakGuard;
208
209impl CbreakGuard {
210 #[cfg(not(target_os = "windows"))]
214 fn new() -> Option<Self> {
215 use std::os::unix::io::AsRawFd;
216 let fd = io::stdin().as_raw_fd();
217 let mut orig: libc::termios = unsafe { std::mem::zeroed() };
218 if unsafe { libc::tcgetattr(fd, &mut orig) } != 0 {
219 return None;
220 }
221 let mut new = orig;
222 new.c_lflag &= !(libc::ICANON | libc::ECHO);
223 new.c_cc[libc::VMIN] = 0;
224 new.c_cc[libc::VTIME] = 0;
225 if unsafe { libc::tcsetattr(fd, libc::TCSANOW, &new) } != 0 {
226 return None;
227 }
228 Some(Self { fd, orig })
229 }
230
231 #[cfg(target_os = "windows")]
232 fn new() -> Option<Self> {
233 None
234 }
235}
236
237#[cfg(not(target_os = "windows"))]
238impl Drop for CbreakGuard {
239 fn drop(&mut self) {
240 unsafe {
244 libc::tcsetattr(self.fd, libc::TCSANOW, &self.orig);
245 }
246 }
247}
248
249#[cfg(not(target_os = "windows"))]
257fn wait_for_esc_or_timeout(guard: &Option<CbreakGuard>, timeout: Duration) -> EscOutcome {
258 let Some(g) = guard.as_ref() else {
259 thread::sleep(timeout);
260 return EscOutcome::Timeout;
261 };
262
263 let mut pfd = libc::pollfd {
264 fd: g.fd,
265 events: libc::POLLIN,
266 revents: 0,
267 };
268 let timeout_ms = timeout.as_millis().min(i32::MAX as u128) as i32;
269 let rc = unsafe { libc::poll(&mut pfd, 1, timeout_ms) };
270 if rc <= 0 {
271 return EscOutcome::Timeout;
275 }
276 let mut buf = [0u8; 32];
277 let n = unsafe { libc::read(g.fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len()) };
278 if n <= 0 {
279 return EscOutcome::Timeout;
280 }
281 classify_input(&buf[..n as usize])
282}
283
284#[cfg(target_os = "windows")]
285fn wait_for_esc_or_timeout(_guard: &Option<CbreakGuard>, timeout: Duration) -> EscOutcome {
286 thread::sleep(timeout);
287 EscOutcome::Timeout
288}
289
290#[derive(Debug, Clone, Copy, PartialEq, Eq)]
292pub enum PollOutcome {
293 Pending,
295 Authorized,
297}
298
299pub struct LoginSession {
308 state: String,
309 login_url: String,
310 client: reqwest::blocking::Client,
311}
312
313impl LoginSession {
314 pub fn url(&self) -> &str {
317 &self.login_url
318 }
319
320 pub fn open_browser_best_effort(&self) {
323 let _ = open_browser(&self.login_url);
324 }
325
326 pub fn poll_once(&self) -> Result<PollOutcome> {
331 let resp = self
332 .client
333 .get(platform_check_url())
334 .query(&[("state", &self.state)])
335 .send()
336 .context("Failed to call /auth/check")?;
337 if resp.status().is_success() {
338 if let Ok(check) = resp.json::<PlatformCheckResponse>() {
339 if check.valid {
340 return Ok(PollOutcome::Authorized);
341 }
342 }
343 }
344 Ok(PollOutcome::Pending)
345 }
346
347 pub fn finish(self, tel: Option<&Arc<Telemetry>>) -> Result<AuthInfo> {
351 let token_resp: PlatformTokenResponse = self
352 .client
353 .get(platform_token_url())
354 .query(&[("state", &self.state)])
355 .send()
356 .context("Failed to call /auth/token")?
357 .json()
358 .context("Failed to parse /auth/token response")?;
359
360 let created_at = std::time::SystemTime::now()
361 .duration_since(std::time::UNIX_EPOCH)
362 .unwrap()
363 .as_secs() as i64;
364
365 let auth_info = AuthInfo {
366 access_token: token_resp.access_token,
367 refresh_token: token_resp.refresh_token,
368 token_type: token_resp.token_type,
369 expires_in: token_resp.expires_in,
370 created_at,
371 user: UserInfo {
372 id: token_resp.user.id,
373 username: token_resp.user.username,
374 name: token_resp.user.name,
375 email: token_resp.user.email,
376 avatar_url: token_resp.user.avatar_url,
377 },
378 };
379
380 if let Some(t) = tel {
381 t.set_account_id(Some(auth_info.user.id.to_string()));
387 t.track(Event::LoginSuccess);
388 }
389
390 Ok(auth_info)
391 }
392}
393
394pub fn start_login() -> Result<LoginSession> {
399 let client = reqwest::blocking::Client::new();
400 let resp: PlatformLoginResponse = client
401 .get(platform_login_url())
402 .query(&[("provider", "atomgit")])
403 .send()
404 .context("Failed to call /auth/login")?
405 .json()
406 .context("Failed to parse /auth/login response")?;
407 Ok(LoginSession {
408 state: resp.state,
409 login_url: strip_force_login(&resp.login_url),
410 client,
411 })
412}
413
414fn strip_force_login(url: &str) -> String {
421 url.replace("&force_login=true", "")
422 .replace("?force_login=true&", "?")
423 .replace("?force_login=true", "")
424}
425
426pub fn login(tel: Option<&Arc<Telemetry>>) -> Result<AuthInfo> {
438 let session = start_login()?;
439
440 println!(" Browser didn't open? Open the URL below in any browser to sign in:");
444 println!(" {}", session.url());
445
446 let cbreak = CbreakGuard::new();
450 if cbreak.is_some() {
451 println!();
452 println!(" Press ESC to cancel");
453 }
454
455 session.open_browser_best_effort();
456
457 loop {
458 match session.poll_once()? {
459 PollOutcome::Authorized => break,
460 PollOutcome::Pending => {}
461 }
462 match wait_for_esc_or_timeout(&cbreak, Duration::from_secs(2)) {
463 EscOutcome::Cancelled => anyhow::bail!("login cancelled by user"),
464 EscOutcome::Timeout | EscOutcome::OtherInput => {}
465 }
466 }
467
468 session.finish(tel)
469}
470
471#[allow(dead_code)]
473fn pasted_state(url: &str) -> Option<String> {
474 url.split('?')
475 .nth(1)?
476 .split('&')
477 .filter_map(|pair| {
478 let mut parts = pair.splitn(2, '=');
479 if parts.next()? == "state" {
480 Some(urlencoding_decode(parts.next()?))
481 } else {
482 None
483 }
484 })
485 .next()
486}
487
488#[allow(dead_code)]
490fn generate_state() -> String {
491 use std::time::{SystemTime, UNIX_EPOCH};
492 let timestamp = SystemTime::now()
493 .duration_since(UNIX_EPOCH)
494 .unwrap()
495 .as_nanos();
496 format!("atomcode_{}", timestamp)
497}
498
499#[cfg(target_os = "macos")]
507pub fn open_browser(url: &str) -> Result<()> {
508 std::process::Command::new("open")
509 .arg(url)
510 .spawn()
511 .context("Failed to open browser")?;
512 Ok(())
513}
514
515#[cfg(target_os = "linux")]
516pub fn open_browser(url: &str) -> Result<()> {
517 std::process::Command::new("xdg-open")
518 .arg(url)
519 .spawn()
520 .context("Failed to open browser")?;
521 Ok(())
522}
523
524#[cfg(target_os = "windows")]
525pub fn open_browser(url: &str) -> Result<()> {
526 use std::os::windows::process::CommandExt;
527 std::process::Command::new("cmd")
528 .raw_arg(format!("/C start \"\" \"{}\"", url))
529 .spawn()
530 .context("Failed to open browser")?;
531 Ok(())
532}
533
534#[cfg(not(any(target_os = "macos", target_os = "linux", target_os = "windows")))]
535pub fn open_browser(_url: &str) -> Result<()> {
536 anyhow::bail!("Unsupported platform for browser auto-open");
537}
538
539#[allow(dead_code)]
548fn await_callback(port: u16) -> Result<(String, String)> {
549 let listener = match TcpListener::bind(("127.0.0.1", port)) {
550 Ok(l) => Some(l),
551 Err(e) => {
552 println!(" Could not bind port {} ({}). Paste path only.", port, e);
553 None
554 }
555 };
556
557 println!(
558 " Waiting for callback on http://127.0.0.1:{}/callback",
559 port
560 );
561 println!(" Or paste the full callback URL here and press Enter:");
562 println!(" (Ctrl+C to cancel)\n");
563
564 let (tx, rx) = mpsc::channel::<Result<(String, String)>>();
565 let stop = Arc::new(AtomicBool::new(false));
566
567 #[cfg_attr(not(target_os = "windows"), allow(unused_variables))]
568 let has_listener = listener.is_some();
569 if let Some(listener) = listener {
570 let tx_l = tx.clone();
571 let stop_l = Arc::clone(&stop);
572 thread::spawn(move || {
573 let r = accept_callback_until_stopped(listener, &stop_l);
574 let _ = tx_l.send(r);
575 });
576 }
577
578 #[cfg(not(target_os = "windows"))]
605 {
606 let tx_stdin = tx.clone();
607 let stop_stdin = Arc::clone(&stop);
608 thread::spawn(move || {
609 let r = read_callback_from_stdin_until_stopped(&stop_stdin);
610 let _ = tx_stdin.send(r);
611 });
612 }
613 #[cfg(target_os = "windows")]
614 {
615 if !has_listener {
616 let tx_stdin = tx.clone();
617 thread::spawn(move || {
618 let stdin = io::stdin();
619 let mut line = String::new();
620 let r = match stdin.lock().read_line(&mut line) {
621 Ok(0) => Err(anyhow::anyhow!("stdin closed")),
622 Ok(_) => parse_pasted_callback(&line),
623 Err(e) => Err(anyhow::Error::new(e).context("Failed to read from stdin")),
624 };
625 let _ = tx_stdin.send(r);
626 });
627 }
628 }
629 drop(tx);
634
635 let result = rx.recv().context("login cancelled")?;
636 stop.store(true, Ordering::Relaxed);
637 result
638}
639
640#[cfg(not(target_os = "windows"))]
653#[allow(dead_code)]
654fn read_callback_from_stdin_until_stopped(stop: &AtomicBool) -> Result<(String, String)> {
655 use std::os::unix::io::AsRawFd;
656
657 let stdin = io::stdin();
658 let fd = stdin.as_raw_fd();
659
660 let orig_flags = unsafe { libc::fcntl(fd, libc::F_GETFL) };
664 if orig_flags >= 0 {
665 unsafe {
666 libc::fcntl(fd, libc::F_SETFL, orig_flags | libc::O_NONBLOCK);
667 }
668 }
669
670 struct FlagGuard {
672 fd: std::os::unix::io::RawFd,
673 orig_flags: i32,
674 }
675 impl Drop for FlagGuard {
676 fn drop(&mut self) {
677 if self.orig_flags >= 0 {
678 unsafe {
679 libc::fcntl(self.fd, libc::F_SETFL, self.orig_flags);
680 }
681 }
682 }
683 }
684 let _guard = FlagGuard { fd, orig_flags };
685
686 let mut line = String::new();
687 let mut buf = [0u8; 256];
688 loop {
689 if stop.load(Ordering::Relaxed) {
690 anyhow::bail!("stdin cancelled");
691 }
692 let mut pfd = libc::pollfd {
693 fd,
694 events: libc::POLLIN,
695 revents: 0,
696 };
697 let poll_rc = unsafe { libc::poll(&mut pfd, 1, 100) };
698 if poll_rc < 0 {
699 let err = io::Error::last_os_error();
700 if err.kind() == io::ErrorKind::Interrupted {
701 continue;
702 }
703 return Err(anyhow::Error::new(err).context("poll(stdin)"));
704 }
705 if poll_rc == 0 {
706 continue; }
708 let n = unsafe { libc::read(fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len()) };
711 if n < 0 {
712 let err = io::Error::last_os_error();
713 if err.kind() == io::ErrorKind::WouldBlock || err.kind() == io::ErrorKind::Interrupted {
714 continue;
715 }
716 return Err(anyhow::Error::new(err).context("read(stdin)"));
717 }
718 if n == 0 {
719 anyhow::bail!("stdin closed");
720 }
721 line.push_str(&String::from_utf8_lossy(&buf[..n as usize]));
724 if line.contains('\n') {
725 return parse_pasted_callback(&line);
726 }
727 }
728}
729
730#[allow(dead_code)]
733fn accept_callback_until_stopped(
734 listener: TcpListener,
735 stop: &AtomicBool,
736) -> Result<(String, String)> {
737 listener
738 .set_nonblocking(true)
739 .context("Failed to set non-blocking mode")?;
740
741 let mut stream = loop {
742 if stop.load(Ordering::Relaxed) {
743 anyhow::bail!("listener cancelled");
744 }
745 match listener.accept() {
746 Ok((stream, _)) => break stream,
747 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
748 thread::sleep(Duration::from_millis(200));
749 continue;
750 }
751 Err(e) => return Err(e).context("Failed to accept connection"),
752 }
753 };
754
755 stream.set_nonblocking(false)?;
756
757 let mut reader = io::BufReader::new(&mut stream);
759 let mut request_line = String::new();
760 reader.read_line(&mut request_line)?;
761
762 let url: String = request_line
764 .split_whitespace()
765 .nth(1)
766 .context("Invalid HTTP request")?
767 .to_string();
768
769 let query_start = url.find('?').context("No query parameters in callback")?;
771 let query = &url[query_start + 1..];
772
773 let params: HashMap<String, String> = query
774 .split('&')
775 .filter_map(|pair| {
776 let mut parts = pair.splitn(2, '=');
777 let key = parts.next()?;
778 let value = parts
779 .next()
780 .map(|v| urlencoding_decode(v))
781 .unwrap_or_default();
782 Some((key.to_string(), value))
783 })
784 .collect();
785
786 if let Some(error) = params.get("error") {
788 let error_desc = params
789 .get("error_description")
790 .map(|s| s.as_str())
791 .unwrap_or(error);
792 let response = "HTTP/1.1 302 Found\r\nLocation: https://atomgit.com\r\n\r\n";
793 let _ = stream.write_all(response.as_bytes());
794 let _ = stream.flush();
795 anyhow::bail!("OAuth error: {}", error_desc);
796 }
797
798 let code = params.get("code").context("No code in callback")?.clone();
799 let state = params.get("state").cloned().unwrap_or_default();
800
801 let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=utf-8\r\n\r\n\
803 <html><head><title>AtomCode Login</title>\
804 <style>body{font-family:system-ui;display:flex;justify-content:center;align-items:center;height:100vh;margin:0;background:#1a1a2e;color:#eee}\
805 .container{text-align:center;padding:2rem}h1{color:#7c3aed;margin:0}p{color:#888}\
806 .success{color:#22c55e;font-size:4rem}</style></head>\
807 <body><div class=\"container\">\
808 <div class=\"success\">✓</div>\
809 <h1>Authorization Successful</h1>\
810 <p>You can close this window and return to AtomCode.</p>\
811 </div></body></html>";
812
813 stream.write_all(response.as_bytes())?;
814 stream.flush()?;
815
816 Ok((code, state))
817}
818
819fn urlencoding_decode(s: &str) -> String {
821 let mut result = String::new();
822 let mut chars = s.chars().peekable();
823
824 while let Some(c) = chars.next() {
825 if c == '%' {
826 let hex: String = chars.by_ref().take(2).collect();
827 if let Ok(byte) = u8::from_str_radix(&hex, 16) {
828 result.push(byte as char);
829 }
830 } else if c == '+' {
831 result.push(' ');
832 } else {
833 result.push(c);
834 }
835 }
836
837 result
838}
839
840pub fn refresh_access_token(auth: &AuthInfo) -> Result<AuthInfo> {
843 let refresh_token = auth
844 .refresh_token
845 .as_deref()
846 .context("No refresh_token available — please /login again")?;
847
848 let client = blocking_client();
849
850 let response = client
852 .post(platform_refresh_url())
853 .json(&serde_json::json!({ "refresh_token": refresh_token }))
854 .send()
855 .context("Failed to send refresh token request to broker")?;
856
857 if !response.status().is_success() {
858 let status = response.status();
859 let body = response.text().unwrap_or_default();
860 anyhow::bail!(
861 "Token refresh failed ({}): {} — please /login again",
862 status,
863 body
864 );
865 }
866
867 #[derive(Deserialize)]
868 struct BrokerResponse {
869 access_token: String,
870 token_type: Option<String>,
871 expires_in: Option<i64>,
872 refresh_token: Option<String>,
873 user: Option<PlatformUserInfo>,
874 }
875
876 let broker_resp: BrokerResponse = response.json().context("Failed to parse broker response")?;
877
878 let created_at = std::time::SystemTime::now()
879 .duration_since(std::time::UNIX_EPOCH)
880 .unwrap()
881 .as_secs() as i64;
882
883 let new_auth = AuthInfo {
884 access_token: broker_resp.access_token,
885 refresh_token: broker_resp
886 .refresh_token
887 .or_else(|| auth.refresh_token.clone()),
888 token_type: broker_resp
889 .token_type
890 .unwrap_or_else(|| auth.token_type.clone()),
891 expires_in: broker_resp.expires_in.or(auth.expires_in),
892 created_at,
893 user: broker_resp
894 .user
895 .map(|u| UserInfo {
896 id: u.id,
897 username: u.username,
898 name: u.name,
899 email: u.email,
900 avatar_url: u.avatar_url,
901 })
902 .unwrap_or_else(|| auth.user.clone()),
903 };
904
905 save_auth(&new_auth)?;
906 Ok(new_auth)
907}
908
909pub fn get_valid_token() -> Result<String> {
912 let auth = get_stored_auth().context("Not logged in — please use /login first")?;
913
914 if let Some(expires_in) = auth.expires_in {
916 let now = std::time::SystemTime::now()
917 .duration_since(std::time::UNIX_EPOCH)
918 .unwrap()
919 .as_secs() as i64;
920 let expires_at = auth.created_at + expires_in;
921
922 if now >= expires_at - 300 {
923 match refresh_access_token(&auth) {
925 Ok(new_auth) => return Ok(new_auth.access_token),
926 Err(e) => anyhow::bail!("Token expired and refresh failed: {}", e),
927 }
928 }
929 } else if auth.created_at == 0 {
930 if auth.refresh_token.is_some() {
933 if let Ok(new_auth) = refresh_access_token(&auth) {
934 return Ok(new_auth.access_token);
935 }
936 }
937 }
938
939 Ok(auth.access_token)
940}
941
942pub fn logout() -> Result<()> {
952 let auth_path = auth_file_path();
953 if auth_path.exists() {
954 std::fs::remove_file(&auth_path).context("Failed to remove auth file")?;
955 }
956 Ok(())
957}
958
959pub fn get_stored_auth() -> Option<AuthInfo> {
961 let auth_path = auth_file_path();
962 if !auth_path.exists() {
963 return None;
964 }
965
966 let content = std::fs::read_to_string(&auth_path).ok()?;
967 toml::from_str(&content).ok()
968}
969
970pub fn save_auth(auth: &AuthInfo) -> Result<()> {
972 let auth_path = auth_file_path();
973
974 if let Some(parent) = auth_path.parent() {
976 std::fs::create_dir_all(parent).context("Failed to create auth directory")?;
977 #[cfg(unix)]
979 {
980 use std::os::unix::fs::PermissionsExt;
981 let _ = std::fs::set_permissions(parent, std::fs::Permissions::from_mode(0o700));
982 }
983 }
984
985 let content = toml::to_string_pretty(auth).context("Failed to serialize auth info")?;
986 super::write_auth_file_secure(&auth_path, &content).context("Failed to write auth file")?;
987
988 #[cfg(unix)]
990 {
991 use std::os::unix::fs::PermissionsExt;
992 std::fs::set_permissions(&auth_path, std::fs::Permissions::from_mode(0o600))
993 .context("Failed to set auth file permissions")?;
994 }
995
996 Ok(())
1003}
1004
1005pub fn auth_file_path() -> std::path::PathBuf {
1007 crate::config::Config::config_dir().join("auth.toml")
1008}
1009
1010pub fn is_logged_in() -> bool {
1012 get_stored_auth().is_some()
1013}
1014
1015pub fn current_user() -> Option<UserInfo> {
1017 get_stored_auth().map(|auth| auth.user)
1018}
1019
1020#[allow(dead_code)]
1026fn parse_pasted_callback(input: &str) -> Result<(String, String)> {
1027 let cleaned = input
1031 .trim()
1032 .trim_start_matches("\x1b[200~")
1033 .trim_end_matches("\x1b[201~")
1034 .trim();
1035
1036 let query_start = cleaned.find('?').context(
1037 "Could not parse callback URL — paste the full http://127.0.0.1:8765/callback?... URL",
1038 )?;
1039 let query = &cleaned[query_start + 1..];
1040
1041 let params: HashMap<String, String> = query
1042 .split('&')
1043 .filter_map(|pair| {
1044 let mut parts = pair.splitn(2, '=');
1045 let key = parts.next()?;
1046 let value = parts
1047 .next()
1048 .map(|v| urlencoding_decode(v))
1049 .unwrap_or_default();
1050 Some((key.to_string(), value))
1051 })
1052 .collect();
1053
1054 if let Some(error) = params.get("error") {
1055 let desc = params
1056 .get("error_description")
1057 .map(|s| s.as_str())
1058 .unwrap_or(error);
1059 anyhow::bail!("OAuth error: {}", desc);
1060 }
1061
1062 let code = params
1063 .get("code")
1064 .context("Callback URL missing 'code' parameter")?
1065 .clone();
1066 let state = params
1067 .get("state")
1068 .context("Callback URL missing 'state' parameter (paste the full URL, not just the code)")?
1069 .clone();
1070
1071 Ok((code, state))
1072}
1073
1074#[cfg(test)]
1075mod tests {
1076 use super::*;
1077
1078 #[test]
1079 fn strip_force_login_removes_trailing_param() {
1080 let url = "https://atomgit.com/oauth/authorize?client_id=abc&state=xyz&force_login=true";
1081 assert_eq!(
1082 strip_force_login(url),
1083 "https://atomgit.com/oauth/authorize?client_id=abc&state=xyz"
1084 );
1085 }
1086
1087 #[test]
1088 fn strip_force_login_removes_middle_param() {
1089 let url = "https://atomgit.com/oauth/authorize?client_id=abc&force_login=true&state=xyz";
1090 assert_eq!(
1091 strip_force_login(url),
1092 "https://atomgit.com/oauth/authorize?client_id=abc&state=xyz"
1093 );
1094 }
1095
1096 #[test]
1097 fn strip_force_login_removes_only_param() {
1098 let url = "https://atomgit.com/oauth/authorize?force_login=true";
1099 assert_eq!(
1100 strip_force_login(url),
1101 "https://atomgit.com/oauth/authorize"
1102 );
1103 }
1104
1105 #[test]
1106 fn strip_force_login_removes_first_of_many() {
1107 let url = "https://atomgit.com/oauth/authorize?force_login=true&state=xyz";
1108 assert_eq!(
1109 strip_force_login(url),
1110 "https://atomgit.com/oauth/authorize?state=xyz"
1111 );
1112 }
1113
1114 #[test]
1115 fn strip_force_login_passthrough_when_absent() {
1116 let url = "https://atomgit.com/oauth/authorize?client_id=abc&state=xyz";
1117 assert_eq!(strip_force_login(url), url);
1118 }
1119
1120 #[test]
1121 fn parse_happy_path_loopback_url() {
1122 let (code, state) =
1123 parse_pasted_callback("http://127.0.0.1:8765/callback?code=abc&state=xyz").unwrap();
1124 assert_eq!(code, "abc");
1125 assert_eq!(state, "xyz");
1126 }
1127
1128 #[test]
1129 fn parse_any_host_with_extra_params() {
1130 let (code, state) =
1131 parse_pasted_callback("https://example.com/x?foo=1&code=abc&state=xyz&bar=2").unwrap();
1132 assert_eq!(code, "abc");
1133 assert_eq!(state, "xyz");
1134 }
1135
1136 #[test]
1137 fn parse_missing_state_errors_with_full_url_hint() {
1138 let err = parse_pasted_callback("http://127.0.0.1:8765/callback?code=abc")
1139 .unwrap_err()
1140 .to_string();
1141 assert!(err.contains("state"), "got: {err}");
1142 assert!(err.contains("full URL"), "got: {err}");
1143 }
1144
1145 #[test]
1146 fn parse_missing_code_errors() {
1147 let err = parse_pasted_callback("http://127.0.0.1:8765/callback?state=xyz")
1148 .unwrap_err()
1149 .to_string();
1150 assert!(err.contains("code"), "got: {err}");
1151 }
1152
1153 #[test]
1154 fn parse_error_response_includes_description() {
1155 let err = parse_pasted_callback(
1156 "http://127.0.0.1:8765/callback?error=access_denied&error_description=User+denied",
1157 )
1158 .unwrap_err()
1159 .to_string();
1160 assert!(err.contains("User denied"), "got: {err}");
1161 }
1162
1163 #[test]
1164 fn parse_not_a_url_errors() {
1165 let err = parse_pasted_callback("this is not a url")
1166 .unwrap_err()
1167 .to_string();
1168 assert!(err.contains("full"), "got: {err}");
1169 }
1170
1171 #[test]
1172 fn parse_url_encoded_state_is_decoded() {
1173 let (_, state) =
1174 parse_pasted_callback("http://127.0.0.1:8765/callback?code=c&state=atomcode_%3Atest")
1175 .unwrap();
1176 assert_eq!(state, "atomcode_:test");
1177 }
1178
1179 #[test]
1180 fn parse_strips_bracketed_paste_markers() {
1181 let input = "\x1b[200~http://127.0.0.1:8765/callback?code=abc&state=xyz\x1b[201~";
1182 let (code, state) = parse_pasted_callback(input).unwrap();
1183 assert_eq!(code, "abc");
1184 assert_eq!(state, "xyz");
1185 }
1186
1187 #[test]
1188 fn parse_trims_surrounding_whitespace() {
1189 let (code, state) =
1190 parse_pasted_callback(" http://127.0.0.1:8765/callback?code=abc&state=xyz\n")
1191 .unwrap();
1192 assert_eq!(code, "abc");
1193 assert_eq!(state, "xyz");
1194 }
1195
1196 #[test]
1199 fn classify_input_bare_esc_cancels() {
1200 assert_eq!(classify_input(&[0x1B]), EscOutcome::Cancelled);
1201 }
1202
1203 #[test]
1204 fn classify_input_arrow_key_ignored() {
1205 assert_eq!(classify_input(b"\x1B[A"), EscOutcome::OtherInput);
1207 }
1208
1209 #[test]
1210 fn classify_input_alt_letter_ignored() {
1211 assert_eq!(classify_input(b"\x1Ba"), EscOutcome::OtherInput);
1213 }
1214
1215 #[test]
1216 fn classify_input_normal_byte_ignored() {
1217 assert_eq!(classify_input(b"q"), EscOutcome::OtherInput);
1218 }
1219
1220 #[test]
1221 fn classify_input_empty_is_timeout() {
1222 assert_eq!(classify_input(&[]), EscOutcome::Timeout);
1223 }
1224
1225 #[test]
1226 fn classify_input_pasted_text_ignored() {
1227 assert_eq!(classify_input(b"hello\n"), EscOutcome::OtherInput);
1228 }
1229
1230 #[test]
1231 fn classify_input_csi_color_code_ignored() {
1232 assert_eq!(classify_input(b"\x1B[31m"), EscOutcome::OtherInput);
1235 }
1236
1237 #[test]
1240 fn sanitize_adds_http_if_no_scheme() {
1241 assert_eq!(sanitize_base_url("127.0.0.1:8765"), "http://127.0.0.1:8765");
1242 }
1243
1244 #[test]
1245 fn sanitize_preserves_http_scheme() {
1246 assert_eq!(sanitize_base_url("http://127.0.0.1:8765"), "http://127.0.0.1:8765");
1247 }
1248
1249 #[test]
1250 fn sanitize_preserves_https_scheme() {
1251 assert_eq!(sanitize_base_url("https://acs.example.com"), "https://acs.example.com");
1252 }
1253
1254 #[test]
1255 fn sanitize_strips_trailing_slash() {
1256 assert_eq!(sanitize_base_url("http://127.0.0.1:8765/"), "http://127.0.0.1:8765");
1257 assert_eq!(sanitize_base_url("http://127.0.0.1:8765///"), "http://127.0.0.1:8765");
1258 }
1259
1260 #[test]
1261 fn sanitize_trims_whitespace() {
1262 assert_eq!(sanitize_base_url(" http://127.0.0.1:8765 "), "http://127.0.0.1:8765");
1263 }
1264
1265 #[test]
1266 fn sanitize_no_scheme_with_trailing_slash() {
1267 assert_eq!(sanitize_base_url("127.0.0.1:8765/"), "http://127.0.0.1:8765");
1268 }
1269}