use crate::errors::{CoreError, CoreResult, ProtocolError};
use crate::time::SharedClock;
use crate::tokens::{self, TokenLifetimes, TokenSet};
use chrono::Duration;
use ed25519_dalek::SigningKey;
use sui_id_shared::ids::{ClientId, UserId};
use sui_id_store::models::{AuditLogRow, AuthorizationCodeRow, RefreshTokenRow};
use sui_id_store::repos::{audit, auth_codes, clients, refresh_tokens, signing_keys, users};
use sui_id_store::Database;
const AUTH_CODE_LIFETIME_SECS: i64 = 60;
#[derive(Debug, Clone)]
pub struct AuthorizeParams {
pub client_id: ClientId,
pub redirect_uri: String,
pub response_type: String,
pub scope: String,
pub state: Option<String>,
pub nonce: Option<String>,
pub code_challenge: String,
pub code_challenge_method: String,
}
#[derive(Debug, Clone)]
pub struct AcceptedAuthorize {
pub params: AuthorizeParams,
}
pub async fn begin_authorization(db: &Database, params: AuthorizeParams) -> CoreResult<AcceptedAuthorize> {
if params.response_type != "code" {
return Err(CoreError::Protocol {
code: ProtocolError::UnsupportedResponseType,
description: format!("only response_type=code is supported, got {}", params.response_type),
});
}
if params.code_challenge.is_empty() {
return Err(CoreError::Protocol {
code: ProtocolError::InvalidRequest,
description: "code_challenge is required (PKCE)".into(),
});
}
if params.code_challenge_method != "S256" {
return Err(CoreError::Protocol {
code: ProtocolError::InvalidRequest,
description: "code_challenge_method must be S256".into(),
});
}
let client = clients::get(db, params.client_id).await.map_err(|e| match e {
sui_id_store::StoreError::NotFound => CoreError::Protocol {
code: ProtocolError::InvalidClient,
description: "unknown client_id".into(),
},
other => CoreError::from(other),
})?;
if client.is_disabled || client.is_deleted {
return Err(CoreError::Protocol {
code: ProtocolError::UnauthorizedClient,
description: "client is not allowed to use the authorization endpoint".into(),
});
}
if !is_redirect_uri_registered(&client.redirect_uris, ¶ms.redirect_uri) {
return Err(CoreError::Protocol {
code: ProtocolError::InvalidRequest,
description: "redirect_uri does not match a registered URI".into(),
});
}
enforce_scope_policy(&client.allowed_scopes, ¶ms.scope, &client.name)?;
Ok(AcceptedAuthorize { params })
}
pub fn is_redirect_uri_registered(registered: &[String], submitted: &str) -> bool {
registered.iter().any(|u| u == submitted)
}
fn enforce_scope_policy(allowed: &str, requested: &str, client_name: &str) -> CoreResult<()> {
let allowed_trimmed = allowed.trim();
if allowed_trimmed.is_empty() {
return Ok(());
}
let allowed_set: std::collections::HashSet<&str> =
allowed_trimmed.split_whitespace().collect();
for tok in requested.split_whitespace() {
if !allowed_set.contains(tok) {
return Err(CoreError::Protocol {
code: ProtocolError::InvalidScope,
description: format!(
"scope {tok:?} is not permitted for client {:?} \
(allowed: {:?}). \
Go to Admin → Clients → edit this client and add {tok:?} \
to the Allowed scopes field.",
client_name,
allowed_trimmed,
),
});
}
}
Ok(())
}
#[derive(Debug, Clone)]
pub struct AuthorizationResponseRedirect {
pub redirect_uri: String,
pub code: String,
pub state: Option<String>,
}
pub async fn complete_authorization(
db: &Database,
clock: &SharedClock,
user_id: UserId,
auth_methods: &[sui_id_shared::AuthMethod],
accepted: AcceptedAuthorize,
) -> CoreResult<AuthorizationResponseRedirect> {
let now = clock.now();
let code_plain = tokens::random_token(32);
let code_hash = tokens::sha256_hex(&code_plain);
let row = AuthorizationCodeRow {
code_hash,
client_id: accepted.params.client_id,
user_id,
redirect_uri: accepted.params.redirect_uri.clone(),
scope: accepted.params.scope.clone(),
nonce: accepted.params.nonce.clone(),
code_challenge: accepted.params.code_challenge.clone(),
code_challenge_method: accepted.params.code_challenge_method.clone(),
expires_at: now + Duration::seconds(AUTH_CODE_LIFETIME_SECS),
consumed: false,
created_at: now,
auth_methods: auth_methods.to_vec(),
};
auth_codes::insert(db, &row).await?;
Ok(AuthorizationResponseRedirect {
redirect_uri: accepted.params.redirect_uri,
code: code_plain,
state: accepted.params.state,
})
}
#[derive(Debug, Clone)]
pub struct CodeExchangeRequest {
pub code: String,
pub redirect_uri: String,
pub client_id: ClientId,
pub client_secret: Option<String>,
pub code_verifier: String,
}
#[derive(Debug, Clone)]
pub struct RefreshExchangeRequest {
pub refresh_token: String,
pub client_id: ClientId,
pub client_secret: Option<String>,
}
#[derive(Debug, Clone, Copy)]
pub struct IssuanceContext<'a> {
pub issuer: &'a str,
pub lifetimes: TokenLifetimes,
}
pub async fn exchange_code(
db: &Database,
clock: &SharedClock,
ctx: IssuanceContext<'_>,
req: CodeExchangeRequest,
) -> CoreResult<TokenSet> {
let client = clients::get(db, req.client_id).await.map_err(|e| match e {
sui_id_store::StoreError::NotFound => CoreError::Protocol {
code: ProtocolError::InvalidClient,
description: "unknown client".into(),
},
other => CoreError::from(other),
})?;
if client.is_disabled || client.is_deleted {
return Err(CoreError::Protocol {
code: ProtocolError::UnauthorizedClient,
description: "client is not allowed".into(),
});
}
authenticate_client(&client, req.client_secret.as_deref()).await?;
let code_hash = tokens::sha256_hex(&req.code);
let row = auth_codes::consume(db, &code_hash).await.map_err(|e| match e {
sui_id_store::StoreError::NotFound => CoreError::Protocol {
code: ProtocolError::InvalidGrant,
description: "code is unknown, expired, or already used".into(),
},
other => CoreError::from(other),
})?;
if row.client_id != req.client_id {
return Err(CoreError::Protocol {
code: ProtocolError::InvalidGrant,
description: "code was issued to a different client".into(),
});
}
if row.redirect_uri != req.redirect_uri {
return Err(CoreError::Protocol {
code: ProtocolError::InvalidGrant,
description: "redirect_uri does not match the original".into(),
});
}
tokens::verify_pkce(&row.code_challenge_method, &req.code_verifier, &row.code_challenge)?;
let user = users::get(db, row.user_id).await.map_err(|e| match e {
sui_id_store::StoreError::NotFound => CoreError::Protocol {
code: ProtocolError::InvalidGrant,
description: "user not found".into(),
},
other => CoreError::from(other),
})?;
if user.is_disabled || user.is_deleted {
let _ = audit::append(
db,
&AuditLogRow {
at: chrono::Utc::now(),
actor: Some(row.user_id),
action: "oauth2.exchange_code.user_revoked".into(),
target: Some(req.client_id.to_string()),
result: "denied".into(),
note: Some("user disabled or deleted during auth-code exchange window".into()),
},
).await;
return Err(CoreError::Protocol {
code: ProtocolError::InvalidGrant,
description: "user account is not active".into(),
});
}
issue_for(
db,
clock,
ctx,
row.user_id,
req.client_id,
&row.scope,
row.nonce.as_deref(),
&row.auth_methods,
user.email.as_deref().map(|addr| (addr, user.email_verified_at.is_some())),
).await
}
pub async fn exchange_refresh(
db: &Database,
clock: &SharedClock,
ctx: IssuanceContext<'_>,
req: RefreshExchangeRequest,
) -> CoreResult<TokenSet> {
let client = clients::get(db, req.client_id).await.map_err(|e| match e {
sui_id_store::StoreError::NotFound => CoreError::Protocol {
code: ProtocolError::InvalidClient,
description: "unknown client".into(),
},
other => CoreError::from(other),
})?;
if client.is_disabled || client.is_deleted {
return Err(CoreError::Protocol {
code: ProtocolError::UnauthorizedClient,
description: "client is not allowed".into(),
});
}
authenticate_client(&client, req.client_secret.as_deref()).await?;
let row = match refresh_tokens::find_any(db, &req.refresh_token).await {
Ok(r) => r,
Err(sui_id_store::StoreError::NotFound) => {
return Err(CoreError::Protocol {
code: ProtocolError::InvalidGrant,
description: "refresh token is unknown".into(),
});
}
Err(e) => return Err(e.into()),
};
if row.client_id != req.client_id {
return Err(CoreError::Protocol {
code: ProtocolError::InvalidGrant,
description: "refresh token was issued to a different client".into(),
});
}
if row.revoked_at.is_some() {
let _ = refresh_tokens::revoke_family(db, &row.family_id).await;
let _ = audit::append(
db,
&AuditLogRow {
at: clock.now(),
actor: Some(row.user_id),
action: "auth.refresh.theft_detected".into(),
target: Some(row.user_id.to_string()),
result: "denied".into(),
note: Some(format!(
"revoked refresh-token family={} client_id={}",
row.family_id, row.client_id
)),
},
).await;
return Err(CoreError::Protocol {
code: ProtocolError::InvalidGrant,
description: "refresh token is unknown or revoked".into(),
});
}
if row.expires_at <= clock.now() {
return Err(CoreError::Protocol {
code: ProtocolError::InvalidGrant,
description: "refresh token has expired".into(),
});
}
refresh_tokens::revoke(db, &row.id).await?;
let email_for_token: Option<(String, bool)> =
if row.scope.split_whitespace().any(|s| s == "email") {
match users::get(db, row.user_id).await {
Ok(u) if !u.is_disabled && !u.is_deleted => {
u.email.map(|addr| (addr, u.email_verified_at.is_some()))
}
_ => None,
}
} else {
None
};
let email_arg: Option<(&str, bool)> =
email_for_token.as_ref().map(|(addr, v)| (addr.as_str(), *v));
issue_for_with_family(
db,
clock,
ctx,
row.user_id,
row.client_id,
&row.scope,
None,
&row.auth_methods,
Some(row.family_id.clone()),
email_arg,
).await
}
async fn authenticate_client(
client: &sui_id_store::models::ClientRow,
secret: Option<&str>,
) -> CoreResult<()> {
if !client.confidential {
return Ok(());
}
let stored = client.secret_hash.as_deref().ok_or(CoreError::Protocol {
code: ProtocolError::InvalidClient,
description: "client is confidential but has no stored secret".into(),
})?;
let provided = secret.ok_or(CoreError::Protocol {
code: ProtocolError::InvalidClient,
description: "client_secret is required".into(),
})?;
crate::password::verify_password(provided, stored).map_err(|_| CoreError::Protocol {
code: ProtocolError::InvalidClient,
description: "client authentication failed".into(),
})
}
#[allow(clippy::too_many_arguments)]
async fn issue_for(
db: &Database,
clock: &SharedClock,
ctx: IssuanceContext<'_>,
user_id: UserId,
client_id: ClientId,
scope: &str,
nonce: Option<&str>,
auth_methods: &[sui_id_shared::AuthMethod],
user_email: Option<(&str, bool)>,
) -> CoreResult<TokenSet> {
issue_for_with_family(
db,
clock,
ctx,
user_id,
client_id,
scope,
nonce,
auth_methods,
None,
user_email,
).await
}
#[allow(clippy::too_many_arguments)]
async fn issue_for_with_family(
db: &Database,
clock: &SharedClock,
ctx: IssuanceContext<'_>,
user_id: UserId,
client_id: ClientId,
scope: &str,
nonce: Option<&str>,
auth_methods: &[sui_id_shared::AuthMethod],
family_id: Option<String>,
user_email: Option<(&str, bool)>,
) -> CoreResult<TokenSet> {
let key_row = signing_keys::active(db).await.map_err(|e| match e {
sui_id_store::StoreError::NotFound => CoreError::Internal,
other => CoreError::from(other),
})?;
let private_bytes = signing_keys::unseal_private(db, &key_row).await?;
let sk_arr: [u8; 32] = private_bytes.as_slice().try_into().map_err(|_| CoreError::Internal)?;
let sk = SigningKey::from_bytes(&sk_arr);
let include_id_token = scope.split_whitespace().any(|s| s == "openid");
let set = tokens::issue_token_set(
ctx.issuer,
user_id,
client_id,
scope,
nonce,
include_id_token,
&key_row.id.to_string(),
&sk,
ctx.lifetimes,
clock,
auth_methods,
user_email,
).await?;
let now = clock.now();
let new_id = tokens::random_token(16);
let family = family_id.unwrap_or_else(|| new_id.clone());
let rt_row = RefreshTokenRow {
id: new_id,
token_plain: Some(set.refresh_token.clone()),
user_id,
client_id,
scope: scope.to_owned(),
expires_at: now + Duration::seconds(ctx.lifetimes.refresh_secs),
revoked_at: None,
created_at: now,
auth_methods: auth_methods.to_vec(),
family_id: family,
};
refresh_tokens::insert(db, &rt_row).await?;
Ok(set)
}
#[cfg(test)]
mod redirect_uri_tests {
use super::is_redirect_uri_registered;
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig {
cases: 256,
..ProptestConfig::default()
})]
#[test]
fn registered_uri_is_always_accepted(
uri in "[A-Za-z0-9:/._~?&=#@%-]{1,256}",
) {
let registered = vec![uri.clone()];
prop_assert!(is_redirect_uri_registered(®istered, &uri));
}
#[test]
fn one_byte_off_uri_is_rejected(
base in "[A-Za-z0-9:/._~?&=-]{8,128}",
mutation_index in any::<usize>(),
) {
let mut submitted = base.clone().into_bytes();
let i = mutation_index % submitted.len();
submitted[i] = if submitted[i] == b'X' { b'Y' } else { b'X' };
let submitted = String::from_utf8(submitted).unwrap();
prop_assume!(submitted != base);
let registered = vec![base];
prop_assert!(!is_redirect_uri_registered(®istered, &submitted));
}
#[test]
fn case_difference_is_not_folded(
stem in "[a-z]{4,16}",
) {
let lower = format!("https://example.com/{stem}");
let upper = format!("https://example.com/{}", stem.to_uppercase());
prop_assume!(lower != upper);
let registered = vec![lower.clone()];
prop_assert!(is_redirect_uri_registered(®istered, &lower));
prop_assert!(!is_redirect_uri_registered(®istered, &upper));
}
#[test]
fn prefix_extension_is_rejected(
base in "[A-Za-z0-9:/._~-]{8,64}",
suffix in "[A-Za-z0-9/.-]{1,32}",
) {
let registered = vec![base.clone()];
let submitted = format!("{base}{suffix}");
prop_assume!(submitted != base);
prop_assert!(!is_redirect_uri_registered(®istered, &submitted));
}
#[test]
fn multi_registry_matches_each_member_and_only_them(
uris in proptest::collection::vec("[A-Za-z0-9:/._~-]{8,64}", 1..6),
outsider in "[A-Za-z0-9:/._~-]{8,64}",
) {
prop_assume!(!uris.iter().any(|u| u == &outsider));
for u in &uris {
prop_assert!(is_redirect_uri_registered(&uris, u));
}
prop_assert!(!is_redirect_uri_registered(&uris, &outsider));
}
}
}