use santui_core::auth::{AuthHandle, User};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::io::{BufRead, BufReader, Write};
use std::net::TcpListener;
use std::path::PathBuf;
use std::sync::Mutex;
use std::time::Duration;
use url::Url;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct StoredToken {
id: String,
email: String,
name: String,
avatar_url: Option<String>,
provider: String,
access_token: String,
refresh_token: Option<String>,
}
#[derive(Debug, Clone)]
pub struct AuthConfig {
pub client_id: String,
pub client_secret: Option<String>,
pub auth_uri: String,
pub token_uri: String,
pub scopes: Vec<String>,
pub redirect_port: u16,
}
impl AuthConfig {
pub fn google(client_id: String, client_secret: Option<String>) -> Self {
AuthConfig {
client_id,
client_secret,
auth_uri: "https://accounts.google.com/o/oauth2/v2/auth".into(),
token_uri: "https://oauth2.googleapis.com/token".into(),
scopes: vec!["openid".into(), "email".into(), "profile".into()],
redirect_port: 9842,
}
}
pub fn github(client_id: String) -> Self {
AuthConfig {
client_id,
client_secret: None,
auth_uri: String::new(),
token_uri: "https://github.com/login/oauth/access_token".into(),
scopes: vec!["read:user".into(), "user:email".into()],
redirect_port: 0,
}
}
}
#[cfg(target_os = "windows")]
fn open_browser(url: &str) {
let _ = std::process::Command::new("cmd")
.args(["/c", "start", &url.replace('&', "^&")])
.spawn();
}
#[cfg(target_os = "linux")]
fn open_browser(url: &str) {
let _ = std::process::Command::new("xdg-open").arg(url).spawn();
}
#[cfg(target_os = "macos")]
fn open_browser(url: &str) {
let _ = std::process::Command::new("open").arg(url).spawn();
}
#[cfg(not(any(target_os = "windows", target_os = "linux", target_os = "macos")))]
fn open_browser(url: &str) {
let _ = std::process::Command::new("xdg-open").arg(url).spawn();
}
fn handle_redirect(
listener: TcpListener,
) -> Result<HashMap<String, String>, Box<dyn std::error::Error>> {
let (stream, _) = listener.accept()?;
stream.set_read_timeout(Some(Duration::from_secs(120)))?;
let mut reader = BufReader::new(&stream);
let mut request_line = String::new();
reader.read_line(&mut request_line)?;
let params = request_line
.split_whitespace()
.nth(1)
.and_then(|path| {
let full_url = format!("http://localhost{path}");
Url::parse(&full_url).ok().map(|u| {
u.query_pairs()
.map(|(k, v)| (k.into_owned(), v.into_owned()))
.collect::<HashMap<String, String>>()
})
})
.ok_or_else(|| "No query parameters in redirect".to_string())?;
let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n<!DOCTYPE html><html lang=\"en\"><head><meta charset=\"UTF-8\"><script src=\"https://cdn.tailwindcss.com\"></script><title>Santui — Signed In</title></head><body class=\"bg-gradient-to-br from-gray-900 via-slate-800 to-gray-900 min-h-screen flex items-center justify-center font-sans\"><div class=\"bg-white/10 backdrop-blur-lg rounded-lg shadow-2xl border border-white/20 p-8 max-w-md w-full mx-4 text-center\"><div class=\"text-emerald-400 mb-4\"><svg class=\"w-16 h-16 mx-auto mb-4\" fill=\"none\" stroke=\"currentColor\" viewBox=\"0 0 24 24\"><path stroke-linecap=\"round\" stroke-linejoin=\"round\" stroke-width=\"1.5\" d=\"M9 12.75L11.25 15 15 9.75M21 12a9 9 0 11-18 0 9 9 0 0118 0z\"/></svg><h1 class=\"text-2xl font-bold mb-1\">Signed In!</h1><p class=\"text-gray-400 text-sm\">You can close this window.</p></div></div></body></html>";
let mut stream = stream;
let _ = stream.write_all(response.as_bytes());
if let Some(err) = params.get("error") {
return Err(format!("OAuth error from server: {err}").into());
}
Ok(params)
}
#[derive(Deserialize)]
struct DeviceCodeResponse {
device_code: String,
user_code: String,
#[allow(dead_code)]
verification_uri: String,
interval: Option<u64>,
}
#[derive(Deserialize)]
struct DeviceTokenResponse {
access_token: Option<String>,
error: Option<String>,
}
fn request_device_code(
config: &AuthConfig,
) -> Result<DeviceCodeResponse, Box<dyn std::error::Error>> {
let scope = config.scopes.join(" ");
let mut resp = ureq::post("https://github.com/login/device/code")
.header("Accept", "application/json")
.send_form([
("client_id", config.client_id.as_str()),
("scope", scope.as_str()),
])?;
let text = resp.body_mut().read_to_string()?;
Ok(serde_json::from_str(&text)?)
}
fn poll_device_token(
config: &AuthConfig,
device_code: &str,
interval: u64,
) -> Result<String, Box<dyn std::error::Error>> {
loop {
std::thread::sleep(std::time::Duration::from_secs(interval));
let mut resp = ureq::post(&config.token_uri)
.header("Accept", "application/json")
.send_form([
("client_id", config.client_id.as_str()),
("device_code", device_code),
("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
])?;
let text = resp.body_mut().read_to_string()?;
let body: DeviceTokenResponse = serde_json::from_str(&text)?;
if let Some(token) = body.access_token {
return Ok(token);
}
match body.error.as_deref() {
Some("authorization_pending") => continue,
Some("slow_down") => continue,
Some(err) => return Err(format!("device flow error: {err}").into()),
None => return Err("unexpected device flow response".into()),
}
}
}
fn user_from_token(provider: &str, access_token: &str) -> Result<User, Box<dyn std::error::Error>> {
match provider {
"github" => {
let mut resp = ureq::get("https://api.github.com/user")
.header("Authorization", &format!("Bearer {access_token}"))
.header("Accept", "application/vnd.github.v3+json")
.call()?;
let body: serde_json::Value = serde_json::from_str(&resp.body_mut().read_to_string()?)?;
Ok(User {
id: body["id"].to_string(),
email: body["email"].as_str().unwrap_or("").into(),
name: body["login"].as_str().unwrap_or("").into(),
avatar_url: body["avatar_url"].as_str().map(|s| s.into()),
provider: provider.into(),
})
}
_ => Err("unsupported provider".into()),
}
}
pub struct AuthClient {
providers: HashMap<String, AuthConfig>,
user: Mutex<Option<User>>,
token_path: PathBuf,
vercel_url: String,
}
impl AuthClient {
pub fn new(providers: Vec<(String, AuthConfig)>) -> Self {
let token_path = dirs::data_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("santui")
.join("auth-tokens.json");
let user = Self::load_tokens(&token_path);
AuthClient {
providers: providers.into_iter().collect(),
user: Mutex::new(user),
token_path,
vercel_url: String::new(),
}
}
pub fn with_vercel(mut self, url: String) -> Self {
self.vercel_url = url;
self
}
fn load_tokens(path: &PathBuf) -> Option<User> {
let data = std::fs::read_to_string(path).ok()?;
let stored: StoredToken = serde_json::from_str(&data).ok()?;
Some(User {
id: stored.id,
email: stored.email,
name: stored.name,
avatar_url: stored.avatar_url,
provider: stored.provider,
})
}
fn save_tokens(&self, stored: &StoredToken) {
if let Some(parent) = self.token_path.parent() {
let _ = std::fs::create_dir_all(parent);
}
if let Ok(data) = serde_json::to_string_pretty(stored) {
let _ = std::fs::write(&self.token_path, data);
}
}
fn clear_tokens(&self) {
let _ = std::fs::remove_file(&self.token_path);
}
fn sign_in_google(&self) -> Result<User, Box<dyn std::error::Error>> {
let port = 9842;
let listener = TcpListener::bind(("127.0.0.1", port))?;
let vercel = if self.vercel_url.is_empty() {
"https://santuiapp.vercel.app".to_string()
} else {
self.vercel_url.clone()
};
let auth_url = format!("{vercel}/api/auth/google?port={port}");
open_browser(&auth_url);
let params = handle_redirect(listener)?;
let access_token = params
.get("access_token")
.ok_or_else(|| "No access_token in redirect".to_string())?;
let user = User {
id: params.get("id").cloned().unwrap_or_default(),
email: params.get("email").cloned().unwrap_or_default(),
name: params.get("name").cloned().unwrap_or_default(),
avatar_url: params.get("avatar_url").cloned(),
provider: "google".into(),
};
self.save_tokens(&StoredToken {
id: user.id.clone(),
email: user.email.clone(),
name: user.name.clone(),
avatar_url: user.avatar_url.clone(),
provider: user.provider.clone(),
access_token: access_token.clone(),
refresh_token: None,
});
*self.user.lock().unwrap_or_else(|e| e.into_inner()) = Some(user.clone());
Ok(user)
}
fn sign_in_github(&self) -> Result<User, Box<dyn std::error::Error>> {
let config = self
.providers
.get("github")
.ok_or_else(|| "GitHub auth not configured".to_string())?;
let device = request_device_code(config)?;
let user_code = &device.user_code;
let interval = device.interval.unwrap_or(5);
let activation_url = format!("https://github.com/login/device?user_code={user_code}");
open_browser(&activation_url);
let access_token = poll_device_token(config, &device.device_code, interval)?;
let user = user_from_token("github", &access_token)?;
self.save_tokens(&StoredToken {
id: user.id.clone(),
email: user.email.clone(),
name: user.name.clone(),
avatar_url: user.avatar_url.clone(),
provider: user.provider.clone(),
access_token,
refresh_token: None,
});
*self.user.lock().unwrap_or_else(|e| e.into_inner()) = Some(user.clone());
Ok(user)
}
}
impl AuthHandle for AuthClient {
fn current_user(&self) -> Option<User> {
self.user.lock().unwrap_or_else(|e| e.into_inner()).clone()
}
fn bearer_token(&self) -> Option<String> {
let data = std::fs::read_to_string(&self.token_path).ok()?;
let stored: StoredToken = serde_json::from_str(&data).ok()?;
Some(stored.access_token)
}
fn sign_in(&self, provider: &str) -> Result<User, Box<dyn std::error::Error>> {
match provider {
"google" => self.sign_in_google(),
"github" => self.sign_in_github(),
_ => Err("unsupported provider".into()),
}
}
fn sign_out(&self) {
self.clear_tokens();
*self.user.lock().unwrap_or_else(|e| e.into_inner()) = None;
}
}