use getrandom;
use crate::errors::{CoreError, CoreResult};
use crate::jwt;
use crate::time::SharedClock;
use base64ct::{Base64UrlUnpadded, Encoding};
use ed25519_dalek::SigningKey;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use sui_id_shared::ids::{ClientId, UserId};
#[derive(Debug, Serialize, Deserialize)]
pub struct AccessTokenClaims {
pub iss: String,
pub sub: String,
pub aud: String,
pub iat: i64,
pub exp: i64,
pub scope: String,
pub jti: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct IdTokenClaims {
pub iss: String,
pub sub: String,
pub aud: String,
pub iat: i64,
pub exp: i64,
#[serde(skip_serializing_if = "Option::is_none")]
pub nonce: Option<String>,
pub jti: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub acr: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub amr: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub email: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub email_verified: Option<bool>,
}
#[derive(Debug, Clone, Copy)]
pub struct TokenLifetimes {
pub access_secs: i64,
pub id_secs: i64,
pub refresh_secs: i64,
}
impl Default for TokenLifetimes {
fn default() -> Self {
Self {
access_secs: 15 * 60,
id_secs: 15 * 60,
refresh_secs: 14 * 24 * 60 * 60,
}
}
}
pub struct TokenSet {
pub access_token: String,
pub id_token: Option<String>,
pub refresh_token: String,
pub access_expires_in: i64,
}
#[allow(clippy::too_many_arguments)]
pub async fn issue_token_set(
issuer: &str,
user: UserId,
client: ClientId,
scope: &str,
nonce: Option<&str>,
include_id_token: bool,
kid: &str,
signing_key: &SigningKey,
lifetimes: TokenLifetimes,
clock: &SharedClock,
auth_methods: &[sui_id_shared::AuthMethod],
user_email: Option<(&str, bool)>,
) -> CoreResult<TokenSet> {
let now = clock.now();
let iat = now.timestamp();
let access_claims = AccessTokenClaims {
iss: issuer.to_owned(),
sub: user.to_string(),
aud: client.to_string(),
iat,
exp: iat + lifetimes.access_secs,
scope: scope.to_owned(),
jti: random_token(16),
};
let access_token = jwt::sign(kid, signing_key, &access_claims)?;
let id_token = if include_id_token {
let (acr, amr) = if auth_methods.is_empty() {
(None, None)
} else {
(
Some(sui_id_shared::acr_from_methods(auth_methods).to_string()),
Some(sui_id_shared::amr_from_methods(auth_methods)),
)
};
let scope_has_email = scope.split_whitespace().any(|s| s == "email");
let (email_claim, email_verified_claim) = if scope_has_email {
match user_email {
Some((addr, verified)) => (Some(addr.to_owned()), Some(verified)),
None => (None, None),
}
} else {
(None, None)
};
let claims = IdTokenClaims {
iss: issuer.to_owned(),
sub: user.to_string(),
aud: client.to_string(),
iat,
exp: iat + lifetimes.id_secs,
nonce: nonce.map(str::to_owned),
jti: random_token(16),
acr,
amr,
email: email_claim,
email_verified: email_verified_claim,
};
Some(jwt::sign(kid, signing_key, &claims)?)
} else {
None
};
Ok(TokenSet {
access_token,
id_token,
refresh_token: random_token(32),
access_expires_in: lifetimes.access_secs,
})
}
pub fn random_token(byte_len: usize) -> String {
let mut buf = vec![0u8; byte_len];
getrandom::fill(&mut buf).expect("system RNG unavailable");
let mut out = vec![0u8; byte_len * 2 + 4];
let n = Base64UrlUnpadded::encode(&buf, &mut out)
.map(str::len)
.unwrap_or(0);
out.truncate(n);
String::from_utf8(out).expect("base64url is ascii")
}
pub fn sha256_hex(s: &str) -> String {
let digest = Sha256::digest(s.as_bytes());
let mut out = String::with_capacity(digest.len() * 2);
for b in digest {
use std::fmt::Write;
let _ = write!(&mut out, "{b:02x}");
}
out
}
pub fn verify_pkce(method: &str, verifier: &str, expected_challenge: &str) -> CoreResult<()> {
use subtle::ConstantTimeEq;
let computed = match method {
"S256" => {
let digest = Sha256::digest(verifier.as_bytes());
let mut out = vec![0u8; 64];
let n = Base64UrlUnpadded::encode(&digest, &mut out)
.map(str::len)
.unwrap_or(0);
out.truncate(n);
String::from_utf8(out).map_err(|_| CoreError::Internal)?
}
_ => {
return Err(CoreError::Protocol {
code: crate::errors::ProtocolError::InvalidGrant,
description: format!("unsupported code_challenge_method: {method}"),
});
}
};
if computed.as_bytes().ct_eq(expected_challenge.as_bytes()).into() {
Ok(())
} else {
Err(CoreError::Protocol {
code: crate::errors::ProtocolError::InvalidGrant,
description: "PKCE verification failed".into(),
})
}
}
fn verify_from_snapshot<C: serde::de::DeserializeOwned>(
keys: &[crate::cache::CachedSigningKey],
token: &str,
) -> crate::CoreResult<crate::jwt::Decoded<C>> {
use ed25519_dalek::VerifyingKey;
let resolver = |kid: &str| -> Option<VerifyingKey> {
let entry = keys.iter().find(|k| k.kid == kid)?;
let arr: [u8; 32] = entry.public_key_bytes.as_slice().try_into().ok()?;
VerifyingKey::from_bytes(&arr).ok()
};
crate::jwt::verify(token, resolver)
}
fn published_to_cached(
rows: Vec<sui_id_store::models::SigningKeyRow>,
) -> Vec<crate::cache::CachedSigningKey> {
rows.into_iter()
.map(|k| crate::cache::CachedSigningKey {
kid: k.id.to_string(),
algorithm: k.algorithm,
public_key_bytes: k.public_key,
})
.collect()
}
pub async fn verify_id_token(
db: &sui_id_store::Database,
clock: &crate::time::SharedClock,
token: &str,
accept_expired: bool,
) -> crate::CoreResult<IdTokenClaims> {
let keys = published_to_cached(
sui_id_store::repos::signing_keys::list_published(db)
.await
.unwrap_or_default(),
);
let decoded: crate::jwt::Decoded<IdTokenClaims> = verify_from_snapshot(&keys, token)?;
if !accept_expired && decoded.claims.exp < clock.now().timestamp() {
return Err(crate::CoreError::Unauthenticated);
}
Ok(decoded.claims)
}
pub async fn verify_id_token_cached(
caches: &crate::cache::Caches,
clock: &crate::time::SharedClock,
token: &str,
accept_expired: bool,
) -> crate::CoreResult<IdTokenClaims> {
let keys = caches.jwks.snapshot().await;
let decoded: crate::jwt::Decoded<IdTokenClaims> = verify_from_snapshot(&keys, token)?;
if !accept_expired && decoded.claims.exp < clock.now().timestamp() {
return Err(crate::CoreError::Unauthenticated);
}
Ok(decoded.claims)
}
pub async fn verify_access_token(
db: &sui_id_store::Database,
clock: &crate::time::SharedClock,
token: &str,
) -> crate::CoreResult<AccessTokenClaims> {
let keys = published_to_cached(
sui_id_store::repos::signing_keys::list_published(db)
.await
.unwrap_or_default(),
);
let decoded: crate::jwt::Decoded<AccessTokenClaims> = verify_from_snapshot(&keys, token)?;
if decoded.claims.exp < clock.now().timestamp() {
return Err(crate::CoreError::Unauthenticated);
}
Ok(decoded.claims)
}
pub async fn verify_access_token_cached(
caches: &crate::cache::Caches,
clock: &crate::time::SharedClock,
token: &str,
) -> crate::CoreResult<AccessTokenClaims> {
let keys = caches.jwks.snapshot().await;
let decoded: crate::jwt::Decoded<AccessTokenClaims> = verify_from_snapshot(&keys, token)?;
if decoded.claims.exp < clock.now().timestamp() {
return Err(crate::CoreError::Unauthenticated);
}
Ok(decoded.claims)
}
#[cfg(test)]
mod tests;