Skip to main content

santui_auth/
lib.rs

1use santui_core::auth::{AuthHandle, User};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::io::{BufRead, BufReader, Write};
5use std::net::TcpListener;
6use std::path::PathBuf;
7use std::sync::Mutex;
8use std::time::Duration;
9use url::Url;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12struct StoredToken {
13    id: String,
14    email: String,
15    name: String,
16    avatar_url: Option<String>,
17    provider: String,
18    access_token: String,
19    refresh_token: Option<String>,
20}
21
22#[derive(Debug, Clone)]
23pub struct AuthConfig {
24    pub client_id: String,
25    pub client_secret: Option<String>,
26    pub auth_uri: String,
27    pub token_uri: String,
28    pub scopes: Vec<String>,
29    pub redirect_port: u16,
30}
31
32impl AuthConfig {
33    pub fn google(client_id: String, client_secret: Option<String>) -> Self {
34        AuthConfig {
35            client_id,
36            client_secret,
37            auth_uri: "https://accounts.google.com/o/oauth2/v2/auth".into(),
38            token_uri: "https://oauth2.googleapis.com/token".into(),
39            scopes: vec!["openid".into(), "email".into(), "profile".into()],
40            redirect_port: 9842,
41        }
42    }
43
44    pub fn github(client_id: String) -> Self {
45        AuthConfig {
46            client_id,
47            client_secret: None,
48            auth_uri: String::new(),
49            token_uri: "https://github.com/login/oauth/access_token".into(),
50            scopes: vec!["read:user".into(), "user:email".into()],
51            redirect_port: 0,
52        }
53    }
54}
55
56#[cfg(target_os = "windows")]
57fn open_browser(url: &str) {
58    let _ = std::process::Command::new("cmd")
59        .args(["/c", "start", &url.replace('&', "^&")])
60        .spawn();
61}
62
63#[cfg(target_os = "linux")]
64fn open_browser(url: &str) {
65    let _ = std::process::Command::new("xdg-open").arg(url).spawn();
66}
67
68#[cfg(target_os = "macos")]
69fn open_browser(url: &str) {
70    let _ = std::process::Command::new("open").arg(url).spawn();
71}
72
73#[cfg(not(any(target_os = "windows", target_os = "linux", target_os = "macos")))]
74fn open_browser(url: &str) {
75    let _ = std::process::Command::new("xdg-open").arg(url).spawn();
76}
77
78fn handle_redirect(
79    listener: TcpListener,
80) -> Result<HashMap<String, String>, Box<dyn std::error::Error>> {
81    let (stream, _) = listener.accept()?;
82    stream.set_read_timeout(Some(Duration::from_secs(120)))?;
83    let mut reader = BufReader::new(&stream);
84    let mut request_line = String::new();
85    reader.read_line(&mut request_line)?;
86
87    let params = request_line
88        .split_whitespace()
89        .nth(1)
90        .and_then(|path| {
91            let full_url = format!("http://localhost{path}");
92            Url::parse(&full_url).ok().map(|u| {
93                u.query_pairs()
94                    .map(|(k, v)| (k.into_owned(), v.into_owned()))
95                    .collect::<HashMap<String, String>>()
96            })
97        })
98        .ok_or_else(|| "No query parameters in redirect".to_string())?;
99
100    let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n<!DOCTYPE html><html lang=\"en\"><head><meta charset=\"UTF-8\"><script src=\"https://cdn.tailwindcss.com\"></script><title>Santui — Signed In</title></head><body class=\"bg-gradient-to-br from-gray-900 via-slate-800 to-gray-900 min-h-screen flex items-center justify-center font-sans\"><div class=\"bg-white/10 backdrop-blur-lg rounded-lg shadow-2xl border border-white/20 p-8 max-w-md w-full mx-4 text-center\"><div class=\"text-emerald-400 mb-4\"><svg class=\"w-16 h-16 mx-auto mb-4\" fill=\"none\" stroke=\"currentColor\" viewBox=\"0 0 24 24\"><path stroke-linecap=\"round\" stroke-linejoin=\"round\" stroke-width=\"1.5\" d=\"M9 12.75L11.25 15 15 9.75M21 12a9 9 0 11-18 0 9 9 0 0118 0z\"/></svg><h1 class=\"text-2xl font-bold mb-1\">Signed In!</h1><p class=\"text-gray-400 text-sm\">You can close this window.</p></div></div></body></html>";
101    let mut stream = stream;
102    let _ = stream.write_all(response.as_bytes());
103
104    if let Some(err) = params.get("error") {
105        return Err(format!("OAuth error from server: {err}").into());
106    }
107
108    Ok(params)
109}
110
111#[derive(Deserialize)]
112struct DeviceCodeResponse {
113    device_code: String,
114    user_code: String,
115    #[allow(dead_code)]
116    verification_uri: String,
117    interval: Option<u64>,
118}
119
120#[derive(Deserialize)]
121struct DeviceTokenResponse {
122    access_token: Option<String>,
123    error: Option<String>,
124}
125
126fn request_device_code(
127    config: &AuthConfig,
128) -> Result<DeviceCodeResponse, Box<dyn std::error::Error>> {
129    let scope = config.scopes.join(" ");
130    let mut resp = ureq::post("https://github.com/login/device/code")
131        .header("Accept", "application/json")
132        .send_form([
133            ("client_id", config.client_id.as_str()),
134            ("scope", scope.as_str()),
135        ])?;
136    let text = resp.body_mut().read_to_string()?;
137    Ok(serde_json::from_str(&text)?)
138}
139
140fn poll_device_token(
141    config: &AuthConfig,
142    device_code: &str,
143    interval: u64,
144) -> Result<String, Box<dyn std::error::Error>> {
145    loop {
146        std::thread::sleep(std::time::Duration::from_secs(interval));
147        let mut resp = ureq::post(&config.token_uri)
148            .header("Accept", "application/json")
149            .send_form([
150                ("client_id", config.client_id.as_str()),
151                ("device_code", device_code),
152                ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
153            ])?;
154        let text = resp.body_mut().read_to_string()?;
155        let body: DeviceTokenResponse = serde_json::from_str(&text)?;
156        if let Some(token) = body.access_token {
157            return Ok(token);
158        }
159        match body.error.as_deref() {
160            Some("authorization_pending") => continue,
161            Some("slow_down") => continue,
162            Some(err) => return Err(format!("device flow error: {err}").into()),
163            None => return Err("unexpected device flow response".into()),
164        }
165    }
166}
167
168fn user_from_token(provider: &str, access_token: &str) -> Result<User, Box<dyn std::error::Error>> {
169    match provider {
170        "github" => {
171            let mut resp = ureq::get("https://api.github.com/user")
172                .header("Authorization", &format!("Bearer {access_token}"))
173                .header("Accept", "application/vnd.github.v3+json")
174                .call()?;
175            let body: serde_json::Value = serde_json::from_str(&resp.body_mut().read_to_string()?)?;
176            Ok(User {
177                id: body["id"].to_string(),
178                email: body["email"].as_str().unwrap_or("").into(),
179                name: body["login"].as_str().unwrap_or("").into(),
180                avatar_url: body["avatar_url"].as_str().map(|s| s.into()),
181                provider: provider.into(),
182            })
183        }
184        _ => Err("unsupported provider".into()),
185    }
186}
187
188pub struct AuthClient {
189    providers: HashMap<String, AuthConfig>,
190    user: Mutex<Option<User>>,
191    token_path: PathBuf,
192    vercel_url: String,
193}
194
195impl AuthClient {
196    pub fn new(providers: Vec<(String, AuthConfig)>) -> Self {
197        let token_path = dirs::data_dir()
198            .unwrap_or_else(|| PathBuf::from("."))
199            .join("santui")
200            .join("auth-tokens.json");
201        let user = Self::load_tokens(&token_path);
202        AuthClient {
203            providers: providers.into_iter().collect(),
204            user: Mutex::new(user),
205            token_path,
206            vercel_url: String::new(),
207        }
208    }
209
210    pub fn with_vercel(mut self, url: String) -> Self {
211        self.vercel_url = url;
212        self
213    }
214
215    fn load_tokens(path: &PathBuf) -> Option<User> {
216        let data = std::fs::read_to_string(path).ok()?;
217        let stored: StoredToken = serde_json::from_str(&data).ok()?;
218        Some(User {
219            id: stored.id,
220            email: stored.email,
221            name: stored.name,
222            avatar_url: stored.avatar_url,
223            provider: stored.provider,
224        })
225    }
226
227    fn save_tokens(&self, stored: &StoredToken) {
228        if let Some(parent) = self.token_path.parent() {
229            let _ = std::fs::create_dir_all(parent);
230        }
231        if let Ok(data) = serde_json::to_string_pretty(stored) {
232            let _ = std::fs::write(&self.token_path, data);
233        }
234    }
235
236    fn clear_tokens(&self) {
237        let _ = std::fs::remove_file(&self.token_path);
238    }
239
240    fn sign_in_google(&self) -> Result<User, Box<dyn std::error::Error>> {
241        let port = 9842;
242        let listener = TcpListener::bind(("127.0.0.1", port))?;
243
244        let vercel = if self.vercel_url.is_empty() {
245            "https://santuiapp.vercel.app".to_string()
246        } else {
247            self.vercel_url.clone()
248        };
249        let auth_url = format!("{vercel}/api/auth/google?port={port}");
250        open_browser(&auth_url);
251
252        let params = handle_redirect(listener)?;
253
254        let access_token = params
255            .get("access_token")
256            .ok_or_else(|| "No access_token in redirect".to_string())?;
257        let user = User {
258            id: params.get("id").cloned().unwrap_or_default(),
259            email: params.get("email").cloned().unwrap_or_default(),
260            name: params.get("name").cloned().unwrap_or_default(),
261            avatar_url: params.get("avatar_url").cloned(),
262            provider: "google".into(),
263        };
264
265        self.save_tokens(&StoredToken {
266            id: user.id.clone(),
267            email: user.email.clone(),
268            name: user.name.clone(),
269            avatar_url: user.avatar_url.clone(),
270            provider: user.provider.clone(),
271            access_token: access_token.clone(),
272            refresh_token: None,
273        });
274        *self.user.lock().unwrap_or_else(|e| e.into_inner()) = Some(user.clone());
275
276        Ok(user)
277    }
278
279    fn sign_in_github(&self) -> Result<User, Box<dyn std::error::Error>> {
280        let config = self
281            .providers
282            .get("github")
283            .ok_or_else(|| "GitHub auth not configured".to_string())?;
284
285        let device = request_device_code(config)?;
286        let user_code = &device.user_code;
287        let interval = device.interval.unwrap_or(5);
288
289        let activation_url = format!("https://github.com/login/device?user_code={user_code}");
290        open_browser(&activation_url);
291
292        let access_token = poll_device_token(config, &device.device_code, interval)?;
293        let user = user_from_token("github", &access_token)?;
294
295        self.save_tokens(&StoredToken {
296            id: user.id.clone(),
297            email: user.email.clone(),
298            name: user.name.clone(),
299            avatar_url: user.avatar_url.clone(),
300            provider: user.provider.clone(),
301            access_token,
302            refresh_token: None,
303        });
304        *self.user.lock().unwrap_or_else(|e| e.into_inner()) = Some(user.clone());
305
306        Ok(user)
307    }
308}
309
310impl AuthHandle for AuthClient {
311    fn current_user(&self) -> Option<User> {
312        self.user.lock().unwrap_or_else(|e| e.into_inner()).clone()
313    }
314
315    fn bearer_token(&self) -> Option<String> {
316        let data = std::fs::read_to_string(&self.token_path).ok()?;
317        let stored: StoredToken = serde_json::from_str(&data).ok()?;
318        Some(stored.access_token)
319    }
320
321    fn sign_in(&self, provider: &str) -> Result<User, Box<dyn std::error::Error>> {
322        match provider {
323            "google" => self.sign_in_google(),
324            "github" => self.sign_in_github(),
325            _ => Err("unsupported provider".into()),
326        }
327    }
328
329    fn sign_out(&self) {
330        self.clear_tokens();
331        *self.user.lock().unwrap_or_else(|e| e.into_inner()) = None;
332    }
333}