use std::collections::HashMap;
use base64::Engine as _;
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum SsoError {
#[error("token expired")]
TokenExpired,
#[error("invalid issuer: expected {expected}, got {got}")]
InvalidIssuer { expected: String, got: String },
#[error("invalid audience")]
InvalidAudience,
#[error("malformed token: {0}")]
MalformedToken(String),
#[error("unsupported provider type")]
UnsupportedProvider,
#[error("base64 decode error: {0}")]
Base64Error(String),
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum SsoProviderType {
Oidc,
Saml,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SsoConfig {
pub provider_type: SsoProviderType,
pub issuer_url: String,
pub client_id: String,
pub redirect_uri: String,
pub scopes: Vec<String>,
}
impl Default for SsoConfig {
fn default() -> Self {
Self {
provider_type: SsoProviderType::Oidc,
issuer_url: String::new(),
client_id: String::new(),
redirect_uri: String::new(),
scopes: vec![
"openid".to_string(),
"profile".to_string(),
"email".to_string(),
],
}
}
}
#[derive(Debug, Clone)]
pub struct SsoUserInfo {
pub subject: String,
pub email: Option<String>,
pub name: Option<String>,
pub groups: Vec<String>,
pub raw_claims: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone)]
pub struct OidcCallback {
pub code: String,
pub state: String,
}
pub struct OidcValidator {
config: SsoConfig,
}
impl OidcValidator {
pub fn new(config: SsoConfig) -> Self {
Self { config }
}
pub fn validate_id_token(&self, id_token: &str) -> Result<SsoUserInfo, SsoError> {
let claims = parse_jwt_claims(id_token)?;
let exp = claims
.get("exp")
.and_then(|v| v.as_i64())
.ok_or_else(|| SsoError::MalformedToken("missing 'exp' claim".to_string()))?;
let now_ts = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0);
if exp <= now_ts {
return Err(SsoError::TokenExpired);
}
let iss = claims
.get("iss")
.and_then(|v| v.as_str())
.ok_or_else(|| SsoError::MalformedToken("missing 'iss' claim".to_string()))?
.to_string();
if iss != self.config.issuer_url {
return Err(SsoError::InvalidIssuer {
expected: self.config.issuer_url.clone(),
got: iss,
});
}
let aud_matches = match claims.get("aud") {
Some(serde_json::Value::String(s)) => s == &self.config.client_id,
Some(serde_json::Value::Array(arr)) => arr.iter().any(|v| {
v.as_str()
.map(|s| s == self.config.client_id)
.unwrap_or(false)
}),
_ => false,
};
if !aud_matches {
return Err(SsoError::InvalidAudience);
}
let subject = claims
.get("sub")
.and_then(|v| v.as_str())
.ok_or_else(|| SsoError::MalformedToken("missing 'sub' claim".to_string()))?
.to_string();
let email = claims
.get("email")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let name = claims
.get("name")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let groups = extract_string_list(&claims, "groups")
.into_iter()
.chain(extract_string_list(&claims, "roles"))
.collect();
Ok(SsoUserInfo {
subject,
email,
name,
groups,
raw_claims: claims,
})
}
pub fn authorization_url(&self, state: &str, nonce: &str) -> String {
let scope = self.config.scopes.join(" ");
let params = [
("response_type", "code"),
("client_id", self.config.client_id.as_str()),
("redirect_uri", self.config.redirect_uri.as_str()),
("scope", scope.as_str()),
("state", state),
("nonce", nonce),
];
let query = params
.iter()
.map(|(k, v)| format!("{}={}", k, percent_encode(v)))
.collect::<Vec<_>>()
.join("&");
format!(
"{}/authorize?{}",
self.config.issuer_url.trim_end_matches('/'),
query
)
}
pub fn parse_callback(&self, query: &str) -> Result<OidcCallback, SsoError> {
let mut code: Option<String> = None;
let mut state: Option<String> = None;
for pair in query.split('&') {
let mut parts = pair.splitn(2, '=');
let key = parts.next().unwrap_or("").trim();
let value = parts.next().unwrap_or("").trim();
match key {
"code" => code = Some(value.to_string()),
"state" => state = Some(value.to_string()),
_ => {}
}
}
let code =
code.ok_or_else(|| SsoError::MalformedToken("missing 'code' parameter".to_string()))?;
let state = state
.ok_or_else(|| SsoError::MalformedToken("missing 'state' parameter".to_string()))?;
Ok(OidcCallback { code, state })
}
}
pub(crate) fn parse_jwt_claims(
token: &str,
) -> Result<HashMap<String, serde_json::Value>, SsoError> {
let segments: Vec<&str> = token.splitn(3, '.').collect();
if segments.len() != 3 {
return Err(SsoError::MalformedToken(
"JWT must have three dot-separated segments".to_string(),
));
}
let payload_b64 = segments[1];
let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(payload_b64)
.map_err(|e| SsoError::Base64Error(e.to_string()))?;
let json_str = std::str::from_utf8(&decoded)
.map_err(|e| SsoError::MalformedToken(format!("payload is not valid UTF-8: {}", e)))?;
let claims: HashMap<String, serde_json::Value> = serde_json::from_str(json_str)
.map_err(|e| SsoError::MalformedToken(format!("payload JSON parse error: {}", e)))?;
Ok(claims)
}
fn extract_string_list(claims: &HashMap<String, serde_json::Value>, key: &str) -> Vec<String> {
match claims.get(key) {
Some(serde_json::Value::Array(arr)) => arr
.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect(),
Some(serde_json::Value::String(s)) => vec![s.clone()],
_ => Vec::new(),
}
}
fn percent_encode(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for byte in s.bytes() {
match byte {
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
out.push(byte as char);
}
b' ' => out.push('+'),
b => {
use std::fmt::Write as _;
let _ = write!(out, "%{:02X}", b);
}
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
fn build_fake_jwt(payload: &serde_json::Value) -> String {
let header = base64::engine::general_purpose::URL_SAFE_NO_PAD
.encode(r#"{"alg":"RS256","typ":"JWT"}"#);
let payload_b64 =
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(payload.to_string());
format!("{}.{}.fakesig", header, payload_b64)
}
fn make_config() -> SsoConfig {
SsoConfig {
provider_type: SsoProviderType::Oidc,
issuer_url: "https://accounts.example.com".to_string(),
client_id: "test-client".to_string(),
redirect_uri: "https://app.example.com/callback".to_string(),
scopes: vec![
"openid".to_string(),
"profile".to_string(),
"email".to_string(),
],
}
}
#[test]
fn test_sso_config_oidc_serialization() {
let config = make_config();
let json = serde_json::to_string(&config).expect("serialize");
let restored: SsoConfig = serde_json::from_str(&json).expect("deserialize");
assert_eq!(restored.issuer_url, config.issuer_url);
assert_eq!(restored.client_id, config.client_id);
assert_eq!(restored.redirect_uri, config.redirect_uri);
assert_eq!(restored.scopes, config.scopes);
assert_eq!(restored.provider_type, SsoProviderType::Oidc);
}
#[test]
fn test_authorization_url_contains_params() {
let validator = OidcValidator::new(make_config());
let url = validator.authorization_url("state-abc", "nonce-xyz");
assert!(url.contains("client_id=test-client"), "missing client_id");
assert!(url.contains("redirect_uri="), "missing redirect_uri");
assert!(url.contains("scope="), "missing scope");
assert!(url.contains("state=state-abc"), "missing state");
assert!(url.contains("nonce=nonce-xyz"), "missing nonce");
assert!(url.contains("response_type=code"), "missing response_type");
}
#[test]
fn test_parse_callback_valid() {
let validator = OidcValidator::new(make_config());
let cb = validator
.parse_callback("code=authcode123&state=mystate456")
.expect("parse callback");
assert_eq!(cb.code, "authcode123");
assert_eq!(cb.state, "mystate456");
}
#[test]
fn test_validate_id_token_expired() {
let validator = OidcValidator::new(make_config());
let payload = serde_json::json!({
"sub": "user-001",
"iss": "https://accounts.example.com",
"aud": "test-client",
"exp": 1_000_000_i64, "iat": 900_000_i64,
"email": "alice@example.com"
});
let token = build_fake_jwt(&payload);
let err = validator
.validate_id_token(&token)
.expect_err("should fail with expired token");
assert!(
matches!(err, SsoError::TokenExpired),
"expected TokenExpired, got: {}",
err
);
}
#[test]
fn test_validate_id_token_wrong_issuer() {
let validator = OidcValidator::new(make_config());
let future_exp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() as i64 + 3600)
.unwrap_or(9_999_999_999);
let payload = serde_json::json!({
"sub": "user-001",
"iss": "https://evil.example.com",
"aud": "test-client",
"exp": future_exp,
"iat": future_exp - 60,
"email": "alice@example.com"
});
let token = build_fake_jwt(&payload);
let err = validator
.validate_id_token(&token)
.expect_err("should fail with wrong issuer");
assert!(
matches!(err, SsoError::InvalidIssuer { .. }),
"expected InvalidIssuer, got: {}",
err
);
}
#[test]
fn test_validate_id_token_valid_claims() {
let validator = OidcValidator::new(make_config());
let future_exp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() as i64 + 3600)
.unwrap_or(9_999_999_999);
let payload = serde_json::json!({
"sub": "user-42",
"iss": "https://accounts.example.com",
"aud": "test-client",
"exp": future_exp,
"iat": future_exp - 60,
"email": "alice@example.com",
"name": "Alice Smith",
"groups": ["engineers", "rdf-users"]
});
let token = build_fake_jwt(&payload);
let user_info = validator
.validate_id_token(&token)
.expect("valid token should be accepted");
assert_eq!(user_info.subject, "user-42");
assert_eq!(user_info.email.as_deref(), Some("alice@example.com"));
assert_eq!(user_info.name.as_deref(), Some("Alice Smith"));
assert!(user_info.groups.contains(&"engineers".to_string()));
assert!(user_info.groups.contains(&"rdf-users".to_string()));
}
#[test]
fn test_sso_user_info_fields() {
let mut raw = HashMap::new();
raw.insert(
"custom_claim".to_string(),
serde_json::Value::String("value".to_string()),
);
let info = SsoUserInfo {
subject: "sub-100".to_string(),
email: Some("bob@corp.com".to_string()),
name: Some("Bob Builder".to_string()),
groups: vec!["builders".to_string()],
raw_claims: raw,
};
assert_eq!(info.subject, "sub-100");
assert_eq!(info.email.as_deref(), Some("bob@corp.com"));
assert_eq!(info.name.as_deref(), Some("Bob Builder"));
assert_eq!(info.groups, vec!["builders"]);
assert!(info.raw_claims.contains_key("custom_claim"));
}
}