1use santui_core::auth::{AuthHandle, User};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::io::{BufRead, BufReader, Write};
5use std::net::TcpListener;
6use std::path::PathBuf;
7use std::sync::Arc;
8use std::sync::Mutex;
9use std::thread;
10use std::time::Duration;
11use url::Url;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14struct StoredToken {
15 id: String,
16 email: String,
17 name: String,
18 avatar_url: Option<String>,
19 provider: String,
20 access_token: String,
21 refresh_token: Option<String>,
22}
23
24#[derive(Debug, Clone)]
25pub struct AuthConfig {
26 pub client_id: String,
27 pub client_secret: Option<String>,
28 pub auth_uri: String,
29 pub token_uri: String,
30 pub scopes: Vec<String>,
31 pub redirect_port: u16,
32}
33
34impl AuthConfig {
35 pub fn google(client_id: String, client_secret: Option<String>) -> Self {
36 AuthConfig {
37 client_id,
38 client_secret,
39 auth_uri: "https://accounts.google.com/o/oauth2/v2/auth".into(),
40 token_uri: "https://oauth2.googleapis.com/token".into(),
41 scopes: vec!["openid".into(), "email".into(), "profile".into()],
42 redirect_port: 9842,
43 }
44 }
45
46 pub fn github(client_id: String) -> Self {
47 AuthConfig {
48 client_id,
49 client_secret: None,
50 auth_uri: String::new(),
51 token_uri: "https://github.com/login/oauth/access_token".into(),
52 scopes: vec!["read:user".into(), "user:email".into()],
53 redirect_port: 0,
54 }
55 }
56}
57
58#[cfg(target_os = "windows")]
59fn open_browser(url: &str) {
60 let _ = std::process::Command::new("cmd")
61 .args(["/c", "start", &url.replace('&', "^&")])
62 .spawn();
63}
64
65#[cfg(target_os = "linux")]
66fn open_browser(url: &str) {
67 let _ = std::process::Command::new("xdg-open").arg(url).spawn();
68}
69
70#[cfg(target_os = "macos")]
71fn open_browser(url: &str) {
72 let _ = std::process::Command::new("open").arg(url).spawn();
73}
74
75#[cfg(not(any(target_os = "windows", target_os = "linux", target_os = "macos")))]
76fn open_browser(url: &str) {
77 let _ = std::process::Command::new("xdg-open").arg(url).spawn();
78}
79
80fn bind_with_fallback() -> Result<(TcpListener, u16), Box<dyn std::error::Error>> {
81 for port in 9842..9850 {
82 if let Ok(listener) = TcpListener::bind(("127.0.0.1", port)) {
83 return Ok((listener, port));
84 }
85 }
86 let listener = TcpListener::bind(("127.0.0.1", 0))?;
87 let port = listener.local_addr()?.port();
88 Ok((listener, port))
89}
90
91fn handle_redirect(
92 listener: TcpListener,
93) -> Result<HashMap<String, String>, Box<dyn std::error::Error>> {
94 let (stream, _) = listener.accept()?;
95 stream.set_read_timeout(Some(Duration::from_secs(120)))?;
96 let mut reader = BufReader::new(&stream);
97 let mut request_line = String::new();
98 reader.read_line(&mut request_line)?;
99
100 let params = request_line
101 .split_whitespace()
102 .nth(1)
103 .and_then(|path| {
104 let full_url = format!("http://localhost{path}");
105 Url::parse(&full_url).ok().map(|u| {
106 u.query_pairs()
107 .map(|(k, v)| (k.into_owned(), v.into_owned()))
108 .collect::<HashMap<String, String>>()
109 })
110 })
111 .ok_or_else(|| "No query parameters in redirect".to_string())?;
112
113 let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n<!DOCTYPE html><html lang=\"en\"><head><meta charset=\"UTF-8\"><script src=\"https://cdn.tailwindcss.com\"></script><title>Santui — Signed In</title></head><body class=\"bg-gradient-to-br from-gray-900 via-slate-800 to-gray-900 min-h-screen flex items-center justify-center font-sans\"><div class=\"bg-white/10 backdrop-blur-lg rounded-lg shadow-2xl border border-white/20 p-8 max-w-md w-full mx-4 text-center\"><div class=\"text-emerald-400 mb-4\"><svg class=\"w-16 h-16 mx-auto mb-4\" fill=\"none\" stroke=\"currentColor\" viewBox=\"0 0 24 24\"><path stroke-linecap=\"round\" stroke-linejoin=\"round\" stroke-width=\"1.5\" d=\"M9 12.75L11.25 15 15 9.75M21 12a9 9 0 11-18 0 9 9 0 0118 0z\"/></svg><h1 class=\"text-2xl font-bold mb-1\">Signed In!</h1><p class=\"text-gray-400 text-sm\">You can close this window.</p></div></div></body></html>";
114 let mut stream = stream;
115 let _ = stream.write_all(response.as_bytes());
116
117 if let Some(err) = params.get("error") {
118 return Err(format!("OAuth error from server: {err}").into());
119 }
120
121 Ok(params)
122}
123
124#[derive(Deserialize)]
125struct DeviceCodeResponse {
126 device_code: String,
127 user_code: String,
128 #[allow(dead_code)]
129 verification_uri: String,
130 interval: Option<u64>,
131}
132
133#[derive(Deserialize)]
134struct DeviceTokenResponse {
135 access_token: Option<String>,
136 error: Option<String>,
137}
138
139fn request_device_code(
140 config: &AuthConfig,
141) -> Result<DeviceCodeResponse, Box<dyn std::error::Error>> {
142 let scope = config.scopes.join(" ");
143 let mut resp = ureq::post("https://github.com/login/device/code")
144 .header("Accept", "application/json")
145 .send_form([
146 ("client_id", config.client_id.as_str()),
147 ("scope", scope.as_str()),
148 ])?;
149 let text = resp.body_mut().read_to_string()?;
150 Ok(serde_json::from_str(&text)?)
151}
152
153fn poll_device_token(
154 config: &AuthConfig,
155 device_code: &str,
156 interval: u64,
157) -> Result<String, Box<dyn std::error::Error>> {
158 loop {
159 std::thread::sleep(std::time::Duration::from_secs(interval));
160 let mut resp = ureq::post(&config.token_uri)
161 .header("Accept", "application/json")
162 .send_form([
163 ("client_id", config.client_id.as_str()),
164 ("device_code", device_code),
165 ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
166 ])?;
167 let text = resp.body_mut().read_to_string()?;
168 let body: DeviceTokenResponse = serde_json::from_str(&text)?;
169 if let Some(token) = body.access_token {
170 return Ok(token);
171 }
172 match body.error.as_deref() {
173 Some("authorization_pending") => continue,
174 Some("slow_down") => continue,
175 Some(err) => return Err(format!("device flow error: {err}").into()),
176 None => return Err("unexpected device flow response".into()),
177 }
178 }
179}
180
181fn user_from_token(provider: &str, access_token: &str) -> Result<User, Box<dyn std::error::Error>> {
182 match provider {
183 "github" => {
184 let mut resp = ureq::get("https://api.github.com/user")
185 .header("Authorization", &format!("Bearer {access_token}"))
186 .header("Accept", "application/vnd.github.v3+json")
187 .call()?;
188 let body: serde_json::Value = serde_json::from_str(&resp.body_mut().read_to_string()?)?;
189 Ok(User {
190 id: body["id"].to_string(),
191 email: body["email"].as_str().unwrap_or("").into(),
192 name: body["login"].as_str().unwrap_or("").into(),
193 avatar_url: body["avatar_url"].as_str().map(|s| s.into()),
194 provider: provider.into(),
195 })
196 }
197 _ => Err("unsupported provider".into()),
198 }
199}
200
201pub struct AuthClient {
202 providers: HashMap<String, AuthConfig>,
203 user: Arc<Mutex<Option<User>>>,
204 pending_sign_in: Arc<Mutex<Option<Result<User, String>>>>,
205 auth_msg: Arc<Mutex<Option<String>>>,
206 token_path: PathBuf,
207 vercel_url: String,
208}
209
210impl AuthClient {
211 pub fn new(providers: Vec<(String, AuthConfig)>) -> Self {
212 let token_path = dirs::data_dir()
213 .unwrap_or_else(|| PathBuf::from("."))
214 .join("santui")
215 .join("auth-tokens.json");
216 let user = Self::load_tokens(&token_path);
217 AuthClient {
218 providers: providers.into_iter().collect(),
219 user: Arc::new(Mutex::new(user)),
220 pending_sign_in: Arc::new(Mutex::new(None)),
221 auth_msg: Arc::new(Mutex::new(None)),
222 token_path,
223 vercel_url: String::new(),
224 }
225 }
226
227 pub fn with_vercel(mut self, url: String) -> Self {
228 self.vercel_url = url;
229 self
230 }
231
232 fn load_tokens(path: &PathBuf) -> Option<User> {
233 let data = std::fs::read_to_string(path).ok()?;
234 let stored: StoredToken = serde_json::from_str(&data).ok()?;
235 Some(User {
236 id: stored.id,
237 email: stored.email,
238 name: stored.name,
239 avatar_url: stored.avatar_url,
240 provider: stored.provider,
241 })
242 }
243
244 fn clear_tokens(&self) {
245 let _ = std::fs::remove_file(&self.token_path);
246 }
247
248 fn run_google_redirect_flow(
249 vercel_url: &str,
250 token_path: &PathBuf,
251 user_lock: &Arc<Mutex<Option<User>>>,
252 pending: &Arc<Mutex<Option<Result<User, String>>>>,
253 auth_msg: &Arc<Mutex<Option<String>>>,
254 ) {
255 let vercel = if vercel_url.is_empty() {
256 "https://santuiapp.vercel.app".to_string()
257 } else {
258 vercel_url.to_string()
259 };
260
261 let (listener, port) = match bind_with_fallback() {
262 Ok(v) => v,
263 Err(e) => {
264 *pending.lock().unwrap_or_else(|e| e.into_inner()) = Some(Err(e.to_string()));
265 *auth_msg.lock().unwrap_or_else(|e| e.into_inner()) = None;
266 return;
267 }
268 };
269 let auth_url = format!("{vercel}/api/auth/google?port={port}");
270 *auth_msg.lock().unwrap_or_else(|e| e.into_inner()) =
271 Some("Google: waiting for browser…".into());
272 open_browser(&auth_url);
273
274 let params = match handle_redirect(listener) {
275 Ok(p) => p,
276 Err(e) => {
277 *pending.lock().unwrap_or_else(|e| e.into_inner()) = Some(Err(e.to_string()));
278 *auth_msg.lock().unwrap_or_else(|e| e.into_inner()) = None;
279 return;
280 }
281 };
282
283 let access_token = match params.get("access_token") {
284 Some(t) => t.clone(),
285 None => {
286 *pending.lock().unwrap_or_else(|e| e.into_inner()) =
287 Some(Err("No access_token in redirect".into()));
288 *auth_msg.lock().unwrap_or_else(|e| e.into_inner()) = None;
289 return;
290 }
291 };
292
293 let user = User {
294 id: params.get("id").cloned().unwrap_or_default(),
295 email: params.get("email").cloned().unwrap_or_default(),
296 name: params.get("name").cloned().unwrap_or_default(),
297 avatar_url: params.get("avatar_url").cloned(),
298 provider: "google".into(),
299 };
300
301 let stored = StoredToken {
302 id: user.id.clone(),
303 email: user.email.clone(),
304 name: user.name.clone(),
305 avatar_url: user.avatar_url.clone(),
306 provider: user.provider.clone(),
307 access_token,
308 refresh_token: None,
309 };
310 save_tokens_to_path(token_path, &stored);
311 *user_lock.lock().unwrap_or_else(|e| e.into_inner()) = Some(user.clone());
312 *auth_msg.lock().unwrap_or_else(|e| e.into_inner()) = None;
313 *pending.lock().unwrap_or_else(|e| e.into_inner()) = Some(Ok(user));
314 }
315
316 fn sign_in_google(&self) -> Result<User, Box<dyn std::error::Error>> {
317 let vercel_url = self.vercel_url.clone();
318 Self::run_google_redirect_flow(
319 &vercel_url,
320 &self.token_path,
321 &self.user,
322 &self.pending_sign_in,
323 &self.auth_msg,
324 );
325
326 loop {
328 if let Some(result) = self
329 .pending_sign_in
330 .lock()
331 .unwrap_or_else(|e| e.into_inner())
332 .take()
333 {
334 *self.auth_msg.lock().unwrap_or_else(|e| e.into_inner()) = None;
335 return result.map_err(|e| e.into());
336 }
337 thread::sleep(Duration::from_millis(100));
338 }
339 }
340
341 fn start_sign_in_google(&self) -> Result<(), Box<dyn std::error::Error>> {
342 let vercel_url = self.vercel_url.clone();
343 let token_path = self.token_path.clone();
344 let user_lock = Arc::clone(&self.user);
345 let pending = Arc::clone(&self.pending_sign_in);
346 let auth_msg = Arc::clone(&self.auth_msg);
347
348 thread::spawn(move || {
349 Self::run_google_redirect_flow(
350 &vercel_url,
351 &token_path,
352 &user_lock,
353 &pending,
354 &auth_msg,
355 );
356 });
357
358 Ok(())
359 }
360
361 fn run_github_device_flow(
362 config: &AuthConfig,
363 token_path: &PathBuf,
364 user_lock: &Arc<Mutex<Option<User>>>,
365 pending: &Arc<Mutex<Option<Result<User, String>>>>,
366 auth_msg: &Arc<Mutex<Option<String>>>,
367 ) {
368 let device = match request_device_code(config) {
369 Ok(d) => d,
370 Err(e) => {
371 *pending.lock().unwrap_or_else(|e| e.into_inner()) = Some(Err(e.to_string()));
372 *auth_msg.lock().unwrap_or_else(|e| e.into_inner()) = None;
373 return;
374 }
375 };
376 let user_code = device.user_code.clone();
377 let interval = device.interval.unwrap_or(5);
378 let activation_url = format!("https://github.com/login/device?user_code={user_code}");
379 *auth_msg.lock().unwrap_or_else(|e| e.into_inner()) = Some(format!(
380 "GitHub: enter code {user_code} at github.com/login/device"
381 ));
382 open_browser(&activation_url);
383
384 let access_token = match poll_device_token(config, &device.device_code, interval) {
385 Ok(t) => t,
386 Err(e) => {
387 *pending.lock().unwrap_or_else(|e| e.into_inner()) = Some(Err(e.to_string()));
388 *auth_msg.lock().unwrap_or_else(|e| e.into_inner()) = None;
389 return;
390 }
391 };
392
393 let user = match user_from_token("github", &access_token) {
394 Ok(u) => u,
395 Err(e) => {
396 *pending.lock().unwrap_or_else(|e| e.into_inner()) = Some(Err(e.to_string()));
397 *auth_msg.lock().unwrap_or_else(|e| e.into_inner()) = None;
398 return;
399 }
400 };
401
402 let stored = StoredToken {
403 id: user.id.clone(),
404 email: user.email.clone(),
405 name: user.name.clone(),
406 avatar_url: user.avatar_url.clone(),
407 provider: user.provider.clone(),
408 access_token,
409 refresh_token: None,
410 };
411 save_tokens_to_path(token_path, &stored);
412 *user_lock.lock().unwrap_or_else(|e| e.into_inner()) = Some(user.clone());
413 *auth_msg.lock().unwrap_or_else(|e| e.into_inner()) = None;
414 *pending.lock().unwrap_or_else(|e| e.into_inner()) = Some(Ok(user));
415 }
416
417 fn sign_in_github(&self) -> Result<User, Box<dyn std::error::Error>> {
418 let config = self
419 .providers
420 .get("github")
421 .ok_or_else(|| "GitHub auth not configured".to_string())?;
422
423 let clone = config.clone();
424 Self::run_github_device_flow(
425 &clone,
426 &self.token_path,
427 &self.user,
428 &self.pending_sign_in,
429 &self.auth_msg,
430 );
431
432 loop {
434 if let Some(result) = self
435 .pending_sign_in
436 .lock()
437 .unwrap_or_else(|e| e.into_inner())
438 .take()
439 {
440 return result.map_err(|e| e.into());
441 }
442 thread::sleep(Duration::from_millis(100));
443 }
444 }
445
446 fn start_sign_in_github(&self) -> Result<(), Box<dyn std::error::Error>> {
447 let config = self
448 .providers
449 .get("github")
450 .ok_or_else(|| "GitHub auth not configured".to_string())?
451 .clone();
452 let token_path = self.token_path.clone();
453 let user_lock = Arc::clone(&self.user);
454 let pending = Arc::clone(&self.pending_sign_in);
455 let msg = Arc::clone(&self.auth_msg);
456
457 thread::spawn(move || {
458 Self::run_github_device_flow(&config, &token_path, &user_lock, &pending, &msg);
459 });
460
461 Ok(())
462 }
463}
464
465fn save_tokens_to_path(token_path: &PathBuf, stored: &StoredToken) {
466 if let Some(parent) = token_path.parent() {
467 let _ = std::fs::create_dir_all(parent);
468 }
469 if let Ok(data) = serde_json::to_string_pretty(stored) {
470 let _ = std::fs::write(token_path, data);
471 }
472}
473
474impl AuthHandle for AuthClient {
475 fn current_user(&self) -> Option<User> {
476 self.user.lock().unwrap_or_else(|e| e.into_inner()).clone()
477 }
478
479 fn bearer_token(&self) -> Option<String> {
480 let data = std::fs::read_to_string(&self.token_path).ok()?;
481 let stored: StoredToken = serde_json::from_str(&data).ok()?;
482 Some(stored.access_token)
483 }
484
485 fn sign_in(&self, provider: &str) -> Result<User, Box<dyn std::error::Error>> {
486 match provider {
487 "google" => self.sign_in_google(),
488 "github" => self.sign_in_github(),
489 _ => Err("unsupported provider".into()),
490 }
491 }
492
493 fn start_sign_in(&self, provider: &str) -> Result<(), Box<dyn std::error::Error>> {
494 match provider {
495 "github" => self.start_sign_in_github(),
496 "google" => self.start_sign_in_google(),
497 _ => Err("unsupported provider".into()),
498 }
499 }
500
501 fn drain_pending_sign_in(&self) -> Option<Result<User, Box<dyn std::error::Error>>> {
502 let mut guard = self
503 .pending_sign_in
504 .lock()
505 .unwrap_or_else(|e| e.into_inner());
506 guard.take().map(|r| r.map_err(|e| e.into()))
507 }
508
509 fn auth_message(&self) -> Option<String> {
510 self.auth_msg
511 .lock()
512 .unwrap_or_else(|e| e.into_inner())
513 .clone()
514 }
515
516 fn sign_out(&self) {
517 self.clear_tokens();
518 *self.auth_msg.lock().unwrap_or_else(|e| e.into_inner()) = None;
519 *self.user.lock().unwrap_or_else(|e| e.into_inner()) = None;
520 }
521}