pub mod cookie;
pub mod email;
pub mod password;
pub use cookie::{extract_token as extract_session_cookie, CookieConfig, SameSite};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct AuthContext {
pub user_id: Option<String>,
pub is_admin: bool,
#[serde(default, skip_serializing_if = "is_false")]
pub is_guest: bool,
pub roles: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tenant_id: Option<String>,
}
fn is_false(b: &bool) -> bool {
!b
}
impl AuthContext {
pub fn anonymous() -> Self {
Self {
user_id: None,
is_admin: false,
is_guest: false,
roles: Vec::new(),
tenant_id: None,
}
}
pub fn authenticated(user_id: String) -> Self {
Self {
user_id: Some(user_id),
is_admin: false,
is_guest: false,
roles: Vec::new(),
tenant_id: None,
}
}
pub fn guest(guest_id: String) -> Self {
Self {
user_id: Some(guest_id),
is_admin: false,
is_guest: true,
roles: Vec::new(),
tenant_id: None,
}
}
pub fn admin() -> Self {
Self {
user_id: Some("__admin__".into()),
is_admin: true,
is_guest: false,
roles: vec!["admin".into()],
tenant_id: None,
}
}
pub fn user(user_id: String) -> Self {
Self::authenticated(user_id)
}
pub fn tenant_id(&self) -> Option<&str> {
self.tenant_id.as_deref()
}
pub fn with_tenant(mut self, tenant_id: String) -> Self {
self.tenant_id = Some(tenant_id);
self
}
pub fn is_authenticated(&self) -> bool {
self.user_id.is_some() && !self.is_guest
}
pub fn has_role(&self, role: &str) -> bool {
self.is_admin || self.roles.iter().any(|r| r == role)
}
pub fn has_any_role(&self, roles: &[&str]) -> bool {
self.is_admin || roles.iter().any(|r| self.has_role(r))
}
pub fn with_roles(mut self, roles: Vec<String>) -> Self {
self.roles = roles;
self
}
}
pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut result: u8 = 0;
for (x, y) in a.iter().zip(b.iter()) {
result |= x ^ y;
}
result == 0
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AuthMode {
Public,
User,
}
impl AuthMode {
#[allow(clippy::should_implement_trait)]
pub fn from_str(s: &str) -> Option<Self> {
match s {
"public" => Some(AuthMode::Public),
"user" => Some(AuthMode::User),
_ => None,
}
}
pub fn check(&self, ctx: &AuthContext) -> bool {
match self {
AuthMode::Public => true,
AuthMode::User => ctx.is_authenticated(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Session {
pub token: String,
pub user_id: String,
#[serde(default)]
pub expires_at: u64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub device: Option<String>,
#[serde(default)]
pub created_at: u64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tenant_id: Option<String>,
}
impl Session {
pub const DEFAULT_LIFETIME_SECS: u64 = 30 * 24 * 60 * 60;
pub fn new(user_id: String) -> Self {
let now = now_secs();
Self {
token: generate_token(),
user_id,
expires_at: now.saturating_add(Self::DEFAULT_LIFETIME_SECS),
device: None,
created_at: now,
tenant_id: None,
}
}
pub fn with_lifetime(user_id: String, lifetime_secs: u64) -> Self {
let now = now_secs();
Self {
token: generate_token(),
user_id,
expires_at: if lifetime_secs == 0 {
0
} else {
now.saturating_add(lifetime_secs)
},
device: None,
created_at: now,
tenant_id: None,
}
}
pub fn to_auth_context(&self) -> AuthContext {
let ctx = AuthContext::authenticated(self.user_id.clone());
match &self.tenant_id {
Some(t) => ctx.with_tenant(t.clone()),
None => ctx,
}
}
pub fn is_expired(&self) -> bool {
self.expires_at != 0 && now_secs() >= self.expires_at
}
}
fn now_secs() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthConfig {
pub provider: String,
pub client_id: String,
pub client_secret: String,
pub redirect_uri: String,
}
impl OAuthConfig {
pub fn auth_url(&self) -> String {
match self.provider.as_str() {
"google" => format!(
"https://accounts.google.com/o/oauth2/v2/auth?client_id={}&redirect_uri={}&response_type=code&scope=openid%20email%20profile",
self.client_id, self.redirect_uri
),
"github" => format!(
"https://github.com/login/oauth/authorize?client_id={}&redirect_uri={}&scope=user:email",
self.client_id, self.redirect_uri
),
_ => String::new(),
}
}
pub fn auth_url_with_state(&self, state: &str) -> String {
let base = self.auth_url();
if base.is_empty() {
return base;
}
format!("{}&state={}", base, state)
}
pub fn token_url(&self) -> &str {
match self.provider.as_str() {
"google" => "https://oauth2.googleapis.com/token",
"github" => "https://github.com/login/oauth/access_token",
_ => "",
}
}
pub fn userinfo_url(&self) -> &str {
match self.provider.as_str() {
"google" => "https://www.googleapis.com/oauth2/v3/userinfo",
"github" => "https://api.github.com/user",
_ => "",
}
}
pub fn exchange_code_full(&self, code: &str) -> Result<TokenSet, String> {
let body = match self.provider.as_str() {
"google" => format!(
"code={code}&client_id={}&client_secret={}&redirect_uri={}&grant_type=authorization_code",
url_encode(&self.client_id),
url_encode(&self.client_secret),
url_encode(&self.redirect_uri)
),
"github" => format!(
"code={code}&client_id={}&client_secret={}&redirect_uri={}",
url_encode(&self.client_id),
url_encode(&self.client_secret),
url_encode(&self.redirect_uri)
),
_ => return Err(format!("unknown OAuth provider: {}", self.provider)),
};
let out = http_post_form(self.token_url(), &body, self.provider.as_str() == "github")?;
parse_token_response(&out)
}
pub fn exchange_code(&self, code: &str) -> Result<String, String> {
let body = match self.provider.as_str() {
"google" => format!(
"code={code}&client_id={}&client_secret={}&redirect_uri={}&grant_type=authorization_code",
url_encode(&self.client_id),
url_encode(&self.client_secret),
url_encode(&self.redirect_uri)
),
"github" => format!(
"code={code}&client_id={}&client_secret={}&redirect_uri={}",
url_encode(&self.client_id),
url_encode(&self.client_secret),
url_encode(&self.redirect_uri)
),
_ => return Err(format!("unknown OAuth provider: {}", self.provider)),
};
let out = http_post_form(self.token_url(), &body, self.provider.as_str() == "github")?;
extract_access_token(&out)
}
pub fn fetch_userinfo(&self, access_token: &str) -> Result<(String, Option<String>), String> {
let info = self.fetch_userinfo_full(access_token)?;
Ok((info.email, info.name))
}
pub fn fetch_userinfo_full(&self, access_token: &str) -> Result<UserInfo, String> {
let out = http_get_bearer(self.userinfo_url(), access_token)?;
let parsed: serde_json::Value =
serde_json::from_str(&out).map_err(|e| format!("userinfo not valid JSON: {e}"))?;
match self.provider.as_str() {
"google" => {
let email = parsed
.get("email")
.and_then(|v| v.as_str())
.ok_or("no email in userinfo")?
.to_string();
let name = parsed
.get("name")
.and_then(|v| v.as_str())
.map(String::from);
let provider_account_id = parsed
.get("sub")
.and_then(|v| v.as_str())
.ok_or("no sub in userinfo")?
.to_string();
Ok(UserInfo {
provider: self.provider.clone(),
provider_account_id,
email,
name,
})
}
"github" => {
let name = parsed
.get("name")
.and_then(|v| v.as_str())
.or_else(|| parsed.get("login").and_then(|v| v.as_str()))
.map(String::from);
let email = parsed
.get("email")
.and_then(|v| v.as_str())
.map(String::from);
let email = email
.or_else(|| fetch_github_primary_email(access_token).ok())
.ok_or("no accessible email on GitHub account")?;
let provider_account_id = parsed
.get("id")
.map(|v| {
v.as_i64()
.map(|n| n.to_string())
.or_else(|| v.as_str().map(String::from))
.unwrap_or_default()
})
.filter(|s| !s.is_empty())
.ok_or("no id in userinfo")?;
Ok(UserInfo {
provider: self.provider.clone(),
provider_account_id,
email,
name,
})
}
_ => Err(format!("unknown provider: {}", self.provider)),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct UserInfo {
pub provider: String,
pub provider_account_id: String,
pub email: String,
pub name: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TokenSet {
pub access_token: String,
pub refresh_token: Option<String>,
pub id_token: Option<String>,
pub expires_at: Option<u64>,
pub scope: Option<String>,
}
fn parse_token_response(body: &str) -> Result<TokenSet, String> {
let json: serde_json::Value = serde_json::from_str(body).unwrap_or_else(|_| {
let mut map = serde_json::Map::new();
for pair in body.split('&') {
if let Some((k, v)) = pair.split_once('=') {
map.insert(k.to_string(), serde_json::Value::String(v.to_string()));
}
}
serde_json::Value::Object(map)
});
let access_token = json
.get("access_token")
.and_then(|v| v.as_str())
.ok_or_else(|| format!("no access_token in token response: {body}"))?
.to_string();
let refresh_token = json
.get("refresh_token")
.and_then(|v| v.as_str())
.map(String::from);
let id_token = json
.get("id_token")
.and_then(|v| v.as_str())
.map(String::from);
let expires_at = json
.get("expires_in")
.and_then(|v| {
v.as_u64()
.or_else(|| v.as_str().and_then(|s| s.parse().ok()))
})
.map(|secs| now_secs().saturating_add(secs));
let scope = json.get("scope").and_then(|v| v.as_str()).map(String::from);
Ok(TokenSet {
access_token,
refresh_token,
id_token,
expires_at,
scope,
})
}
fn url_encode(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for b in s.bytes() {
match b {
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
out.push(b as char)
}
_ => out.push_str(&format!("%{b:02X}")),
}
}
out
}
const HTTP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
fn ureq_agent() -> ureq::Agent {
ureq::AgentBuilder::new()
.timeout_connect(HTTP_TIMEOUT)
.timeout_read(HTTP_TIMEOUT)
.timeout_write(HTTP_TIMEOUT)
.user_agent("pylon/0.1")
.build()
}
fn http_post_form(url: &str, body: &str, accept_json: bool) -> Result<String, String> {
let agent = ureq_agent();
let mut req = agent
.post(url)
.set("Content-Type", "application/x-www-form-urlencoded");
if accept_json {
req = req.set("Accept", "application/json");
}
match req.send_string(body) {
Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
Err(ureq::Error::Status(code, resp)) => {
let body = resp.into_string().unwrap_or_default();
Err(format!("HTTP {code}: {body}"))
}
Err(e) => Err(format!("HTTP error: {e}")),
}
}
fn http_get_bearer(url: &str, token: &str) -> Result<String, String> {
let agent = ureq_agent();
match agent
.get(url)
.set("Authorization", &format!("Bearer {token}"))
.set("Accept", "application/json")
.call()
{
Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
Err(ureq::Error::Status(code, resp)) => {
let body = resp.into_string().unwrap_or_default();
Err(format!("HTTP {code}: {body}"))
}
Err(e) => Err(format!("HTTP error: {e}")),
}
}
fn fetch_github_primary_email(token: &str) -> Result<String, String> {
let out = http_get_bearer("https://api.github.com/user/emails", token)?;
let emails: serde_json::Value =
serde_json::from_str(&out).map_err(|e| format!("emails not JSON: {e}"))?;
emails
.as_array()
.and_then(|arr| {
arr.iter()
.find(|e| {
e.get("primary").and_then(|v| v.as_bool()).unwrap_or(false)
&& e.get("verified").and_then(|v| v.as_bool()).unwrap_or(false)
})
.and_then(|e| e.get("email").and_then(|v| v.as_str()).map(String::from))
})
.ok_or_else(|| "no primary verified email on GitHub".into())
}
fn extract_access_token(body: &str) -> Result<String, String> {
if let Ok(json) = serde_json::from_str::<serde_json::Value>(body) {
if let Some(t) = json.get("access_token").and_then(|v| v.as_str()) {
return Ok(t.to_string());
}
}
for pair in body.split('&') {
if let Some(val) = pair.strip_prefix("access_token=") {
return Ok(val.to_string());
}
}
Err(format!("no access_token in token response: {body}"))
}
pub struct OAuthRegistry {
providers: std::collections::HashMap<String, OAuthConfig>,
}
impl Default for OAuthRegistry {
fn default() -> Self {
Self::new()
}
}
impl OAuthRegistry {
pub fn new() -> Self {
Self {
providers: std::collections::HashMap::new(),
}
}
pub fn register(&mut self, config: OAuthConfig) {
self.providers.insert(config.provider.clone(), config);
}
pub fn get(&self, provider: &str) -> Option<&OAuthConfig> {
self.providers.get(provider)
}
pub fn from_env() -> Self {
let mut reg = Self::new();
if let (Ok(id), Ok(secret)) = (
std::env::var("PYLON_OAUTH_GOOGLE_CLIENT_ID"),
std::env::var("PYLON_OAUTH_GOOGLE_CLIENT_SECRET"),
) {
reg.register(OAuthConfig {
provider: "google".into(),
client_id: id,
client_secret: secret,
redirect_uri: std::env::var("PYLON_OAUTH_GOOGLE_REDIRECT")
.unwrap_or_else(|_| "http://localhost:3000/api/auth/callback/google".into()),
});
}
if let (Ok(id), Ok(secret)) = (
std::env::var("PYLON_OAUTH_GITHUB_CLIENT_ID"),
std::env::var("PYLON_OAUTH_GITHUB_CLIENT_SECRET"),
) {
reg.register(OAuthConfig {
provider: "github".into(),
client_id: id,
client_secret: secret,
redirect_uri: std::env::var("PYLON_OAUTH_GITHUB_REDIRECT")
.unwrap_or_else(|_| "http://localhost:3000/api/auth/callback/github".into()),
});
}
reg
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OAuthState {
pub provider: String,
pub callback_url: String,
pub error_callback_url: String,
pub expires_at: u64,
}
pub trait OAuthStateBackend: Send + Sync {
fn put(&self, token: &str, state: &OAuthState);
fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState>;
}
pub struct InMemoryOAuthBackend {
states: Mutex<HashMap<String, OAuthState>>,
}
impl InMemoryOAuthBackend {
pub fn new() -> Self {
Self {
states: Mutex::new(HashMap::new()),
}
}
}
impl Default for InMemoryOAuthBackend {
fn default() -> Self {
Self::new()
}
}
impl OAuthStateBackend for InMemoryOAuthBackend {
fn put(&self, token: &str, state: &OAuthState) {
self.states
.lock()
.unwrap()
.insert(token.to_string(), state.clone());
}
fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState> {
let mut s = self.states.lock().unwrap();
let entry = s.remove(token)?;
if entry.expires_at <= now_unix_secs {
return None;
}
Some(entry)
}
}
pub struct OAuthStateStore {
backend: Box<dyn OAuthStateBackend>,
}
impl Default for OAuthStateStore {
fn default() -> Self {
Self::new()
}
}
impl OAuthStateStore {
pub fn new() -> Self {
Self {
backend: Box::new(InMemoryOAuthBackend::new()),
}
}
pub fn with_backend(backend: Box<dyn OAuthStateBackend>) -> Self {
Self { backend }
}
pub fn create(&self, provider: &str, callback_url: &str, error_callback_url: &str) -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let token = generate_token();
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let state = OAuthState {
provider: provider.to_string(),
callback_url: callback_url.to_string(),
error_callback_url: error_callback_url.to_string(),
expires_at: now + 600,
};
self.backend.put(&token, &state);
token
}
pub fn validate(&self, state: &str, expected_provider: &str) -> Option<OAuthState> {
use std::time::{SystemTime, UNIX_EPOCH};
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let entry = self.backend.take(state, now)?;
if entry.provider != expected_provider {
return None;
}
Some(entry)
}
}
pub fn validate_trusted_redirect(
url: &str,
trusted_origins: &[String],
) -> Result<(), TrustedOriginError> {
if url.is_empty() {
return Err(TrustedOriginError::Empty);
}
if !url.starts_with("http://") && !url.starts_with("https://") {
return Err(TrustedOriginError::NotHttp);
}
let url_origin = origin_of(url);
if trusted_origins.iter().any(|t| t == &url_origin) {
Ok(())
} else {
Err(TrustedOriginError::NotTrusted { origin: url_origin })
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TrustedOriginError {
Empty,
NotHttp,
NotTrusted { origin: String },
}
impl std::fmt::Display for TrustedOriginError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TrustedOriginError::Empty => write!(f, "redirect URL is empty"),
TrustedOriginError::NotHttp => {
write!(f, "redirect URL must use http:// or https:// scheme")
}
TrustedOriginError::NotTrusted { origin } => write!(
f,
"redirect origin {origin:?} is not in PYLON_TRUSTED_ORIGINS"
),
}
}
}
pub fn origin_of(url: &str) -> String {
let after_scheme = match url.find("://") {
Some(i) => i + 3,
None => return url.trim_end_matches('/').to_string(),
};
let rest = &url[after_scheme..];
let cut = rest
.find(|c: char| c == '/' || c == '?' || c == '#')
.unwrap_or(rest.len());
url[..after_scheme + cut].to_string()
}
pub trait MagicCodeBackend: Send + Sync {
fn put(&self, email: &str, code: &MagicCode);
fn get(&self, email: &str) -> Option<MagicCode>;
fn remove(&self, email: &str);
fn bump_attempts(&self, email: &str);
fn load_all(&self) -> Vec<MagicCode>;
}
pub struct InMemoryMagicCodeBackend {
codes: Mutex<HashMap<String, MagicCode>>,
}
impl InMemoryMagicCodeBackend {
pub fn new() -> Self {
Self {
codes: Mutex::new(HashMap::new()),
}
}
}
impl Default for InMemoryMagicCodeBackend {
fn default() -> Self {
Self::new()
}
}
impl MagicCodeBackend for InMemoryMagicCodeBackend {
fn put(&self, email: &str, code: &MagicCode) {
self.codes
.lock()
.unwrap()
.insert(email.to_string(), code.clone());
}
fn get(&self, email: &str) -> Option<MagicCode> {
self.codes.lock().unwrap().get(email).cloned()
}
fn remove(&self, email: &str) {
self.codes.lock().unwrap().remove(email);
}
fn bump_attempts(&self, email: &str) {
if let Some(c) = self.codes.lock().unwrap().get_mut(email) {
c.attempts = c.attempts.saturating_add(1);
}
}
fn load_all(&self) -> Vec<MagicCode> {
self.codes.lock().unwrap().values().cloned().collect()
}
}
pub struct MagicCodeStore {
cache: Mutex<HashMap<String, MagicCode>>,
backend: Box<dyn MagicCodeBackend>,
}
#[derive(Debug, Clone)]
pub struct MagicCode {
pub email: String,
pub code: String,
pub expires_at: u64,
pub attempts: u32,
}
const MAX_ATTEMPTS: u32 = 5;
const CREATE_COOLDOWN_SECS: u64 = 60;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MagicCodeError {
NotFound,
TooManyAttempts,
BadCode,
Expired,
Throttled { retry_after_secs: u64 },
}
impl Default for MagicCodeStore {
fn default() -> Self {
Self::new()
}
}
impl MagicCodeStore {
pub fn new() -> Self {
Self::with_backend(Box::new(InMemoryMagicCodeBackend::new()))
}
pub fn with_backend(backend: Box<dyn MagicCodeBackend>) -> Self {
let now = now_secs();
let mut cache = HashMap::new();
for c in backend.load_all() {
if c.expires_at > now {
cache.insert(c.email.clone(), c);
}
}
Self {
cache: Mutex::new(cache),
backend,
}
}
pub fn create(&self, email: &str) -> String {
self.try_create(email).unwrap_or_else(|_| String::new())
}
pub fn try_create(&self, email: &str) -> Result<String, MagicCodeError> {
let now = now_secs();
let mut codes = self.cache.lock().unwrap();
if let Some(existing) = codes.get(email) {
if existing.expires_at > now {
let created_at = existing.expires_at.saturating_sub(600);
let age = now.saturating_sub(created_at);
if age < CREATE_COOLDOWN_SECS {
return Err(MagicCodeError::Throttled {
retry_after_secs: CREATE_COOLDOWN_SECS - age,
});
}
}
}
let code = generate_magic_code();
let mc = MagicCode {
email: email.to_string(),
code: code.clone(),
expires_at: now + 600, attempts: 0,
};
codes.insert(email.to_string(), mc.clone());
self.backend.put(email, &mc);
Ok(code)
}
pub fn verify(&self, email: &str, code: &str) -> bool {
matches!(self.try_verify(email, code), Ok(()))
}
pub fn list_all_unfiltered(&self) -> Vec<MagicCode> {
self.cache
.lock()
.map(|m| m.values().cloned().collect())
.unwrap_or_default()
}
pub fn try_verify(&self, email: &str, code: &str) -> Result<(), MagicCodeError> {
let now = now_secs();
let mut codes = self.cache.lock().unwrap();
let mc = match codes.get_mut(email) {
Some(m) => m,
None => return Err(MagicCodeError::NotFound),
};
if mc.attempts >= MAX_ATTEMPTS {
return Err(MagicCodeError::TooManyAttempts);
}
if mc.expires_at <= now {
codes.remove(email);
self.backend.remove(email);
return Err(MagicCodeError::Expired);
}
let ok = constant_time_eq(mc.code.as_bytes(), code.as_bytes());
if !ok {
mc.attempts += 1;
self.backend.bump_attempts(email);
if mc.attempts >= MAX_ATTEMPTS {
return Err(MagicCodeError::TooManyAttempts);
}
return Err(MagicCodeError::BadCode);
}
codes.remove(email);
self.backend.remove(email);
Ok(())
}
}
fn hex_encode(bytes: &[u8]) -> String {
bytes.iter().map(|b| format!("{:02x}", b)).collect()
}
fn generate_magic_code() -> String {
use rand::Rng;
let mut rng = rand::thread_rng();
let code: u32 = rng.gen_range(0..1_000_000);
format!("{:06}", code)
}
fn generate_token() -> String {
use rand::Rng;
let mut rng = rand::thread_rng();
let bytes: [u8; 32] = rng.gen();
format!("pylon_{}", hex_encode(&bytes))
}
use std::collections::HashMap;
use std::sync::Mutex;
pub trait SessionBackend: Send + Sync {
fn load_all(&self) -> Vec<Session>;
fn save(&self, session: &Session);
fn remove(&self, token: &str);
}
pub struct SessionStore {
sessions: Mutex<HashMap<String, Session>>,
backend: Option<Box<dyn SessionBackend>>,
default_lifetime_secs: u64,
}
impl Default for SessionStore {
fn default() -> Self {
Self::new()
}
}
impl SessionStore {
pub fn new() -> Self {
Self {
sessions: Mutex::new(HashMap::new()),
backend: None,
default_lifetime_secs: Session::DEFAULT_LIFETIME_SECS,
}
}
pub fn with_lifetime(mut self, lifetime_secs: u64) -> Self {
self.default_lifetime_secs = lifetime_secs;
self
}
pub fn with_backend(backend: Box<dyn SessionBackend>) -> Self {
let mut map = HashMap::new();
for s in backend.load_all() {
if !s.is_expired() {
map.insert(s.token.clone(), s);
}
}
Self {
sessions: Mutex::new(map),
backend: Some(backend),
default_lifetime_secs: Session::DEFAULT_LIFETIME_SECS,
}
}
pub fn create(&self, user_id: String) -> Session {
let session = Session::with_lifetime(user_id, self.default_lifetime_secs);
let mut sessions = self.sessions.lock().unwrap();
sessions.insert(session.token.clone(), session.clone());
if let Some(b) = &self.backend {
b.save(&session);
}
session
}
pub fn get(&self, token: &str) -> Option<Session> {
let mut sessions = self.sessions.lock().unwrap();
match sessions.get(token) {
Some(s) if s.is_expired() => {
sessions.remove(token);
None
}
Some(s) => Some(s.clone()),
None => None,
}
}
pub fn resolve(&self, token: Option<&str>) -> AuthContext {
match token {
Some(t) => match self.get(t) {
Some(session) => session.to_auth_context(),
None => AuthContext::anonymous(),
},
None => AuthContext::anonymous(),
}
}
pub fn refresh(&self, old_token: &str) -> Option<Session> {
let mut sessions = self.sessions.lock().unwrap();
let old = sessions.remove(old_token)?;
if let Some(b) = &self.backend {
b.remove(old_token);
}
if old.is_expired() {
return None;
}
let mut new = Session::with_lifetime(old.user_id.clone(), self.default_lifetime_secs);
new.device = old.device.clone();
sessions.insert(new.token.clone(), new.clone());
if let Some(b) = &self.backend {
b.save(&new);
}
Some(new)
}
pub fn list_all_unfiltered(&self) -> Vec<Session> {
self.sessions
.lock()
.map(|m| m.values().cloned().collect())
.unwrap_or_default()
}
pub fn list_for_user(&self, user_id: &str) -> Vec<Session> {
let sessions = self.sessions.lock().unwrap();
sessions
.values()
.filter(|s| s.user_id == user_id && !s.is_expired())
.cloned()
.collect()
}
pub fn revoke_all_for_user(&self, user_id: &str) -> usize {
let mut sessions = self.sessions.lock().unwrap();
let tokens: Vec<String> = sessions
.iter()
.filter_map(|(t, s)| {
if s.user_id == user_id {
Some(t.clone())
} else {
None
}
})
.collect();
let n = tokens.len();
for t in &tokens {
sessions.remove(t);
if let Some(b) = &self.backend {
b.remove(t);
}
}
n
}
pub fn sweep_expired(&self) -> usize {
let mut sessions = self.sessions.lock().unwrap();
let expired: Vec<String> = sessions
.iter()
.filter_map(|(t, s)| {
if s.is_expired() {
Some(t.clone())
} else {
None
}
})
.collect();
let n = expired.len();
for t in &expired {
sessions.remove(t);
if let Some(b) = &self.backend {
b.remove(t);
}
}
n
}
pub fn set_device(&self, token: &str, device: String) -> bool {
let mut sessions = self.sessions.lock().unwrap();
if let Some(s) = sessions.get_mut(token) {
s.device = Some(device);
if let Some(b) = &self.backend {
b.save(s);
}
true
} else {
false
}
}
pub fn create_guest(&self) -> Session {
use rand::Rng;
let mut rng = rand::thread_rng();
let bytes: [u8; 16] = rng.gen();
let guest_id = format!("guest_{}", hex_encode(&bytes));
self.create(guest_id)
}
pub fn upgrade(&self, token: &str, real_user_id: String) -> bool {
let mut sessions = self.sessions.lock().unwrap();
if let Some(session) = sessions.get_mut(token) {
session.user_id = real_user_id;
if let Some(b) = &self.backend {
b.save(session);
}
true
} else {
false
}
}
pub fn set_tenant(&self, token: &str, tenant_id: Option<String>) -> bool {
let mut sessions = self.sessions.lock().unwrap();
if let Some(session) = sessions.get_mut(token) {
session.tenant_id = tenant_id;
if let Some(b) = &self.backend {
b.save(session);
}
true
} else {
false
}
}
pub fn revoke(&self, token: &str) -> bool {
let mut sessions = self.sessions.lock().unwrap();
let removed = sessions.remove(token).is_some();
if removed {
if let Some(b) = &self.backend {
b.remove(token);
}
}
removed
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Account {
pub id: String,
pub user_id: String,
pub provider_id: String,
pub account_id: String,
pub access_token: Option<String>,
pub refresh_token: Option<String>,
pub id_token: Option<String>,
pub access_token_expires_at: Option<u64>,
pub refresh_token_expires_at: Option<u64>,
pub scope: Option<String>,
pub password: Option<String>,
pub created_at: u64,
pub updated_at: u64,
}
impl Account {
pub fn new(user_id: String, info: &UserInfo, tokens: &TokenSet) -> Self {
let now = now_secs();
Self {
id: generate_token(),
user_id,
provider_id: info.provider.clone(),
account_id: info.provider_account_id.clone(),
access_token: Some(tokens.access_token.clone()),
refresh_token: tokens.refresh_token.clone(),
id_token: tokens.id_token.clone(),
access_token_expires_at: tokens.expires_at,
refresh_token_expires_at: None,
scope: tokens.scope.clone(),
password: None,
created_at: now,
updated_at: now,
}
}
pub fn access_token_expired(&self) -> bool {
match self.access_token_expires_at {
Some(ts) => now_secs() >= ts,
None => false,
}
}
}
pub trait AccountBackend: Send + Sync {
fn upsert(&self, account: &Account);
fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account>;
fn find_for_user(&self, user_id: &str) -> Vec<Account>;
fn unlink(&self, provider_id: &str, account_id: &str) -> bool;
fn list_all(&self) -> Vec<Account>;
}
pub struct InMemoryAccountBackend {
accounts: Mutex<HashMap<(String, String), Account>>,
}
impl InMemoryAccountBackend {
pub fn new() -> Self {
Self {
accounts: Mutex::new(HashMap::new()),
}
}
}
impl Default for InMemoryAccountBackend {
fn default() -> Self {
Self::new()
}
}
impl AccountBackend for InMemoryAccountBackend {
fn upsert(&self, account: &Account) {
let key = (account.provider_id.clone(), account.account_id.clone());
self.accounts.lock().unwrap().insert(key, account.clone());
}
fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
self.accounts
.lock()
.unwrap()
.get(&(provider_id.to_string(), account_id.to_string()))
.cloned()
}
fn find_for_user(&self, user_id: &str) -> Vec<Account> {
self.accounts
.lock()
.unwrap()
.values()
.filter(|a| a.user_id == user_id)
.cloned()
.collect()
}
fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
self.accounts
.lock()
.unwrap()
.remove(&(provider_id.to_string(), account_id.to_string()))
.is_some()
}
fn list_all(&self) -> Vec<Account> {
self.accounts.lock().unwrap().values().cloned().collect()
}
}
pub struct AccountStore {
backend: Box<dyn AccountBackend>,
}
impl Default for AccountStore {
fn default() -> Self {
Self::new()
}
}
impl AccountStore {
pub fn new() -> Self {
Self {
backend: Box::new(InMemoryAccountBackend::new()),
}
}
pub fn with_backend(backend: Box<dyn AccountBackend>) -> Self {
Self { backend }
}
pub fn upsert(&self, account: &Account) {
self.backend.upsert(account);
}
pub fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
self.backend.find_by_provider(provider_id, account_id)
}
pub fn find_for_user(&self, user_id: &str) -> Vec<Account> {
self.backend.find_for_user(user_id)
}
pub fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
self.backend.unlink(provider_id, account_id)
}
pub fn list_all_unfiltered(&self) -> Vec<Account> {
self.backend.list_all()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn anonymous_context() {
let ctx = AuthContext::anonymous();
assert!(!ctx.is_authenticated());
assert!(ctx.user_id.is_none());
}
#[test]
fn authenticated_context() {
let ctx = AuthContext::authenticated("user-1".into());
assert!(ctx.is_authenticated());
assert_eq!(ctx.user_id, Some("user-1".into()));
}
#[test]
fn auth_mode_public_allows_anonymous() {
let mode = AuthMode::Public;
assert!(mode.check(&AuthContext::anonymous()));
assert!(mode.check(&AuthContext::authenticated("user-1".into())));
}
#[test]
fn auth_mode_user_requires_authenticated() {
let mode = AuthMode::User;
assert!(!mode.check(&AuthContext::anonymous()));
assert!(mode.check(&AuthContext::authenticated("user-1".into())));
}
#[test]
fn auth_mode_from_str() {
assert_eq!(AuthMode::from_str("public"), Some(AuthMode::Public));
assert_eq!(AuthMode::from_str("user"), Some(AuthMode::User));
assert_eq!(AuthMode::from_str("admin"), None);
}
#[test]
fn session_store_create_and_get() {
let store = SessionStore::new();
let session = store.create("user-1".into());
assert!(!session.token.is_empty());
assert!(session.token.starts_with("pylon_"));
let retrieved = store.get(&session.token).unwrap();
assert_eq!(retrieved.user_id, "user-1");
}
#[test]
fn session_store_resolve() {
let store = SessionStore::new();
let session = store.create("user-1".into());
let ctx = store.resolve(Some(&session.token));
assert!(ctx.is_authenticated());
assert_eq!(ctx.user_id, Some("user-1".into()));
let anon = store.resolve(None);
assert!(!anon.is_authenticated());
let bad = store.resolve(Some("invalid-token"));
assert!(!bad.is_authenticated());
}
#[test]
fn session_store_revoke() {
let store = SessionStore::new();
let session = store.create("user-1".into());
assert!(store.revoke(&session.token));
assert!(store.get(&session.token).is_none());
assert!(!store.revoke(&session.token)); }
#[test]
fn session_to_auth_context() {
let session = Session::new("user-42".into());
let ctx = session.to_auth_context();
assert_eq!(ctx.user_id, Some("user-42".into()));
}
#[test]
fn admin_context() {
let ctx = AuthContext::admin();
assert!(ctx.is_admin);
assert!(ctx.is_authenticated());
}
#[test]
fn anonymous_not_admin() {
let ctx = AuthContext::anonymous();
assert!(!ctx.is_admin);
}
#[test]
fn authenticated_not_admin() {
let ctx = AuthContext::authenticated("user-1".into());
assert!(!ctx.is_admin);
}
#[test]
fn magic_code_create_and_verify() {
let store = MagicCodeStore::new();
let code = store.create("test@example.com");
assert_eq!(code.len(), 6);
assert!(store.verify("test@example.com", &code));
}
#[test]
fn magic_code_wrong_code_rejected() {
let store = MagicCodeStore::new();
store.create("test@example.com");
assert!(!store.verify("test@example.com", "000000"));
}
#[test]
fn magic_code_wrong_email_rejected() {
let store = MagicCodeStore::new();
let code = store.create("test@example.com");
assert!(!store.verify("other@example.com", &code));
}
#[test]
fn magic_code_consumed_after_verify() {
let store = MagicCodeStore::new();
let code = store.create("test@example.com");
assert!(store.verify("test@example.com", &code));
assert!(!store.verify("test@example.com", &code));
}
#[test]
fn magic_code_different_emails_independent() {
let store = MagicCodeStore::new();
let code1 = store.create("alice@example.com");
let code2 = store.create("bob@example.com");
assert!(store.verify("alice@example.com", &code1));
assert!(store.verify("bob@example.com", &code2));
}
#[test]
fn constant_time_eq_equal() {
assert!(constant_time_eq(b"hello", b"hello"));
assert!(constant_time_eq(b"", b""));
}
#[test]
fn constant_time_eq_not_equal() {
assert!(!constant_time_eq(b"hello", b"world"));
assert!(!constant_time_eq(b"hello", b"hell"));
assert!(!constant_time_eq(b"a", b"b"));
}
#[test]
fn generated_tokens_are_unique() {
let t1 = generate_token();
let t2 = generate_token();
assert_ne!(t1, t2);
assert!(t1.starts_with("pylon_"));
assert!(t2.starts_with("pylon_"));
assert_eq!(t1.len(), 6 + 64);
}
#[test]
fn oauth_registry_empty() {
let reg = OAuthRegistry::new();
assert!(reg.get("google").is_none());
}
#[test]
fn oauth_registry_register_and_get() {
let mut reg = OAuthRegistry::new();
reg.register(OAuthConfig {
provider: "google".into(),
client_id: "test-id".into(),
client_secret: "test-secret".into(),
redirect_uri: "http://localhost/callback".into(),
});
let config = reg.get("google").unwrap();
assert_eq!(config.client_id, "test-id");
assert!(config.auth_url().contains("accounts.google.com"));
}
#[test]
fn guest_session() {
let store = SessionStore::new();
let session = store.create_guest();
assert!(session.user_id.starts_with("guest_"));
assert!(!session.token.is_empty());
let ctx = store.resolve(Some(&session.token));
assert!(ctx.is_authenticated());
assert!(ctx.user_id.unwrap().starts_with("guest_"));
}
#[test]
fn upgrade_guest_to_real_user() {
let store = SessionStore::new();
let session = store.create_guest();
assert!(session.user_id.starts_with("guest_"));
let upgraded = store.upgrade(&session.token, "real-user-123".into());
assert!(upgraded);
let ctx = store.resolve(Some(&session.token));
assert_eq!(ctx.user_id, Some("real-user-123".into()));
}
#[test]
fn upgrade_invalid_token_fails() {
let store = SessionStore::new();
let upgraded = store.upgrade("nonexistent-token", "user".into());
assert!(!upgraded);
}
#[test]
fn guest_context() {
let ctx = AuthContext::guest("guest_123".into());
assert!(!ctx.is_authenticated());
assert!(ctx.is_guest);
assert!(!ctx.is_admin);
assert_eq!(ctx.user_id, Some("guest_123".into()));
assert!(!AuthMode::User.check(&ctx));
assert!(AuthMode::Public.check(&ctx));
}
#[test]
fn oauth_token_urls() {
let google = OAuthConfig {
provider: "google".into(),
client_id: "x".into(),
client_secret: "x".into(),
redirect_uri: "x".into(),
};
assert_eq!(google.token_url(), "https://oauth2.googleapis.com/token");
let github = OAuthConfig {
provider: "github".into(),
client_id: "x".into(),
client_secret: "x".into(),
redirect_uri: "x".into(),
};
assert_eq!(
github.token_url(),
"https://github.com/login/oauth/access_token"
);
let unknown = OAuthConfig {
provider: "unknown".into(),
client_id: "x".into(),
client_secret: "x".into(),
redirect_uri: "x".into(),
};
assert_eq!(unknown.token_url(), "");
assert!(unknown.auth_url().is_empty());
}
#[test]
fn oauth_auth_url_github() {
let config = OAuthConfig {
provider: "github".into(),
client_id: "gh-id".into(),
client_secret: "gh-secret".into(),
redirect_uri: "http://localhost/cb".into(),
};
assert!(config.auth_url().contains("github.com"));
assert!(config.auth_url().contains("gh-id"));
}
#[test]
fn oauth_auth_url_with_state() {
let config = OAuthConfig {
provider: "google".into(),
client_id: "test-id".into(),
client_secret: "test-secret".into(),
redirect_uri: "http://localhost/cb".into(),
};
let url = config.auth_url_with_state("random_state_123");
assert!(url.contains("&state=random_state_123"));
}
#[test]
fn oauth_state_store_create_and_validate() {
let store = OAuthStateStore::new();
let token = store.create("google", "https://app/cb", "https://app/login");
let rec = store.validate(&token, "google").expect("valid first time");
assert_eq!(rec.callback_url, "https://app/cb");
assert_eq!(rec.error_callback_url, "https://app/login");
assert!(store.validate(&token, "google").is_none());
}
#[test]
fn oauth_state_store_wrong_provider_rejected() {
let store = OAuthStateStore::new();
let token = store.create("google", "https://app/cb", "https://app/cb");
assert!(store.validate(&token, "github").is_none());
}
#[test]
fn oauth_state_store_invalid_state_rejected() {
let store = OAuthStateStore::new();
assert!(store.validate("nonexistent", "google").is_none());
}
#[test]
fn validate_trusted_redirect_basics() {
let trusted = vec!["http://localhost:3000".to_string()];
assert!(validate_trusted_redirect("http://localhost:3000/dashboard", &trusted).is_ok());
assert!(validate_trusted_redirect("http://localhost:3000", &trusted).is_ok());
assert!(validate_trusted_redirect("http://localhost:3000/x?y=1", &trusted).is_ok());
assert!(matches!(
validate_trusted_redirect("http://localhost:4321/dashboard", &trusted),
Err(TrustedOriginError::NotTrusted { .. })
));
assert!(matches!(
validate_trusted_redirect("javascript:alert(1)", &trusted),
Err(TrustedOriginError::NotHttp)
));
assert!(matches!(
validate_trusted_redirect("", &trusted),
Err(TrustedOriginError::Empty)
));
}
}