use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use sqlx::FromRow;
use super::errors::OAuth2Error;
use super::main::OidcIdInfo;
use super::provider::{ProviderConfig, ProviderName};
use crate::session::UserId;
use crate::storage::CacheData;
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct OAuth2Account {
#[serde(skip_serializing)]
pub sequence_number: Option<i64>,
pub id: String,
pub user_id: String,
pub provider: String,
pub provider_user_id: String,
pub name: String,
pub email: String,
pub picture: Option<String>,
pub metadata: Value,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
impl Default for OAuth2Account {
fn default() -> Self {
Self {
sequence_number: None,
id: String::new(),
user_id: String::new(),
provider: String::new(),
provider_user_id: String::new(),
name: String::new(),
email: String::new(),
picture: None,
metadata: Value::Null,
created_at: Utc::now(),
updated_at: Utc::now(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct OidcUserInfo {
pub(crate) sub: String,
pub(crate) family_name: Option<String>,
pub name: Option<String>,
pub picture: Option<String>,
pub(crate) email: Option<String>,
pub(crate) given_name: Option<String>,
pub(crate) hd: Option<String>,
pub(crate) email_verified: Option<bool>,
pub(crate) preferred_username: Option<String>,
}
pub(crate) fn oauth2_account_from_idinfo(
idinfo: &OidcIdInfo,
ctx: &ProviderConfig,
) -> Result<OAuth2Account, OAuth2Error> {
let provider_name = ctx.provider_name;
let email = idinfo
.email
.clone()
.or_else(|| idinfo.preferred_username.clone())
.ok_or_else(|| {
OAuth2Error::Validation(format!(
"OIDC id_token from '{provider_name}' is missing both `email` and `preferred_username` claims"
))
})?;
let name = idinfo.name.clone().unwrap_or_else(|| email.clone());
Ok(OAuth2Account {
sequence_number: None,
id: String::new(),
user_id: String::new(),
name,
email,
picture: idinfo.picture.clone(),
provider: provider_name.to_string(),
provider_user_id: format!("{}_{}", provider_name, idinfo.sub),
metadata: json!({
"family_name": idinfo.family_name,
"given_name": idinfo.given_name,
"hd": idinfo.hd,
"verified_email": idinfo.email_verified,
}),
created_at: Utc::now(),
updated_at: Utc::now(),
})
}
pub(crate) fn oauth2_account_from_idinfo_and_userinfo(
idinfo: &OidcIdInfo,
userinfo: &OidcUserInfo,
ctx: &ProviderConfig,
) -> Result<OAuth2Account, OAuth2Error> {
validate_claim_match(idinfo, userinfo, ctx)?;
let provider_name = ctx.provider_name;
let email = idinfo
.email
.clone()
.or_else(|| userinfo.email.clone())
.or_else(|| idinfo.preferred_username.clone())
.or_else(|| userinfo.preferred_username.clone())
.ok_or_else(|| {
OAuth2Error::Validation(format!(
"OIDC response from '{provider_name}' is missing `email` / `preferred_username` in both id_token and userinfo"
))
})?;
let name = idinfo
.name
.clone()
.or_else(|| userinfo.name.clone())
.unwrap_or_else(|| email.clone());
let picture = idinfo.picture.clone().or_else(|| userinfo.picture.clone());
let family_name = idinfo
.family_name
.clone()
.or_else(|| userinfo.family_name.clone());
let given_name = idinfo
.given_name
.clone()
.or_else(|| userinfo.given_name.clone());
let hd = idinfo.hd.clone().or_else(|| userinfo.hd.clone());
let email_verified = idinfo.email_verified.or(userinfo.email_verified);
Ok(OAuth2Account {
sequence_number: None,
id: String::new(),
user_id: String::new(),
name,
email,
picture,
provider: provider_name.to_string(),
provider_user_id: format!("{}_{}", provider_name, idinfo.sub),
metadata: json!({
"family_name": family_name,
"given_name": given_name,
"hd": hd,
"email_verified": email_verified,
}),
created_at: Utc::now(),
updated_at: Utc::now(),
})
}
fn validate_claim_match(
idinfo: &OidcIdInfo,
userinfo: &OidcUserInfo,
ctx: &ProviderConfig,
) -> Result<(), OAuth2Error> {
let provider = ctx.provider_name;
check_strict(
"email",
idinfo.email.as_deref(),
userinfo.email.as_deref(),
provider,
)?;
check_strict_bool(
"email_verified",
idinfo.email_verified,
userinfo.email_verified,
provider,
)?;
check_strict(
"preferred_username",
idinfo.preferred_username.as_deref(),
userinfo.preferred_username.as_deref(),
provider,
)?;
check_strict("hd", idinfo.hd.as_deref(), userinfo.hd.as_deref(), provider)?;
let strict = ctx.strict_display_claims;
check_display(
"name",
idinfo.name.as_deref(),
userinfo.name.as_deref(),
provider,
strict,
)?;
check_display(
"picture",
idinfo.picture.as_deref(),
userinfo.picture.as_deref(),
provider,
strict,
)?;
check_display(
"family_name",
idinfo.family_name.as_deref(),
userinfo.family_name.as_deref(),
provider,
strict,
)?;
check_display(
"given_name",
idinfo.given_name.as_deref(),
userinfo.given_name.as_deref(),
provider,
strict,
)?;
Ok(())
}
fn check_strict(
field: &'static str,
idinfo_value: Option<&str>,
userinfo_value: Option<&str>,
provider: ProviderName,
) -> Result<(), OAuth2Error> {
match (idinfo_value, userinfo_value) {
(Some(a), Some(b)) if a != b => Err(OAuth2Error::ClaimMismatch {
field,
idinfo_value: a.to_string(),
userinfo_value: b.to_string(),
provider: provider.to_string(),
}),
_ => Ok(()),
}
}
fn check_strict_bool(
field: &'static str,
idinfo_value: Option<bool>,
userinfo_value: Option<bool>,
provider: ProviderName,
) -> Result<(), OAuth2Error> {
match (idinfo_value, userinfo_value) {
(Some(a), Some(b)) if a != b => Err(OAuth2Error::ClaimMismatch {
field,
idinfo_value: a.to_string(),
userinfo_value: b.to_string(),
provider: provider.to_string(),
}),
_ => Ok(()),
}
}
fn check_display(
field: &'static str,
idinfo_value: Option<&str>,
userinfo_value: Option<&str>,
provider: ProviderName,
strict: bool,
) -> Result<(), OAuth2Error> {
match (idinfo_value, userinfo_value) {
(Some(a), Some(b)) if a != b => {
if strict {
Err(OAuth2Error::ClaimMismatch {
field,
idinfo_value: a.to_string(),
userinfo_value: b.to_string(),
provider: provider.to_string(),
})
} else {
tracing::warn!(
security_event = "oauth2_claim_mismatch",
field = field,
idinfo_value = a,
userinfo_value = b,
provider = %provider,
"claim mismatch between id_token and /userinfo; using id_token value (set OAUTH2_<provider>_STRICT_DISPLAY_CLAIMS=true to reject)"
);
Ok(())
}
}
_ => Ok(()),
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub(crate) struct StateParams {
pub(crate) csrf_id: String,
pub(crate) nonce_id: String,
pub(crate) pkce_id: String,
pub(crate) misc_id: Option<String>,
pub(crate) mode_id: Option<String>,
pub(crate) provider: String,
}
#[derive(Serialize, Clone, Deserialize, Debug)]
pub(crate) struct StoredToken {
pub(crate) token: String,
pub(crate) expires_at: DateTime<Utc>,
pub(crate) user_agent: Option<String>,
pub(crate) ttl: u64,
}
#[derive(Debug, Deserialize)]
pub struct AuthResponse {
pub(crate) code: String,
pub state: String,
_id_token: Option<String>,
}
#[derive(Debug, Deserialize, Serialize)]
pub(super) struct OidcTokenResponse {
pub(super) access_token: String,
token_type: String,
expires_in: Option<u64>,
refresh_token: Option<String>,
scope: Option<String>,
pub(super) id_token: Option<String>,
}
impl From<StoredToken> for CacheData {
fn from(data: StoredToken) -> Self {
Self {
value: serde_json::to_string(&data).expect("Failed to serialize StoredToken"),
}
}
}
impl TryFrom<CacheData> for StoredToken {
type Error = OAuth2Error;
fn try_from(data: CacheData) -> Result<Self, Self::Error> {
serde_json::from_str(&data.value).map_err(|e| OAuth2Error::Storage(e.to_string()))
}
}
#[allow(dead_code)]
#[derive(Debug, PartialEq)]
pub(crate) enum AccountSearchField {
Id(AccountId),
UserId(UserId),
Provider(Provider),
ProviderUserId(ProviderUserId),
Name(DisplayName),
Email(Email),
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum OAuth2Mode {
AddToUser,
CreateUser,
Login,
CreateUserOrLogin,
}
impl OAuth2Mode {
pub fn as_str(&self) -> &'static str {
match self {
Self::AddToUser => "add_to_user",
Self::CreateUser => "create_user",
Self::Login => "login",
Self::CreateUserOrLogin => "create_user_or_login",
}
}
}
impl std::str::FromStr for OAuth2Mode {
type Err = OAuth2Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"add_to_user" => Ok(Self::AddToUser),
"create_user" => Ok(Self::CreateUser),
"login" => Ok(Self::Login),
"create_user_or_login" => Ok(Self::CreateUserOrLogin),
_ => Err(OAuth2Error::InvalidMode(s.to_string())),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct AccountId(String);
impl AccountId {
pub fn new(id: String) -> Result<Self, crate::oauth2::OAuth2Error> {
use crate::oauth2::OAuth2Error;
if id.is_empty() {
return Err(OAuth2Error::Validation(
"Account ID cannot be empty".to_string(),
));
}
if id.len() > 255 {
return Err(OAuth2Error::Validation("Account ID too long".to_string()));
}
if !id
.chars()
.all(|c| c.is_ascii_alphanumeric() || matches!(c, '-' | '_' | '.' | '@' | '+'))
{
return Err(OAuth2Error::Validation(
"Account ID contains invalid characters".to_string(),
));
}
if id.contains("..") || id.contains("--") || id.contains("__") {
return Err(OAuth2Error::Validation(
"Account ID contains dangerous character sequences".to_string(),
));
}
Ok(AccountId(id))
}
pub fn as_str(&self) -> &str {
&self.0
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct Provider(String);
impl Provider {
pub fn new(provider: String) -> Result<Self, crate::oauth2::OAuth2Error> {
use crate::oauth2::OAuth2Error;
if provider.is_empty() {
return Err(OAuth2Error::Validation(
"Provider name cannot be empty".to_string(),
));
}
if provider.len() > 50 {
return Err(OAuth2Error::Validation(
"Provider name too long".to_string(),
));
}
if !provider
.chars()
.all(|c| c.is_ascii_alphanumeric() || matches!(c, '-' | '_' | '.'))
{
return Err(OAuth2Error::Validation(
"Provider name contains invalid characters".to_string(),
));
}
if provider.starts_with('-') || provider.starts_with('_') || provider.starts_with('.') {
return Err(OAuth2Error::Validation(
"Provider name cannot start with special characters".to_string(),
));
}
Ok(Provider(provider))
}
pub fn as_str(&self) -> &str {
&self.0
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ProviderUserId(String);
impl ProviderUserId {
pub fn new(id: String) -> Result<Self, crate::oauth2::OAuth2Error> {
use crate::oauth2::OAuth2Error;
if id.is_empty() {
return Err(OAuth2Error::Validation(
"Provider user ID cannot be empty".to_string(),
));
}
if id.len() > 512 {
return Err(OAuth2Error::Validation(
"Provider user ID too long".to_string(),
));
}
if !id.chars().all(|c| {
c.is_ascii_alphanumeric()
|| matches!(c, '-' | '_' | '.' | '@' | '+' | '=' | '(' | ')' | '|')
}) {
return Err(OAuth2Error::Validation(
"Provider user ID contains invalid characters".to_string(),
));
}
if id.contains("..") || id.contains("--") || id.contains("__") {
return Err(OAuth2Error::Validation(
"Provider user ID contains dangerous character sequences".to_string(),
));
}
Ok(ProviderUserId(id))
}
pub fn as_str(&self) -> &str {
&self.0
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct DisplayName(String);
impl DisplayName {
#[allow(dead_code)] pub fn new(name: String) -> Result<Self, crate::oauth2::OAuth2Error> {
use crate::oauth2::OAuth2Error;
if name.is_empty() {
return Err(OAuth2Error::Validation(
"Display name cannot be empty".to_string(),
));
}
if name.len() > 100 {
return Err(OAuth2Error::Validation("Display name too long".to_string()));
}
if name.trim().is_empty() {
return Err(OAuth2Error::Validation(
"Display name cannot consist only of whitespace".to_string(),
));
}
if name.contains("..") || name.contains("--") || name.contains("__") {
return Err(OAuth2Error::Validation(
"Display name contains dangerous character sequences".to_string(),
));
}
Ok(DisplayName(name))
}
pub fn as_str(&self) -> &str {
&self.0
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct Email(String);
impl Email {
#[allow(dead_code)] pub fn new(email: String) -> Result<Self, crate::oauth2::OAuth2Error> {
use crate::oauth2::OAuth2Error;
if email.is_empty() {
return Err(OAuth2Error::Validation("Email cannot be empty".to_string()));
}
if email.len() < 3 {
return Err(OAuth2Error::Validation("Email too short".to_string()));
}
if email.len() > 254 {
return Err(OAuth2Error::Validation("Email too long".to_string()));
}
if !email.contains('@') {
return Err(OAuth2Error::Validation(
"Email must contain @ symbol".to_string(),
));
}
let parts: Vec<&str> = email.split('@').collect();
if parts.len() != 2 || parts[0].is_empty() || parts[1].is_empty() {
return Err(OAuth2Error::Validation(
"Email format is invalid".to_string(),
));
}
Ok(Email(email))
}
pub fn as_str(&self) -> &str {
&self.0
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct OAuth2State(String);
impl std::fmt::Display for OAuth2State {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl OAuth2State {
pub fn new(state: String) -> Result<Self, super::errors::OAuth2Error> {
use super::errors::OAuth2Error;
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
if state.is_empty() {
return Err(OAuth2Error::DecodeState(
"OAuth2 state cannot be empty".to_string(),
));
}
if state.len() < 10 {
return Err(OAuth2Error::DecodeState(
"OAuth2 state too short".to_string(),
));
}
if state.len() > 8192 {
return Err(OAuth2Error::DecodeState(
"OAuth2 state too long".to_string(),
));
}
let decoded_bytes = URL_SAFE_NO_PAD
.decode(&state)
.map_err(|e| OAuth2Error::DecodeState(format!("Invalid base64url encoding: {e}")))?;
let decoded_string = String::from_utf8(decoded_bytes).map_err(|e| {
OAuth2Error::DecodeState(format!("Invalid UTF-8 in decoded state: {e}"))
})?;
let _: StateParams = serde_json::from_str(&decoded_string)
.map_err(|e| OAuth2Error::DecodeState(format!("Invalid JSON in decoded state: {e}")))?;
Ok(OAuth2State(state))
}
pub fn as_str(&self) -> &str {
&self.0
}
pub fn contains(&self, needle: char) -> bool {
self.0.contains(needle)
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum TokenType {
Csrf,
Nonce,
Pkce,
}
impl std::fmt::Display for TokenType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
impl TokenType {
pub fn as_str(&self) -> &str {
match self {
TokenType::Csrf => "csrf",
TokenType::Nonce => "nonce",
TokenType::Pkce => "pkce",
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct FedCMNonceResponse {
pub nonce: String,
pub nonce_id: String,
}
#[derive(Debug, Deserialize)]
pub struct FedCMCallbackRequest {
pub credential: String,
pub nonce_id: String,
pub mode: Option<String>,
}
#[cfg(test)]
mod tests;