pub mod api_key;
pub mod apple_jwt;
pub mod captcha;
pub mod cookie;
pub mod email;
pub mod jwt;
pub mod oidc_provider;
pub mod org;
pub mod password;
pub mod phone;
pub mod provider;
pub mod scim;
pub mod siwe;
pub mod stripe;
pub mod totp;
pub mod webauthn;
pub use cookie::{extract_token as extract_session_cookie, CookieConfig, SameSite};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct AuthContext {
pub user_id: Option<String>,
pub is_admin: bool,
#[serde(default, skip_serializing_if = "is_false")]
pub is_guest: bool,
pub roles: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tenant_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub api_key_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub api_key_scopes: Option<String>,
}
fn is_false(b: &bool) -> bool {
!b
}
impl AuthContext {
pub fn anonymous() -> Self {
Self {
user_id: None,
is_admin: false,
is_guest: false,
roles: Vec::new(),
tenant_id: None,
api_key_id: None,
api_key_scopes: None,
}
}
pub fn authenticated(user_id: String) -> Self {
Self {
user_id: Some(user_id),
is_admin: false,
is_guest: false,
roles: Vec::new(),
tenant_id: None,
api_key_id: None,
api_key_scopes: None,
}
}
pub fn from_api_key(user_id: String, key_id: String, scopes: Option<String>) -> Self {
Self {
user_id: Some(user_id),
is_admin: false,
is_guest: false,
roles: Vec::new(),
tenant_id: None,
api_key_id: Some(key_id),
api_key_scopes: scopes,
}
}
pub fn is_api_key_auth(&self) -> bool {
self.api_key_id.is_some()
}
pub fn guest(guest_id: String) -> Self {
Self {
user_id: Some(guest_id),
is_admin: false,
is_guest: true,
roles: Vec::new(),
tenant_id: None,
api_key_id: None,
api_key_scopes: None,
}
}
pub fn admin() -> Self {
Self {
user_id: Some("__admin__".into()),
is_admin: true,
is_guest: false,
roles: vec!["admin".into()],
tenant_id: None,
api_key_id: None,
api_key_scopes: None,
}
}
pub fn user(user_id: String) -> Self {
Self::authenticated(user_id)
}
pub fn tenant_id(&self) -> Option<&str> {
self.tenant_id.as_deref()
}
pub fn with_tenant(mut self, tenant_id: String) -> Self {
self.tenant_id = Some(tenant_id);
self
}
pub fn is_authenticated(&self) -> bool {
self.user_id.is_some() && !self.is_guest
}
pub fn has_role(&self, role: &str) -> bool {
self.is_admin || self.roles.iter().any(|r| r == role)
}
pub fn has_any_role(&self, roles: &[&str]) -> bool {
self.is_admin || roles.iter().any(|r| self.has_role(r))
}
pub fn with_roles(mut self, roles: Vec<String>) -> Self {
self.roles = roles;
self
}
}
pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut result: u8 = 0;
for (x, y) in a.iter().zip(b.iter()) {
result |= x ^ y;
}
result == 0
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AuthMode {
Public,
User,
}
impl AuthMode {
#[allow(clippy::should_implement_trait)]
pub fn from_str(s: &str) -> Option<Self> {
match s {
"public" => Some(AuthMode::Public),
"user" => Some(AuthMode::User),
_ => None,
}
}
pub fn check(&self, ctx: &AuthContext) -> bool {
match self {
AuthMode::Public => true,
AuthMode::User => ctx.is_authenticated(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Session {
pub token: String,
pub user_id: String,
#[serde(default)]
pub expires_at: u64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub device: Option<String>,
#[serde(default)]
pub created_at: u64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tenant_id: Option<String>,
}
impl Session {
pub const DEFAULT_LIFETIME_SECS: u64 = 30 * 24 * 60 * 60;
pub fn new(user_id: String) -> Self {
let now = now_secs();
Self {
token: generate_token(),
user_id,
expires_at: now.saturating_add(Self::DEFAULT_LIFETIME_SECS),
device: None,
created_at: now,
tenant_id: None,
}
}
pub fn with_lifetime(user_id: String, lifetime_secs: u64) -> Self {
let now = now_secs();
Self {
token: generate_token(),
user_id,
expires_at: if lifetime_secs == 0 {
0
} else {
now.saturating_add(lifetime_secs)
},
device: None,
created_at: now,
tenant_id: None,
}
}
pub fn to_auth_context(&self) -> AuthContext {
let ctx = AuthContext::authenticated(self.user_id.clone());
match &self.tenant_id {
Some(t) => ctx.with_tenant(t.clone()),
None => ctx,
}
}
pub fn is_expired(&self) -> bool {
self.expires_at != 0 && now_secs() >= self.expires_at
}
}
fn now_secs() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct OAuthConfig {
pub provider: String,
pub client_id: String,
pub client_secret: String,
pub redirect_uri: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub scopes_override: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tenant: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub apple: Option<provider::AppleConfig>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub oidc_issuer: Option<String>,
}
impl OAuthConfig {
fn resolved_spec(&self) -> Result<provider::ResolvedSpec, String> {
if let Some(issuer) = self.oidc_issuer.as_deref() {
return provider::oidc_cache::resolve(issuer);
}
provider::find_spec(&self.provider)
.map(provider::ResolvedSpec::Static)
.ok_or_else(|| format!("unknown OAuth provider: {}", self.provider))
}
fn provider_cfg(&self) -> provider::ProviderConfig {
provider::ProviderConfig {
provider: self.provider.clone(),
client_id: self.client_id.clone(),
client_secret: self.client_secret.clone(),
redirect_uri: self.redirect_uri.clone(),
scopes_override: self.scopes_override.clone(),
tenant: self.tenant.clone(),
apple: self.apple.clone(),
oidc_issuer: self.oidc_issuer.clone(),
}
}
pub fn auth_url(&self) -> String {
match self.build_auth_url(None) {
Ok(u) => u,
Err(_) => String::new(),
}
}
pub fn auth_url_with_state(&self, state: &str) -> String {
let base = self.auth_url();
if base.is_empty() {
return base;
}
format!("{}&state={}", base, url_encode(state))
}
pub fn auth_url_with_pkce(&self, state: &str) -> Result<(String, Option<String>), String> {
let spec = self.resolved_spec()?;
let pkce = if spec.requires_pkce() {
Some(generate_pkce())
} else {
None
};
let challenge = pkce.as_ref().map(|p| p.code_challenge.as_str());
let mut url = self.build_auth_url(challenge)?;
if !state.is_empty() {
url.push_str(&format!("&state={}", url_encode(state)));
}
Ok((url, pkce.map(|p| p.code_verifier)))
}
fn build_auth_url(&self, pkce_challenge: Option<&str>) -> Result<String, String> {
let spec = self.resolved_spec()?;
let cfg = self.provider_cfg();
let auth = provider::resolve_endpoint(spec.auth_url(), &cfg);
if auth.is_empty() {
return Err(format!(
"provider {} has no authorization endpoint",
self.provider
));
}
let scopes_default = spec.scopes().to_string();
let scopes_raw = self.scopes_override.as_deref().unwrap_or(&scopes_default);
let scopes_joined = scopes_raw
.split_whitespace()
.collect::<Vec<_>>()
.join(spec.scope_separator());
let mut url = format!(
"{auth}?{cid_param}={cid}&redirect_uri={ruri}&response_type=code&scope={scope}",
cid_param = spec.client_id_param(),
cid = url_encode(&self.client_id),
ruri = url_encode(&self.redirect_uri),
scope = url_encode(&scopes_joined),
);
if !spec.auth_query_extra().is_empty() {
url.push('&');
url.push_str(spec.auth_query_extra());
}
if let Some(challenge) = pkce_challenge {
url.push_str("&code_challenge=");
url.push_str(challenge);
url.push_str("&code_challenge_method=S256");
}
Ok(url)
}
pub fn token_url(&self) -> String {
match self.resolved_spec() {
Ok(spec) => provider::resolve_endpoint(spec.token_url(), &self.provider_cfg()),
Err(_) => String::new(),
}
}
pub fn userinfo_url(&self) -> String {
match self.resolved_spec() {
Ok(spec) => match spec.userinfo_url() {
Some(u) => provider::resolve_endpoint(u, &self.provider_cfg()),
None => String::new(),
},
Err(_) => String::new(),
}
}
pub fn exchange_code_full(&self, code: &str) -> Result<TokenSet, String> {
self.exchange_code_full_pkce(code, None)
}
pub fn exchange_code_full_pkce(
&self,
code: &str,
code_verifier: Option<&str>,
) -> Result<TokenSet, String> {
let spec = self.resolved_spec()?;
let cfg = self.provider_cfg();
let token_url = provider::resolve_endpoint(spec.token_url(), &cfg);
let pkce_field = code_verifier
.map(|v| format!("&code_verifier={}", url_encode(v)))
.unwrap_or_default();
let out = match spec.token_exchange() {
provider::TokenExchangeShape::Standard => {
let body = format!(
"code={code}&{cid_param}={cid}&client_secret={secret}&redirect_uri={ruri}&grant_type=authorization_code{pkce}",
code = url_encode(code),
cid_param = spec.client_id_param(),
cid = url_encode(&self.client_id),
secret = url_encode(&self.client_secret),
ruri = url_encode(&self.redirect_uri),
pkce = pkce_field,
);
http_post_form(&token_url, &body, true).map_err(sanitize_token_error)?
}
provider::TokenExchangeShape::AppleJwt => {
let apple = self.apple.as_ref().ok_or(
"apple provider requires `apple` config (team_id, key_id, private_key_pem)",
)?;
let signed_secret = apple_jwt::mint_client_secret(apple, &self.client_id)?;
let body = format!(
"code={code}&client_id={cid}&client_secret={secret}&redirect_uri={ruri}&grant_type=authorization_code{pkce}",
code = url_encode(code),
cid = url_encode(&self.client_id),
secret = url_encode(&signed_secret),
ruri = url_encode(&self.redirect_uri),
pkce = pkce_field,
);
http_post_form(&token_url, &body, true).map_err(sanitize_token_error)?
}
provider::TokenExchangeShape::BasicAuth => {
let body = format!(
"code={code}&redirect_uri={ruri}&grant_type=authorization_code{pkce}",
code = url_encode(code),
ruri = url_encode(&self.redirect_uri),
pkce = pkce_field,
);
http_post_form_basic(&token_url, &body, &self.client_id, &self.client_secret)
.map_err(sanitize_token_error)?
}
provider::TokenExchangeShape::JsonBody => {
let mut json = serde_json::Map::new();
json.insert("grant_type".into(), "authorization_code".into());
json.insert("code".into(), code.into());
json.insert("redirect_uri".into(), self.redirect_uri.clone().into());
json.insert("client_id".into(), self.client_id.clone().into());
json.insert("client_secret".into(), self.client_secret.clone().into());
if let Some(v) = code_verifier {
json.insert("code_verifier".into(), v.to_string().into());
}
let body = serde_json::Value::Object(json).to_string();
http_post_json(&token_url, &body, None).map_err(sanitize_token_error)?
}
provider::TokenExchangeShape::BasicAuthJsonBody => {
let mut json = serde_json::Map::new();
json.insert("grant_type".into(), "authorization_code".into());
json.insert("code".into(), code.into());
json.insert("redirect_uri".into(), self.redirect_uri.clone().into());
if let Some(v) = code_verifier {
json.insert("code_verifier".into(), v.to_string().into());
}
let body = serde_json::Value::Object(json).to_string();
http_post_json(
&token_url,
&body,
Some((&self.client_id, &self.client_secret)),
)
.map_err(sanitize_token_error)?
}
};
parse_token_response(&out)
}
pub fn exchange_code(&self, code: &str) -> Result<String, String> {
Ok(self.exchange_code_full(code)?.access_token)
}
pub fn fetch_userinfo(&self, access_token: &str) -> Result<(String, Option<String>), String> {
let info = self.fetch_userinfo_full(access_token)?;
Ok((info.email, info.name))
}
pub fn fetch_userinfo_full(&self, access_token: &str) -> Result<UserInfo, String> {
self.fetch_userinfo_with_id_token(access_token, None)
}
pub fn fetch_userinfo_with_id_token(
&self,
access_token: &str,
id_token: Option<&str>,
) -> Result<UserInfo, String> {
let spec = self.resolved_spec()?;
let cfg = self.provider_cfg();
if matches!(spec.userinfo_parser(), provider::UserinfoParser::AppleIdToken) {
let token = id_token
.ok_or("apple login requires the id_token from the token response")?;
return parse_apple_id_token(token, &self.provider);
}
if matches!(spec.userinfo_parser(), provider::UserinfoParser::LinearGraphql) {
return fetch_linear_userinfo(&self.provider, access_token);
}
let url = match spec.userinfo_url() {
Some(u) => provider::resolve_endpoint(u, &cfg),
None => return Err(format!("provider {} has no userinfo endpoint", self.provider)),
};
let out = match spec.userinfo_method() {
provider::UserinfoMethod::Get => http_get_bearer(&url, access_token),
provider::UserinfoMethod::Post => http_post_bearer(&url, access_token),
}
.map_err(sanitize_token_error)?;
let parsed: serde_json::Value =
serde_json::from_str(&out).map_err(|e| format!("userinfo not valid JSON: {e}"))?;
match spec.userinfo_parser() {
provider::UserinfoParser::Oidc => {
let email = parsed
.get("email")
.and_then(|v| v.as_str())
.ok_or("no email in userinfo")?
.to_string();
let name = parsed
.get("name")
.and_then(|v| v.as_str())
.map(String::from);
let provider_account_id = parsed
.get("sub")
.and_then(|v| v.as_str())
.ok_or("no sub in userinfo")?
.to_string();
Ok(UserInfo {
provider: self.provider.clone(),
provider_account_id,
email,
name,
})
}
provider::UserinfoParser::GitHub => {
let name = parsed
.get("name")
.and_then(|v| v.as_str())
.or_else(|| parsed.get("login").and_then(|v| v.as_str()))
.map(String::from);
let email = parsed
.get("email")
.and_then(|v| v.as_str())
.map(String::from);
let email = email
.or_else(|| fetch_github_primary_email(access_token).ok())
.ok_or("no accessible email on GitHub account")?;
let provider_account_id = parsed
.get("id")
.map(|v| {
v.as_i64()
.map(|n| n.to_string())
.or_else(|| v.as_str().map(String::from))
.unwrap_or_default()
})
.filter(|s| !s.is_empty())
.ok_or("no id in userinfo")?;
Ok(UserInfo {
provider: self.provider.clone(),
provider_account_id,
email,
name,
})
}
provider::UserinfoParser::Custom {
id_path,
email_path,
name_path,
} => {
let provider_account_id = json_pointer_string(&parsed, id_path)
.ok_or_else(|| format!("no id at {id_path} in userinfo"))?;
let raw_email = json_pointer_string(&parsed, email_path)
.ok_or_else(|| format!("no email at {email_path} in userinfo"))?;
let email = if !raw_email.contains('@') {
let domain = match self.provider.as_str() {
"twitter" => "x.invalid",
"reddit" => "reddit.invalid",
other => return Err(format!(
"{other}: userinfo `email` field is not an email address (got {raw_email:?}); refusing to synthesize",
)),
};
format!("{raw_email}@{domain}")
} else {
raw_email
};
let name = name_path.and_then(|p| json_pointer_string(&parsed, p));
Ok(UserInfo {
provider: self.provider.clone(),
provider_account_id,
email,
name,
})
}
provider::UserinfoParser::AppleIdToken => unreachable!("handled above"),
provider::UserinfoParser::LinearGraphql => unreachable!("handled above"),
}
}
}
struct PkcePair {
code_verifier: String,
code_challenge: String,
}
fn generate_pkce() -> PkcePair {
use rand::RngCore;
let mut bytes = [0u8; 32];
rand::thread_rng().fill_bytes(&mut bytes);
let code_verifier = apple_jwt::base64_url(bytes);
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(code_verifier.as_bytes());
let code_challenge = apple_jwt::base64_url(hasher.finalize());
PkcePair {
code_verifier,
code_challenge,
}
}
fn parse_apple_id_token(id_token: &str, provider: &str) -> Result<UserInfo, String> {
let mut parts = id_token.split('.');
let _header = parts.next().ok_or("apple id_token: missing header")?;
let claims_b64 = parts.next().ok_or("apple id_token: missing claims")?;
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
let claims_bytes = URL_SAFE_NO_PAD
.decode(claims_b64)
.map_err(|e| format!("apple id_token claims not base64: {e}"))?;
let claims: serde_json::Value = serde_json::from_slice(&claims_bytes)
.map_err(|e| format!("apple id_token claims not JSON: {e}"))?;
let provider_account_id = claims
.get("sub")
.and_then(|v| v.as_str())
.ok_or("apple id_token: missing sub")?
.to_string();
let email = claims
.get("email")
.and_then(|v| v.as_str())
.ok_or("apple id_token: missing email (was the `email` scope requested?)")?
.to_string();
Ok(UserInfo {
provider: provider.to_string(),
provider_account_id,
email,
name: None, })
}
fn sanitize_token_error(err: String) -> String {
const SENSITIVE: &[&str] = &[
"client_secret",
"code_verifier",
"client_assertion",
"refresh_token",
"access_token",
"id_token",
"code",
];
let mut out = err;
for key in SENSITIVE {
out = redact_param_form(&out, key);
out = redact_param_json(&out, key);
}
out
}
fn redact_param_form(input: &str, key: &str) -> String {
let needle = format!("{key}=");
let mut out = String::with_capacity(input.len());
let mut i = 0;
while i < input.len() {
if input[i..].starts_with(&needle) {
out.push_str(&needle);
out.push_str("***");
i += needle.len();
while let Some((rel, ch)) = input[i..].char_indices().next() {
if matches!(ch, '&' | '\n' | '"' | ' ' | '\'') {
i += rel;
break;
}
i += rel + ch.len_utf8();
}
} else {
let (_, ch) = input[i..].char_indices().next().expect("non-empty");
out.push(ch);
i += ch.len_utf8();
}
}
out
}
fn redact_param_json(input: &str, key: &str) -> String {
let needle = format!("\"{key}\"");
let mut out = String::with_capacity(input.len());
let mut i = 0;
while i < input.len() {
if !input[i..].starts_with(&needle) {
let (_, ch) = input[i..].char_indices().next().expect("non-empty");
out.push(ch);
i += ch.len_utf8();
continue;
}
let mut j = i + needle.len();
while let Some((_, ch)) = input[j..].char_indices().next() {
if !ch.is_whitespace() {
break;
}
j += ch.len_utf8();
}
if !input[j..].starts_with(':') {
out.push_str(&input[i..j]);
i = j;
continue;
}
j += 1;
while let Some((_, ch)) = input[j..].char_indices().next() {
if !ch.is_whitespace() {
break;
}
j += ch.len_utf8();
}
if !input[j..].starts_with('"') {
out.push_str(&input[i..j]);
i = j;
continue;
}
let value_start = j + 1;
let mut k = value_start;
let mut prev_backslash = false;
let mut closing: Option<usize> = None;
while k < input.len() {
let (_, ch) = input[k..].char_indices().next().expect("non-empty");
if ch == '"' && !prev_backslash {
closing = Some(k);
break;
}
prev_backslash = ch == '\\' && !prev_backslash;
k += ch.len_utf8();
}
match closing {
Some(end) => {
out.push_str(&input[i..value_start]);
out.push_str("***");
out.push('"');
i = end + 1;
}
None => {
out.push_str(&input[i..value_start]);
out.push_str("***");
i = input.len();
}
}
}
out
}
fn fetch_linear_userinfo(provider: &str, access_token: &str) -> Result<UserInfo, String> {
let body = r#"{"query":"query { viewer { id email name } }"}"#;
let agent = ureq_agent();
let resp = agent
.post("https://api.linear.app/graphql")
.set("Authorization", &format!("Bearer {access_token}"))
.set("Content-Type", "application/json")
.set("Accept", "application/json")
.send_string(body)
.map_err(|e| format!("linear graphql: {e}"))?;
let out = resp.into_string().map_err(|e| format!("read body: {e}"))?;
let parsed: serde_json::Value = serde_json::from_str(&out)
.map_err(|e| format!("linear graphql not JSON: {e}"))?;
let viewer = parsed
.pointer("/data/viewer")
.ok_or("linear graphql: no /data/viewer")?;
let provider_account_id = viewer
.get("id")
.and_then(|v| v.as_str())
.ok_or("linear graphql: no id")?
.to_string();
let email = viewer
.get("email")
.and_then(|v| v.as_str())
.ok_or("linear graphql: no email")?
.to_string();
let name = viewer.get("name").and_then(|v| v.as_str()).map(String::from);
Ok(UserInfo {
provider: provider.to_string(),
provider_account_id,
email,
name,
})
}
fn json_pointer_string(v: &serde_json::Value, path: &str) -> Option<String> {
let node = v.pointer(path)?;
if let Some(s) = node.as_str() {
return Some(s.to_string());
}
if let Some(n) = node.as_i64() {
return Some(n.to_string());
}
if let Some(n) = node.as_u64() {
return Some(n.to_string());
}
None
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct UserInfo {
pub provider: String,
pub provider_account_id: String,
pub email: String,
pub name: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TokenSet {
pub access_token: String,
pub refresh_token: Option<String>,
pub id_token: Option<String>,
pub expires_at: Option<u64>,
pub scope: Option<String>,
}
fn parse_token_response(body: &str) -> Result<TokenSet, String> {
let json: serde_json::Value = serde_json::from_str(body).unwrap_or_else(|_| {
let mut map = serde_json::Map::new();
for pair in body.split('&') {
if let Some((k, v)) = pair.split_once('=') {
map.insert(k.to_string(), serde_json::Value::String(v.to_string()));
}
}
serde_json::Value::Object(map)
});
let access_token = json
.get("access_token")
.and_then(|v| v.as_str())
.ok_or_else(|| format!("no access_token in token response: {body}"))?
.to_string();
let refresh_token = json
.get("refresh_token")
.and_then(|v| v.as_str())
.map(String::from);
let id_token = json
.get("id_token")
.and_then(|v| v.as_str())
.map(String::from);
let expires_at = json
.get("expires_in")
.and_then(|v| {
v.as_u64()
.or_else(|| v.as_str().and_then(|s| s.parse().ok()))
})
.map(|secs| now_secs().saturating_add(secs));
let scope = json.get("scope").and_then(|v| v.as_str()).map(String::from);
Ok(TokenSet {
access_token,
refresh_token,
id_token,
expires_at,
scope,
})
}
fn url_encode(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for b in s.bytes() {
match b {
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
out.push(b as char)
}
_ => out.push_str(&format!("%{b:02X}")),
}
}
out
}
const HTTP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
fn ureq_agent() -> ureq::Agent {
ureq::AgentBuilder::new()
.timeout_connect(HTTP_TIMEOUT)
.timeout_read(HTTP_TIMEOUT)
.timeout_write(HTTP_TIMEOUT)
.user_agent("pylon/0.1")
.build()
}
fn http_post_form(url: &str, body: &str, accept_json: bool) -> Result<String, String> {
let agent = ureq_agent();
let mut req = agent
.post(url)
.set("Content-Type", "application/x-www-form-urlencoded");
if accept_json {
req = req.set("Accept", "application/json");
}
match req.send_string(body) {
Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
Err(ureq::Error::Status(code, resp)) => {
let body = resp.into_string().unwrap_or_default();
Err(format!("HTTP {code}: {body}"))
}
Err(e) => Err(format!("HTTP error: {e}")),
}
}
fn http_post_form_basic(
url: &str,
body: &str,
client_id: &str,
client_secret: &str,
) -> Result<String, String> {
use base64::{engine::general_purpose::STANDARD, Engine};
let creds = format!("{client_id}:{client_secret}");
let basic = STANDARD.encode(creds.as_bytes());
let agent = ureq_agent();
match agent
.post(url)
.set("Content-Type", "application/x-www-form-urlencoded")
.set("Accept", "application/json")
.set("Authorization", &format!("Basic {basic}"))
.send_string(body)
{
Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
Err(ureq::Error::Status(code, resp)) => {
let body = resp.into_string().unwrap_or_default();
Err(format!("HTTP {code}: {body}"))
}
Err(e) => Err(format!("HTTP error: {e}")),
}
}
fn http_post_json(
url: &str,
body: &str,
basic_creds: Option<(&str, &str)>,
) -> Result<String, String> {
let agent = ureq_agent();
let mut req = agent
.post(url)
.set("Content-Type", "application/json")
.set("Accept", "application/json");
if let Some((id, secret)) = basic_creds {
use base64::{engine::general_purpose::STANDARD, Engine};
let creds = STANDARD.encode(format!("{id}:{secret}").as_bytes());
req = req.set("Authorization", &format!("Basic {creds}"));
}
req = req.set("Notion-Version", "2022-06-28");
match req.send_string(body) {
Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
Err(ureq::Error::Status(code, resp)) => {
let body = resp.into_string().unwrap_or_default();
Err(format!("HTTP {code}: {body}"))
}
Err(e) => Err(format!("HTTP error: {e}")),
}
}
fn http_post_bearer(url: &str, token: &str) -> Result<String, String> {
let agent = ureq_agent();
match agent
.post(url)
.set("Authorization", &format!("Bearer {token}"))
.set("Accept", "application/json")
.call()
{
Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
Err(ureq::Error::Status(code, resp)) => {
let body = resp.into_string().unwrap_or_default();
Err(format!("HTTP {code}: {body}"))
}
Err(e) => Err(format!("HTTP error: {e}")),
}
}
fn http_get_bearer(url: &str, token: &str) -> Result<String, String> {
let agent = ureq_agent();
match agent
.get(url)
.set("Authorization", &format!("Bearer {token}"))
.set("Accept", "application/json")
.call()
{
Ok(resp) => resp.into_string().map_err(|e| format!("read body: {e}")),
Err(ureq::Error::Status(code, resp)) => {
let body = resp.into_string().unwrap_or_default();
Err(format!("HTTP {code}: {body}"))
}
Err(e) => Err(format!("HTTP error: {e}")),
}
}
fn fetch_github_primary_email(token: &str) -> Result<String, String> {
let out = http_get_bearer("https://api.github.com/user/emails", token)?;
let emails: serde_json::Value =
serde_json::from_str(&out).map_err(|e| format!("emails not JSON: {e}"))?;
emails
.as_array()
.and_then(|arr| {
arr.iter()
.find(|e| {
e.get("primary").and_then(|v| v.as_bool()).unwrap_or(false)
&& e.get("verified").and_then(|v| v.as_bool()).unwrap_or(false)
})
.and_then(|e| e.get("email").and_then(|v| v.as_str()).map(String::from))
})
.ok_or_else(|| "no primary verified email on GitHub".into())
}
pub struct OAuthRegistry {
providers: std::collections::HashMap<String, OAuthConfig>,
}
impl Default for OAuthRegistry {
fn default() -> Self {
Self::new()
}
}
impl OAuthRegistry {
pub fn new() -> Self {
Self {
providers: std::collections::HashMap::new(),
}
}
pub fn register(&mut self, config: OAuthConfig) {
self.providers.insert(config.provider.clone(), config);
}
pub fn get(&self, provider: &str) -> Option<&OAuthConfig> {
self.providers.get(provider)
}
pub fn from_env() -> Self {
let mut reg = Self::new();
for spec in provider::builtin::all() {
let upper = spec.id.to_ascii_uppercase();
let prefix = format!("PYLON_OAUTH_{upper}");
let id = match std::env::var(format!("{prefix}_CLIENT_ID")) {
Ok(v) => v,
Err(_) => continue,
};
let secret = match std::env::var(format!("{prefix}_CLIENT_SECRET")) {
Ok(v) => v,
Err(_) if spec.id == "apple" => String::new(),
Err(_) => continue,
};
let redirect_uri = std::env::var(format!("{prefix}_REDIRECT")).unwrap_or_else(|_| {
format!("http://localhost:3000/api/auth/callback/{}", spec.id)
});
let scopes_override = std::env::var(format!("{prefix}_SCOPES")).ok();
let tenant = std::env::var(format!("{prefix}_TENANT")).ok();
let apple = if spec.id == "apple" {
match (
std::env::var(format!("{prefix}_TEAM_ID")),
std::env::var(format!("{prefix}_KEY_ID")),
std::env::var(format!("{prefix}_PRIVATE_KEY")),
) {
(Ok(team_id), Ok(key_id), Ok(private_key_pem)) => Some(provider::AppleConfig {
team_id,
key_id,
private_key_pem,
}),
_ => continue, }
} else {
None
};
reg.register(OAuthConfig {
provider: spec.id.to_string(),
client_id: id,
client_secret: secret,
redirect_uri,
scopes_override,
tenant,
apple,
oidc_issuer: None,
});
}
for (key, issuer) in std::env::vars() {
let Some(rest) = key.strip_prefix("PYLON_OAUTH_") else {
continue;
};
let Some(name_upper) = rest.strip_suffix("_OIDC_ISSUER") else {
continue;
};
let name = name_upper.to_ascii_lowercase();
if provider::find_spec(&name).is_some() {
continue; }
let prefix = format!("PYLON_OAUTH_{name_upper}");
let id = match std::env::var(format!("{prefix}_CLIENT_ID")) {
Ok(v) => v,
Err(_) => continue,
};
let secret = std::env::var(format!("{prefix}_CLIENT_SECRET")).unwrap_or_default();
let redirect_uri = std::env::var(format!("{prefix}_REDIRECT"))
.unwrap_or_else(|_| format!("http://localhost:3000/api/auth/callback/{name}"));
reg.register(OAuthConfig {
provider: name,
client_id: id,
client_secret: secret,
redirect_uri,
scopes_override: std::env::var(format!("{prefix}_SCOPES")).ok(),
tenant: None,
apple: None,
oidc_issuer: Some(issuer),
});
}
reg
}
pub fn ids(&self) -> impl Iterator<Item = &str> {
self.providers.keys().map(|s| s.as_str())
}
pub fn shared() -> &'static OAuthRegistry {
static CELL: std::sync::OnceLock<OAuthRegistry> = std::sync::OnceLock::new();
CELL.get_or_init(Self::from_env)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OAuthState {
pub provider: String,
pub callback_url: String,
pub error_callback_url: String,
pub pkce_verifier: Option<String>,
pub expires_at: u64,
}
pub trait OAuthStateBackend: Send + Sync {
fn put(&self, token: &str, state: &OAuthState);
fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState>;
}
pub struct InMemoryOAuthBackend {
states: Mutex<HashMap<String, OAuthState>>,
}
impl InMemoryOAuthBackend {
pub fn new() -> Self {
Self {
states: Mutex::new(HashMap::new()),
}
}
}
impl Default for InMemoryOAuthBackend {
fn default() -> Self {
Self::new()
}
}
impl OAuthStateBackend for InMemoryOAuthBackend {
fn put(&self, token: &str, state: &OAuthState) {
self.states
.lock()
.unwrap()
.insert(token.to_string(), state.clone());
}
fn take(&self, token: &str, now_unix_secs: u64) -> Option<OAuthState> {
let mut s = self.states.lock().unwrap();
let entry = s.remove(token)?;
if entry.expires_at <= now_unix_secs {
return None;
}
Some(entry)
}
}
pub struct OAuthStateStore {
backend: Box<dyn OAuthStateBackend>,
}
impl Default for OAuthStateStore {
fn default() -> Self {
Self::new()
}
}
impl OAuthStateStore {
pub fn new() -> Self {
Self {
backend: Box::new(InMemoryOAuthBackend::new()),
}
}
pub fn with_backend(backend: Box<dyn OAuthStateBackend>) -> Self {
Self { backend }
}
pub fn create(&self, provider: &str, callback_url: &str, error_callback_url: &str) -> String {
self.create_with_pkce(provider, callback_url, error_callback_url, None)
}
pub fn create_with_pkce(
&self,
provider: &str,
callback_url: &str,
error_callback_url: &str,
pkce_verifier: Option<String>,
) -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let token = generate_token();
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let state = OAuthState {
provider: provider.to_string(),
callback_url: callback_url.to_string(),
error_callback_url: error_callback_url.to_string(),
pkce_verifier,
expires_at: now + 600,
};
self.backend.put(&token, &state);
token
}
pub fn validate(&self, state: &str, expected_provider: &str) -> Option<OAuthState> {
use std::time::{SystemTime, UNIX_EPOCH};
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let entry = self.backend.take(state, now)?;
if entry.provider != expected_provider {
return None;
}
Some(entry)
}
}
pub fn validate_trusted_redirect(
url: &str,
trusted_origins: &[String],
) -> Result<(), TrustedOriginError> {
if url.is_empty() {
return Err(TrustedOriginError::Empty);
}
if !url.starts_with("http://") && !url.starts_with("https://") {
return Err(TrustedOriginError::NotHttp);
}
let url_origin = origin_of(url);
if trusted_origins.iter().any(|t| t == &url_origin) {
Ok(())
} else {
Err(TrustedOriginError::NotTrusted { origin: url_origin })
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TrustedOriginError {
Empty,
NotHttp,
NotTrusted { origin: String },
}
impl std::fmt::Display for TrustedOriginError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TrustedOriginError::Empty => write!(f, "redirect URL is empty"),
TrustedOriginError::NotHttp => {
write!(f, "redirect URL must use http:// or https:// scheme")
}
TrustedOriginError::NotTrusted { origin } => write!(
f,
"redirect origin {origin:?} is not in PYLON_TRUSTED_ORIGINS"
),
}
}
}
pub fn origin_of(url: &str) -> String {
let after_scheme = match url.find("://") {
Some(i) => i + 3,
None => return url.trim_end_matches('/').to_string(),
};
let rest = &url[after_scheme..];
let cut = rest
.find(|c: char| c == '/' || c == '?' || c == '#')
.unwrap_or(rest.len());
url[..after_scheme + cut].to_string()
}
pub trait MagicCodeBackend: Send + Sync {
fn put(&self, email: &str, code: &MagicCode);
fn get(&self, email: &str) -> Option<MagicCode>;
fn remove(&self, email: &str);
fn bump_attempts(&self, email: &str);
fn load_all(&self) -> Vec<MagicCode>;
}
pub struct InMemoryMagicCodeBackend {
codes: Mutex<HashMap<String, MagicCode>>,
}
impl InMemoryMagicCodeBackend {
pub fn new() -> Self {
Self {
codes: Mutex::new(HashMap::new()),
}
}
}
impl Default for InMemoryMagicCodeBackend {
fn default() -> Self {
Self::new()
}
}
impl MagicCodeBackend for InMemoryMagicCodeBackend {
fn put(&self, email: &str, code: &MagicCode) {
self.codes
.lock()
.unwrap()
.insert(email.to_string(), code.clone());
}
fn get(&self, email: &str) -> Option<MagicCode> {
self.codes.lock().unwrap().get(email).cloned()
}
fn remove(&self, email: &str) {
self.codes.lock().unwrap().remove(email);
}
fn bump_attempts(&self, email: &str) {
if let Some(c) = self.codes.lock().unwrap().get_mut(email) {
c.attempts = c.attempts.saturating_add(1);
}
}
fn load_all(&self) -> Vec<MagicCode> {
self.codes.lock().unwrap().values().cloned().collect()
}
}
pub struct MagicCodeStore {
cache: Mutex<HashMap<String, MagicCode>>,
backend: Box<dyn MagicCodeBackend>,
}
#[derive(Debug, Clone)]
pub struct MagicCode {
pub email: String,
pub code: String,
pub expires_at: u64,
pub attempts: u32,
}
const MAX_ATTEMPTS: u32 = 5;
const CREATE_COOLDOWN_SECS: u64 = 60;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MagicCodeError {
NotFound,
TooManyAttempts,
BadCode,
Expired,
Throttled { retry_after_secs: u64 },
}
impl Default for MagicCodeStore {
fn default() -> Self {
Self::new()
}
}
impl MagicCodeStore {
pub fn new() -> Self {
Self::with_backend(Box::new(InMemoryMagicCodeBackend::new()))
}
pub fn with_backend(backend: Box<dyn MagicCodeBackend>) -> Self {
let now = now_secs();
let mut cache = HashMap::new();
for c in backend.load_all() {
if c.expires_at > now {
cache.insert(c.email.clone(), c);
}
}
Self {
cache: Mutex::new(cache),
backend,
}
}
pub fn create(&self, email: &str) -> String {
self.try_create(email).unwrap_or_else(|_| String::new())
}
pub fn try_create(&self, email: &str) -> Result<String, MagicCodeError> {
let now = now_secs();
let mut codes = self.cache.lock().unwrap();
if let Some(existing) = codes.get(email) {
if existing.expires_at > now {
let created_at = existing.expires_at.saturating_sub(600);
let age = now.saturating_sub(created_at);
if age < CREATE_COOLDOWN_SECS {
return Err(MagicCodeError::Throttled {
retry_after_secs: CREATE_COOLDOWN_SECS - age,
});
}
}
}
let code = generate_magic_code();
let mc = MagicCode {
email: email.to_string(),
code: code.clone(),
expires_at: now + 600, attempts: 0,
};
codes.insert(email.to_string(), mc.clone());
self.backend.put(email, &mc);
Ok(code)
}
pub fn verify(&self, email: &str, code: &str) -> bool {
matches!(self.try_verify(email, code), Ok(()))
}
pub fn list_all_unfiltered(&self) -> Vec<MagicCode> {
self.cache
.lock()
.map(|m| m.values().cloned().collect())
.unwrap_or_default()
}
pub fn try_verify(&self, email: &str, code: &str) -> Result<(), MagicCodeError> {
let now = now_secs();
let mut codes = self.cache.lock().unwrap();
let mc = match codes.get_mut(email) {
Some(m) => m,
None => return Err(MagicCodeError::NotFound),
};
if mc.attempts >= MAX_ATTEMPTS {
return Err(MagicCodeError::TooManyAttempts);
}
if mc.expires_at <= now {
codes.remove(email);
self.backend.remove(email);
return Err(MagicCodeError::Expired);
}
let ok = constant_time_eq(mc.code.as_bytes(), code.as_bytes());
if !ok {
mc.attempts += 1;
self.backend.bump_attempts(email);
if mc.attempts >= MAX_ATTEMPTS {
return Err(MagicCodeError::TooManyAttempts);
}
return Err(MagicCodeError::BadCode);
}
codes.remove(email);
self.backend.remove(email);
Ok(())
}
}
fn hex_encode(bytes: &[u8]) -> String {
bytes.iter().map(|b| format!("{:02x}", b)).collect()
}
fn generate_magic_code() -> String {
use rand::Rng;
let mut rng = rand::thread_rng();
let code: u32 = rng.gen_range(0..1_000_000);
format!("{:06}", code)
}
fn generate_token() -> String {
use rand::Rng;
let mut rng = rand::thread_rng();
let bytes: [u8; 32] = rng.gen();
format!("pylon_{}", hex_encode(&bytes))
}
use std::collections::HashMap;
use std::sync::Mutex;
pub trait SessionBackend: Send + Sync {
fn load_all(&self) -> Vec<Session>;
fn save(&self, session: &Session);
fn remove(&self, token: &str);
}
pub struct SessionStore {
sessions: Mutex<HashMap<String, Session>>,
backend: Option<Box<dyn SessionBackend>>,
default_lifetime_secs: u64,
}
impl Default for SessionStore {
fn default() -> Self {
Self::new()
}
}
impl SessionStore {
pub fn new() -> Self {
Self {
sessions: Mutex::new(HashMap::new()),
backend: None,
default_lifetime_secs: Session::DEFAULT_LIFETIME_SECS,
}
}
pub fn with_lifetime(mut self, lifetime_secs: u64) -> Self {
self.default_lifetime_secs = lifetime_secs;
self
}
pub fn with_backend(backend: Box<dyn SessionBackend>) -> Self {
let mut map = HashMap::new();
for s in backend.load_all() {
if !s.is_expired() {
map.insert(s.token.clone(), s);
}
}
Self {
sessions: Mutex::new(map),
backend: Some(backend),
default_lifetime_secs: Session::DEFAULT_LIFETIME_SECS,
}
}
pub fn create(&self, user_id: String) -> Session {
let session = Session::with_lifetime(user_id, self.default_lifetime_secs);
let mut sessions = self.sessions.lock().unwrap();
sessions.insert(session.token.clone(), session.clone());
if let Some(b) = &self.backend {
b.save(&session);
}
session
}
pub fn get(&self, token: &str) -> Option<Session> {
let mut sessions = self.sessions.lock().unwrap();
match sessions.get(token) {
Some(s) if s.is_expired() => {
sessions.remove(token);
None
}
Some(s) => Some(s.clone()),
None => None,
}
}
pub fn resolve(&self, token: Option<&str>) -> AuthContext {
match token {
Some(t) => match self.get(t) {
Some(session) => session.to_auth_context(),
None => AuthContext::anonymous(),
},
None => AuthContext::anonymous(),
}
}
pub fn refresh(&self, old_token: &str) -> Option<Session> {
let mut sessions = self.sessions.lock().unwrap();
let old = sessions.remove(old_token)?;
if let Some(b) = &self.backend {
b.remove(old_token);
}
if old.is_expired() {
return None;
}
let mut new = Session::with_lifetime(old.user_id.clone(), self.default_lifetime_secs);
new.device = old.device.clone();
sessions.insert(new.token.clone(), new.clone());
if let Some(b) = &self.backend {
b.save(&new);
}
Some(new)
}
pub fn list_all_unfiltered(&self) -> Vec<Session> {
self.sessions
.lock()
.map(|m| m.values().cloned().collect())
.unwrap_or_default()
}
pub fn list_for_user(&self, user_id: &str) -> Vec<Session> {
let sessions = self.sessions.lock().unwrap();
sessions
.values()
.filter(|s| s.user_id == user_id && !s.is_expired())
.cloned()
.collect()
}
pub fn revoke_all_for_user(&self, user_id: &str) -> usize {
let mut sessions = self.sessions.lock().unwrap();
let tokens: Vec<String> = sessions
.iter()
.filter_map(|(t, s)| {
if s.user_id == user_id {
Some(t.clone())
} else {
None
}
})
.collect();
let n = tokens.len();
for t in &tokens {
sessions.remove(t);
if let Some(b) = &self.backend {
b.remove(t);
}
}
n
}
pub fn sweep_expired(&self) -> usize {
let mut sessions = self.sessions.lock().unwrap();
let expired: Vec<String> = sessions
.iter()
.filter_map(|(t, s)| {
if s.is_expired() {
Some(t.clone())
} else {
None
}
})
.collect();
let n = expired.len();
for t in &expired {
sessions.remove(t);
if let Some(b) = &self.backend {
b.remove(t);
}
}
n
}
pub fn set_device(&self, token: &str, device: String) -> bool {
let mut sessions = self.sessions.lock().unwrap();
if let Some(s) = sessions.get_mut(token) {
s.device = Some(device);
if let Some(b) = &self.backend {
b.save(s);
}
true
} else {
false
}
}
pub fn create_guest(&self) -> Session {
use rand::Rng;
let mut rng = rand::thread_rng();
let bytes: [u8; 16] = rng.gen();
let guest_id = format!("guest_{}", hex_encode(&bytes));
self.create(guest_id)
}
pub fn upgrade(&self, token: &str, real_user_id: String) -> bool {
let mut sessions = self.sessions.lock().unwrap();
if let Some(session) = sessions.get_mut(token) {
session.user_id = real_user_id;
if let Some(b) = &self.backend {
b.save(session);
}
true
} else {
false
}
}
pub fn set_tenant(&self, token: &str, tenant_id: Option<String>) -> bool {
let mut sessions = self.sessions.lock().unwrap();
if let Some(session) = sessions.get_mut(token) {
session.tenant_id = tenant_id;
if let Some(b) = &self.backend {
b.save(session);
}
true
} else {
false
}
}
pub fn revoke(&self, token: &str) -> bool {
let mut sessions = self.sessions.lock().unwrap();
let removed = sessions.remove(token).is_some();
if removed {
if let Some(b) = &self.backend {
b.remove(token);
}
}
removed
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Account {
pub id: String,
pub user_id: String,
pub provider_id: String,
pub account_id: String,
pub access_token: Option<String>,
pub refresh_token: Option<String>,
pub id_token: Option<String>,
pub access_token_expires_at: Option<u64>,
pub refresh_token_expires_at: Option<u64>,
pub scope: Option<String>,
pub password: Option<String>,
pub created_at: u64,
pub updated_at: u64,
}
impl Account {
pub fn new(user_id: String, info: &UserInfo, tokens: &TokenSet) -> Self {
let now = now_secs();
Self {
id: generate_token(),
user_id,
provider_id: info.provider.clone(),
account_id: info.provider_account_id.clone(),
access_token: Some(tokens.access_token.clone()),
refresh_token: tokens.refresh_token.clone(),
id_token: tokens.id_token.clone(),
access_token_expires_at: tokens.expires_at,
refresh_token_expires_at: None,
scope: tokens.scope.clone(),
password: None,
created_at: now,
updated_at: now,
}
}
pub fn access_token_expired(&self) -> bool {
match self.access_token_expires_at {
Some(ts) => now_secs() >= ts,
None => false,
}
}
}
pub trait AccountBackend: Send + Sync {
fn upsert(&self, account: &Account);
fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account>;
fn find_for_user(&self, user_id: &str) -> Vec<Account>;
fn unlink(&self, provider_id: &str, account_id: &str) -> bool;
fn delete_for_user(&self, user_id: &str) -> usize {
let accounts = self.find_for_user(user_id);
let n = accounts.len();
for a in accounts {
self.unlink(&a.provider_id, &a.account_id);
}
n
}
fn list_all(&self) -> Vec<Account>;
}
pub struct InMemoryAccountBackend {
accounts: Mutex<HashMap<(String, String), Account>>,
}
impl InMemoryAccountBackend {
pub fn new() -> Self {
Self {
accounts: Mutex::new(HashMap::new()),
}
}
}
impl Default for InMemoryAccountBackend {
fn default() -> Self {
Self::new()
}
}
impl AccountBackend for InMemoryAccountBackend {
fn upsert(&self, account: &Account) {
let key = (account.provider_id.clone(), account.account_id.clone());
self.accounts.lock().unwrap().insert(key, account.clone());
}
fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
self.accounts
.lock()
.unwrap()
.get(&(provider_id.to_string(), account_id.to_string()))
.cloned()
}
fn find_for_user(&self, user_id: &str) -> Vec<Account> {
self.accounts
.lock()
.unwrap()
.values()
.filter(|a| a.user_id == user_id)
.cloned()
.collect()
}
fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
self.accounts
.lock()
.unwrap()
.remove(&(provider_id.to_string(), account_id.to_string()))
.is_some()
}
fn list_all(&self) -> Vec<Account> {
self.accounts.lock().unwrap().values().cloned().collect()
}
}
pub struct AccountStore {
backend: Box<dyn AccountBackend>,
}
impl Default for AccountStore {
fn default() -> Self {
Self::new()
}
}
impl AccountStore {
pub fn new() -> Self {
Self {
backend: Box::new(InMemoryAccountBackend::new()),
}
}
pub fn with_backend(backend: Box<dyn AccountBackend>) -> Self {
Self { backend }
}
pub fn upsert(&self, account: &Account) {
self.backend.upsert(account);
}
pub fn find_by_provider(&self, provider_id: &str, account_id: &str) -> Option<Account> {
self.backend.find_by_provider(provider_id, account_id)
}
pub fn find_for_user(&self, user_id: &str) -> Vec<Account> {
self.backend.find_for_user(user_id)
}
pub fn delete_for_user(&self, user_id: &str) -> usize {
self.backend.delete_for_user(user_id)
}
pub fn unlink(&self, provider_id: &str, account_id: &str) -> bool {
self.backend.unlink(provider_id, account_id)
}
pub fn list_all_unfiltered(&self) -> Vec<Account> {
self.backend.list_all()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn anonymous_context() {
let ctx = AuthContext::anonymous();
assert!(!ctx.is_authenticated());
assert!(ctx.user_id.is_none());
}
#[test]
fn authenticated_context() {
let ctx = AuthContext::authenticated("user-1".into());
assert!(ctx.is_authenticated());
assert_eq!(ctx.user_id, Some("user-1".into()));
}
#[test]
fn from_api_key_carries_scope_metadata() {
let ctx = AuthContext::from_api_key(
"user-1".into(),
"key_abc".into(),
Some("read,write".into()),
);
assert!(ctx.is_authenticated());
assert!(ctx.is_api_key_auth());
assert_eq!(ctx.user_id.as_deref(), Some("user-1"));
assert_eq!(ctx.api_key_id.as_deref(), Some("key_abc"));
assert_eq!(ctx.api_key_scopes.as_deref(), Some("read,write"));
}
#[test]
fn session_auth_is_not_api_key_auth() {
let ctx = AuthContext::authenticated("user-1".into());
assert!(!ctx.is_api_key_auth());
assert!(ctx.api_key_id.is_none());
}
#[test]
fn auth_mode_public_allows_anonymous() {
let mode = AuthMode::Public;
assert!(mode.check(&AuthContext::anonymous()));
assert!(mode.check(&AuthContext::authenticated("user-1".into())));
}
#[test]
fn auth_mode_user_requires_authenticated() {
let mode = AuthMode::User;
assert!(!mode.check(&AuthContext::anonymous()));
assert!(mode.check(&AuthContext::authenticated("user-1".into())));
}
#[test]
fn auth_mode_from_str() {
assert_eq!(AuthMode::from_str("public"), Some(AuthMode::Public));
assert_eq!(AuthMode::from_str("user"), Some(AuthMode::User));
assert_eq!(AuthMode::from_str("admin"), None);
}
#[test]
fn session_store_create_and_get() {
let store = SessionStore::new();
let session = store.create("user-1".into());
assert!(!session.token.is_empty());
assert!(session.token.starts_with("pylon_"));
let retrieved = store.get(&session.token).unwrap();
assert_eq!(retrieved.user_id, "user-1");
}
#[test]
fn session_store_resolve() {
let store = SessionStore::new();
let session = store.create("user-1".into());
let ctx = store.resolve(Some(&session.token));
assert!(ctx.is_authenticated());
assert_eq!(ctx.user_id, Some("user-1".into()));
let anon = store.resolve(None);
assert!(!anon.is_authenticated());
let bad = store.resolve(Some("invalid-token"));
assert!(!bad.is_authenticated());
}
#[test]
fn session_store_revoke() {
let store = SessionStore::new();
let session = store.create("user-1".into());
assert!(store.revoke(&session.token));
assert!(store.get(&session.token).is_none());
assert!(!store.revoke(&session.token)); }
#[test]
fn session_to_auth_context() {
let session = Session::new("user-42".into());
let ctx = session.to_auth_context();
assert_eq!(ctx.user_id, Some("user-42".into()));
}
#[test]
fn admin_context() {
let ctx = AuthContext::admin();
assert!(ctx.is_admin);
assert!(ctx.is_authenticated());
}
#[test]
fn anonymous_not_admin() {
let ctx = AuthContext::anonymous();
assert!(!ctx.is_admin);
}
#[test]
fn authenticated_not_admin() {
let ctx = AuthContext::authenticated("user-1".into());
assert!(!ctx.is_admin);
}
#[test]
fn magic_code_create_and_verify() {
let store = MagicCodeStore::new();
let code = store.create("test@example.com");
assert_eq!(code.len(), 6);
assert!(store.verify("test@example.com", &code));
}
#[test]
fn magic_code_wrong_code_rejected() {
let store = MagicCodeStore::new();
store.create("test@example.com");
assert!(!store.verify("test@example.com", "000000"));
}
#[test]
fn magic_code_wrong_email_rejected() {
let store = MagicCodeStore::new();
let code = store.create("test@example.com");
assert!(!store.verify("other@example.com", &code));
}
#[test]
fn magic_code_consumed_after_verify() {
let store = MagicCodeStore::new();
let code = store.create("test@example.com");
assert!(store.verify("test@example.com", &code));
assert!(!store.verify("test@example.com", &code));
}
#[test]
fn magic_code_different_emails_independent() {
let store = MagicCodeStore::new();
let code1 = store.create("alice@example.com");
let code2 = store.create("bob@example.com");
assert!(store.verify("alice@example.com", &code1));
assert!(store.verify("bob@example.com", &code2));
}
#[test]
fn constant_time_eq_equal() {
assert!(constant_time_eq(b"hello", b"hello"));
assert!(constant_time_eq(b"", b""));
}
#[test]
fn constant_time_eq_not_equal() {
assert!(!constant_time_eq(b"hello", b"world"));
assert!(!constant_time_eq(b"hello", b"hell"));
assert!(!constant_time_eq(b"a", b"b"));
}
#[test]
fn generated_tokens_are_unique() {
let t1 = generate_token();
let t2 = generate_token();
assert_ne!(t1, t2);
assert!(t1.starts_with("pylon_"));
assert!(t2.starts_with("pylon_"));
assert_eq!(t1.len(), 6 + 64);
}
#[test]
fn oauth_registry_empty() {
let reg = OAuthRegistry::new();
assert!(reg.get("google").is_none());
}
#[test]
fn oauth_registry_register_and_get() {
let mut reg = OAuthRegistry::new();
reg.register(OAuthConfig {
provider: "google".into(),
client_id: "test-id".into(),
client_secret: "test-secret".into(),
redirect_uri: "http://localhost/callback".into(),
..Default::default()
});
let config = reg.get("google").unwrap();
assert_eq!(config.client_id, "test-id");
assert!(config.auth_url().contains("accounts.google.com"));
}
#[test]
fn every_builtin_provider_routes_through_oauth_config() {
for spec in provider::builtin::all() {
let cfg = OAuthConfig {
provider: spec.id.into(),
client_id: "cid".into(),
client_secret: "csecret".into(),
redirect_uri: "https://app/cb".into(),
tenant: if spec.id == "microsoft" {
Some("contoso".into())
} else {
None
},
apple: if spec.id == "apple" {
Some(provider::AppleConfig {
team_id: "T".into(),
key_id: "K".into(),
private_key_pem: "no".into(),
})
} else {
None
},
..Default::default()
};
let auth = cfg.auth_url();
assert!(!auth.is_empty(), "{}: empty auth_url", spec.id);
let expected_param = format!("{}=cid", spec.client_id_param);
assert!(
auth.contains(&expected_param),
"{}: missing {}; got auth_url: {}",
spec.id,
expected_param,
auth,
);
assert!(!cfg.token_url().is_empty(), "{}: empty token_url", spec.id);
if spec.id == "apple" {
assert!(
auth.contains("response_mode=form_post"),
"apple auth_url must include response_mode=form_post; got {auth}"
);
}
}
}
#[test]
fn microsoft_tenant_placeholder_resolves() {
let cfg = OAuthConfig {
provider: "microsoft".into(),
client_id: "id".into(),
client_secret: "secret".into(),
redirect_uri: "https://app/cb".into(),
tenant: Some("contoso.onmicrosoft.com".into()),
..Default::default()
};
assert!(cfg.auth_url().contains("/contoso.onmicrosoft.com/"));
assert!(cfg.token_url().contains("/contoso.onmicrosoft.com/"));
}
#[test]
fn microsoft_default_tenant_common() {
let cfg = OAuthConfig {
provider: "microsoft".into(),
client_id: "id".into(),
client_secret: "secret".into(),
redirect_uri: "https://app/cb".into(),
..Default::default()
};
assert!(cfg.auth_url().contains("/common/"));
assert!(cfg.token_url().contains("/common/"));
}
#[test]
fn scopes_override_replaces_spec_default() {
let cfg = OAuthConfig {
provider: "github".into(),
client_id: "id".into(),
client_secret: "secret".into(),
redirect_uri: "https://app/cb".into(),
scopes_override: Some("repo user:email".into()),
..Default::default()
};
let auth = cfg.auth_url();
assert!(auth.contains("scope=repo%20user%3Aemail"), "got: {auth}");
}
#[test]
fn apple_exchange_requires_apple_config() {
let cfg = OAuthConfig {
provider: "apple".into(),
client_id: "com.example.app".into(),
client_secret: String::new(),
redirect_uri: "https://app/cb".into(),
apple: None, ..Default::default()
};
let err = cfg.exchange_code_full("x").unwrap_err();
assert!(err.contains("apple provider requires"), "got: {err}");
}
#[test]
fn oidc_issuer_uses_discovered_endpoints() {
let issuer = "https://acme.test.invalid";
provider::oidc_cache::insert_for_test(
issuer,
provider::DiscoveredSpec {
auth_url: "https://acme.test.invalid/authorize".into(),
token_url: "https://acme.test.invalid/oauth/token".into(),
userinfo_url: Some("https://acme.test.invalid/userinfo".into()),
scopes: "openid email profile".into(),
userinfo_parser: provider::UserinfoParser::Oidc,
token_exchange: provider::TokenExchangeShape::Standard,
},
);
let cfg = OAuthConfig {
provider: "auth0".into(), client_id: "id".into(),
client_secret: "secret".into(),
redirect_uri: "https://app/cb".into(),
oidc_issuer: Some(issuer.into()),
..Default::default()
};
assert!(cfg.auth_url().starts_with("https://acme.test.invalid/authorize?"));
assert_eq!(cfg.token_url(), "https://acme.test.invalid/oauth/token");
assert_eq!(cfg.userinfo_url(), "https://acme.test.invalid/userinfo");
}
#[test]
fn apple_auth_url_includes_form_post() {
let cfg = OAuthConfig {
provider: "apple".into(),
client_id: "com.example.app".into(),
client_secret: String::new(),
redirect_uri: "https://app/cb".into(),
apple: Some(provider::AppleConfig {
team_id: "T".into(),
key_id: "K".into(),
private_key_pem: "no".into(),
}),
..Default::default()
};
let auth = cfg.auth_url();
assert!(auth.contains("response_mode=form_post"), "got: {auth}");
assert_eq!(cfg.userinfo_url(), "");
}
#[test]
fn apple_id_token_decode_extracts_identity() {
let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(b"{\"alg\":\"none\"}");
use base64::Engine;
let claims = serde_json::json!({
"iss": "https://appleid.apple.com",
"sub": "001234.abc.def",
"aud": "com.example.app",
"email": "user@privaterelay.appleid.com",
"email_verified": "true",
});
let claims_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD
.encode(claims.to_string().as_bytes());
let id_token = format!("{header}.{claims_b64}.signature_ignored");
let cfg = OAuthConfig {
provider: "apple".into(),
client_id: "com.example.app".into(),
client_secret: String::new(),
redirect_uri: "https://app/cb".into(),
apple: Some(provider::AppleConfig {
team_id: "T".into(),
key_id: "K".into(),
private_key_pem: "no".into(),
}),
..Default::default()
};
let info = cfg
.fetch_userinfo_with_id_token("ignored", Some(&id_token))
.expect("apple id_token decode");
assert_eq!(info.provider_account_id, "001234.abc.def");
assert_eq!(info.email, "user@privaterelay.appleid.com");
let err = cfg.fetch_userinfo_full("token").unwrap_err();
assert!(err.contains("apple login requires"), "got: {err}");
}
#[test]
fn twitter_auth_url_includes_pkce() {
let cfg = OAuthConfig {
provider: "twitter".into(),
client_id: "tw_client".into(),
client_secret: "tw_secret".into(),
redirect_uri: "https://app/cb".into(),
..Default::default()
};
let (url, verifier) = cfg.auth_url_with_pkce("state123").expect("twitter pkce");
let v = verifier.expect("twitter must produce verifier");
assert!(v.len() >= 43, "PKCE verifier must be 43+ chars: got {v}");
assert!(url.contains("code_challenge="), "got: {url}");
assert!(url.contains("code_challenge_method=S256"), "got: {url}");
let google = OAuthConfig {
provider: "google".into(),
client_id: "g".into(),
client_secret: "g".into(),
redirect_uri: "https://app/cb".into(),
..Default::default()
};
let (gurl, gverifier) = google.auth_url_with_pkce("st").expect("google");
assert!(gverifier.is_none(), "google should not add PKCE");
assert!(!gurl.contains("code_challenge"), "got: {gurl}");
}
#[test]
fn tiktok_uses_client_key_and_comma_scopes() {
let cfg = OAuthConfig {
provider: "tiktok".into(),
client_id: "tk_client".into(),
client_secret: "tk_secret".into(),
redirect_uri: "https://app/cb".into(),
scopes_override: Some("user.info.basic video.list".into()),
..Default::default()
};
let auth = cfg.auth_url();
assert!(auth.contains("client_key=tk_client"), "got: {auth}");
assert!(auth.contains("user.info.basic%2Cvideo.list"), "got: {auth}");
assert!(!auth.contains("user.info.basic%20video.list"), "got: {auth}");
}
#[test]
fn token_exchange_url_encodes_code() {
let raw = "code+with/special=chars";
let encoded = url_encode(raw);
assert!(!encoded.contains('+'));
assert!(!encoded.contains('/'));
assert!(!encoded.contains('='));
assert!(encoded.contains("%2B"));
assert!(encoded.contains("%2F"));
assert!(encoded.contains("%3D"));
}
#[test]
fn sanitize_token_error_redacts_secrets() {
let raw = "HTTP 400: error=invalid_grant&client_secret=sk_real_secret_value&code_verifier=verifierxyz&hint=check%20your%20code";
let scrubbed = sanitize_token_error(raw.into());
assert!(!scrubbed.contains("sk_real_secret_value"));
assert!(!scrubbed.contains("verifierxyz"));
assert!(scrubbed.contains("client_secret=***"));
assert!(scrubbed.contains("code_verifier=***"));
assert!(scrubbed.contains("invalid_grant"));
assert!(scrubbed.contains("hint=check%20your%20code"));
}
#[test]
fn sanitize_token_error_redacts_json_secrets() {
let raw = r#"HTTP 400: {"error":"invalid_grant","client_secret":"sk_jsonleak","refresh_token":"rt_abcxyz","id_token":"ey.payload.sig"}"#;
let scrubbed = sanitize_token_error(raw.into());
assert!(!scrubbed.contains("sk_jsonleak"), "got: {scrubbed}");
assert!(!scrubbed.contains("rt_abcxyz"), "got: {scrubbed}");
assert!(!scrubbed.contains("ey.payload.sig"), "got: {scrubbed}");
assert!(scrubbed.contains(r#""client_secret":"***""#), "got: {scrubbed}");
assert!(scrubbed.contains(r#""refresh_token":"***""#), "got: {scrubbed}");
assert!(scrubbed.contains(r#""id_token":"***""#), "got: {scrubbed}");
assert!(scrubbed.contains("invalid_grant"));
}
#[test]
fn sanitize_token_error_handles_utf8() {
let raw = "HTTP 400: ⚠️ provider says the secret is wrong: client_secret=sk_x";
let scrubbed = sanitize_token_error(raw.into());
assert!(scrubbed.contains("⚠️"), "non-ASCII chars must survive: {scrubbed}");
assert!(!scrubbed.contains("sk_x"));
assert!(scrubbed.contains("client_secret=***"));
}
#[test]
fn oidc_discovery_picks_token_auth_method() {
let json_post = r#"{
"issuer": "https://acme.test/",
"authorization_endpoint": "https://acme.test/auth",
"token_endpoint": "https://acme.test/token",
"token_endpoint_auth_methods_supported": ["client_secret_post"]
}"#;
let spec = provider::OidcDiscoveryDoc::parse(json_post).unwrap().into_spec();
assert!(matches!(
spec.token_exchange,
provider::TokenExchangeShape::Standard
));
let json_default = r#"{
"issuer": "https://acme.test/",
"authorization_endpoint": "https://acme.test/auth",
"token_endpoint": "https://acme.test/token"
}"#;
let spec = provider::OidcDiscoveryDoc::parse(json_default)
.unwrap()
.into_spec();
assert!(matches!(
spec.token_exchange,
provider::TokenExchangeShape::BasicAuth
));
}
#[test]
fn oidc_discovery_rejects_incomplete_doc() {
let json = r#"{
"issuer": "https://acme.test/",
"authorization_endpoint": "https://acme.test/auth"
}"#;
let err = provider::OidcDiscoveryDoc::parse(json).unwrap_err();
assert!(err.contains("token_endpoint"), "got: {err}");
}
#[test]
fn from_env_picks_up_discord() {
let key_id = "PYLON_OAUTH_DISCORD_CLIENT_ID";
let key_secret = "PYLON_OAUTH_DISCORD_CLIENT_SECRET";
std::env::set_var(key_id, "discord-test-id");
std::env::set_var(key_secret, "discord-test-secret");
let reg = OAuthRegistry::from_env();
let discord = reg.get("discord").expect("discord registered");
assert_eq!(discord.client_id, "discord-test-id");
assert!(discord.auth_url().contains("discord.com"));
std::env::remove_var(key_id);
std::env::remove_var(key_secret);
}
#[test]
fn guest_session() {
let store = SessionStore::new();
let session = store.create_guest();
assert!(session.user_id.starts_with("guest_"));
assert!(!session.token.is_empty());
let ctx = store.resolve(Some(&session.token));
assert!(ctx.is_authenticated());
assert!(ctx.user_id.unwrap().starts_with("guest_"));
}
#[test]
fn upgrade_guest_to_real_user() {
let store = SessionStore::new();
let session = store.create_guest();
assert!(session.user_id.starts_with("guest_"));
let upgraded = store.upgrade(&session.token, "real-user-123".into());
assert!(upgraded);
let ctx = store.resolve(Some(&session.token));
assert_eq!(ctx.user_id, Some("real-user-123".into()));
}
#[test]
fn upgrade_invalid_token_fails() {
let store = SessionStore::new();
let upgraded = store.upgrade("nonexistent-token", "user".into());
assert!(!upgraded);
}
#[test]
fn guest_context() {
let ctx = AuthContext::guest("guest_123".into());
assert!(!ctx.is_authenticated());
assert!(ctx.is_guest);
assert!(!ctx.is_admin);
assert_eq!(ctx.user_id, Some("guest_123".into()));
assert!(!AuthMode::User.check(&ctx));
assert!(AuthMode::Public.check(&ctx));
}
#[test]
fn oauth_token_urls() {
let google = OAuthConfig {
provider: "google".into(),
client_id: "x".into(),
client_secret: "x".into(),
redirect_uri: "x".into(),
..Default::default()
};
assert_eq!(google.token_url(), "https://oauth2.googleapis.com/token");
let github = OAuthConfig {
provider: "github".into(),
client_id: "x".into(),
client_secret: "x".into(),
redirect_uri: "x".into(),
..Default::default()
};
assert_eq!(
github.token_url(),
"https://github.com/login/oauth/access_token"
);
let unknown = OAuthConfig {
provider: "unknown".into(),
client_id: "x".into(),
client_secret: "x".into(),
redirect_uri: "x".into(),
..Default::default()
};
assert_eq!(unknown.token_url(), "");
assert!(unknown.auth_url().is_empty());
}
#[test]
fn oauth_auth_url_github() {
let config = OAuthConfig {
provider: "github".into(),
client_id: "gh-id".into(),
client_secret: "gh-secret".into(),
redirect_uri: "http://localhost/cb".into(),
..Default::default()
};
assert!(config.auth_url().contains("github.com"));
assert!(config.auth_url().contains("gh-id"));
}
#[test]
fn oauth_auth_url_with_state() {
let config = OAuthConfig {
provider: "google".into(),
client_id: "test-id".into(),
client_secret: "test-secret".into(),
redirect_uri: "http://localhost/cb".into(),
..Default::default()
};
let url = config.auth_url_with_state("random_state_123");
assert!(url.contains("&state=random_state_123"));
}
#[test]
fn oauth_state_store_create_and_validate() {
let store = OAuthStateStore::new();
let token = store.create("google", "https://app/cb", "https://app/login");
let rec = store.validate(&token, "google").expect("valid first time");
assert_eq!(rec.callback_url, "https://app/cb");
assert_eq!(rec.error_callback_url, "https://app/login");
assert!(store.validate(&token, "google").is_none());
}
#[test]
fn oauth_state_store_wrong_provider_rejected() {
let store = OAuthStateStore::new();
let token = store.create("google", "https://app/cb", "https://app/cb");
assert!(store.validate(&token, "github").is_none());
}
#[test]
fn oauth_state_store_invalid_state_rejected() {
let store = OAuthStateStore::new();
assert!(store.validate("nonexistent", "google").is_none());
}
#[test]
fn validate_trusted_redirect_basics() {
let trusted = vec!["http://localhost:3000".to_string()];
assert!(validate_trusted_redirect("http://localhost:3000/dashboard", &trusted).is_ok());
assert!(validate_trusted_redirect("http://localhost:3000", &trusted).is_ok());
assert!(validate_trusted_redirect("http://localhost:3000/x?y=1", &trusted).is_ok());
assert!(matches!(
validate_trusted_redirect("http://localhost:4321/dashboard", &trusted),
Err(TrustedOriginError::NotTrusted { .. })
));
assert!(matches!(
validate_trusted_redirect("javascript:alert(1)", &trusted),
Err(TrustedOriginError::NotHttp)
));
assert!(matches!(
validate_trusted_redirect("", &trusted),
Err(TrustedOriginError::Empty)
));
}
}