use crate::{AitpSigningKey, AitpVerifyingKey, CryptoError};
use aitp_core::{base64url, jcs, Aid, AidAlgorithm};
use serde::Deserialize;
pub const TYP_TCT: &str = "aitp-tct+jwt";
pub const TYP_GRANT_VOUCHER: &str = "aitp-grant+jwt";
pub const TYP_DELEGATION: &str = "aitp-delegation+jwt";
pub fn jose_alg(algorithm: AidAlgorithm) -> Result<&'static str, CryptoError> {
match algorithm {
AidAlgorithm::Ed25519 => Ok("EdDSA"),
AidAlgorithm::P256 => Ok("ES256"),
other => Err(CryptoError::KeyParseFailed(format!(
"AID algorithm {other:?} has no registered JOSE alg in this build"
))),
}
}
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
struct JwsHeader {
alg: String,
typ: String,
}
pub fn sign_compact<T: serde::Serialize>(
key: &AitpSigningKey,
typ: &str,
claims: &T,
) -> Result<String, CryptoError> {
if typ.bytes().any(|b| b == b'"' || b == b'\\' || b < 0x20) {
return Err(CryptoError::JwsMalformed(
"typ must not require JSON escaping".into(),
));
}
let alg = jose_alg(key.algorithm())?;
let header = format!("{{\"alg\":\"{alg}\",\"typ\":\"{typ}\"}}");
let payload = jcs::canonicalize_serializable(claims)
.map_err(|e| CryptoError::JwsMalformed(format!("claims canonicalization: {e}")))?;
let signing_input = format!(
"{}.{}",
base64url::encode(header.as_bytes()),
base64url::encode(&payload)
);
let sig = key.sign_raw(signing_input.as_bytes());
Ok(format!("{signing_input}.{}", base64url::encode(&sig)))
}
pub fn verify_compact(
signer: &Aid,
expected_typ: &str,
token: &str,
) -> Result<Vec<u8>, CryptoError> {
let (header_b64, payload_b64, sig_b64) = split_strict(token)?;
let header_bytes = base64url::decode_strict(header_b64)
.map_err(|e| CryptoError::JwsMalformed(format!("header segment: {e}")))?;
let payload = base64url::decode_strict(payload_b64)
.map_err(|e| CryptoError::JwsMalformed(format!("payload segment: {e}")))?;
let sig_bytes = base64url::decode_strict(sig_b64)
.map_err(|e| CryptoError::JwsMalformed(format!("signature segment: {e}")))?;
let header: JwsHeader = serde_json::from_slice(&header_bytes)
.map_err(|e| CryptoError::JwsMalformed(format!("protected header: {e}")))?;
if header.typ != expected_typ {
return Err(CryptoError::TypMismatch {
expected: expected_typ.to_string(),
got: header.typ,
});
}
let expected_alg = jose_alg(signer.algorithm())?;
if header.alg != expected_alg {
return Err(CryptoError::AlgMismatch(format!(
"expected {expected_alg} for this AID, got {}",
header.alg
)));
}
let vk = AitpVerifyingKey::from_aid(signer)?;
let signing_input = &token.as_bytes()[..header_b64.len() + 1 + payload_b64.len()];
vk.verify_raw(signing_input, &sig_bytes)?;
match serde_json::from_slice::<serde_json::Value>(&payload) {
Ok(serde_json::Value::Object(_)) => {}
_ => {
return Err(CryptoError::JwsMalformed(
"payload is not a JSON object".into(),
))
}
}
Ok(payload)
}
pub fn decode_payload_unverified(token: &str) -> Result<Vec<u8>, CryptoError> {
let (_, payload_b64, _) = split_strict(token)?;
base64url::decode_strict(payload_b64)
.map_err(|e| CryptoError::JwsMalformed(format!("payload segment: {e}")))
}
fn split_strict(token: &str) -> Result<(&str, &str, &str), CryptoError> {
let mut parts = token.split('.');
match (parts.next(), parts.next(), parts.next(), parts.next()) {
(Some(h), Some(p), Some(s), None) if !h.is_empty() && !p.is_empty() && !s.is_empty() => {
Ok((h, p, s))
}
_ => Err(CryptoError::JwsMalformed(
"expected exactly three non-empty segments".into(),
)),
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn claims() -> serde_json::Value {
json!({
"ver": "aitp/0.2",
"iss": "aid:pubkey:O2onvM62pC1io6jQKm8Nc2UyFXcd4kOmOsBIoYtZ2ik",
"grants": ["demo.echo"],
})
}
#[test]
fn sign_verify_round_trip_ed25519() {
let key = AitpSigningKey::from_seed(&[1u8; 32]);
let token = sign_compact(&key, TYP_TCT, &claims()).unwrap();
let payload = verify_compact(key.aid(), TYP_TCT, &token).unwrap();
let back: serde_json::Value = serde_json::from_slice(&payload).unwrap();
assert_eq!(back, claims());
}
#[test]
fn sign_verify_round_trip_p256() {
let key = AitpSigningKey::from_p256_seed(&[5u8; 32]).unwrap();
let token = sign_compact(&key, TYP_DELEGATION, &claims()).unwrap();
assert!(token.starts_with(&base64url::encode(
b"{\"alg\":\"ES256\",\"typ\":\"aitp-delegation+jwt\"}"
)));
verify_compact(key.aid(), TYP_DELEGATION, &token).unwrap();
}
#[test]
fn header_bytes_are_exact_two_member_form() {
let key = AitpSigningKey::from_seed(&[2u8; 32]);
let token = sign_compact(&key, TYP_GRANT_VOUCHER, &claims()).unwrap();
let header_b64 = token.split('.').next().unwrap();
let header = base64url::decode_strict(header_b64).unwrap();
assert_eq!(header, b"{\"alg\":\"EdDSA\",\"typ\":\"aitp-grant+jwt\"}");
}
#[test]
fn rejects_typ_mismatch() {
let key = AitpSigningKey::from_seed(&[3u8; 32]);
let token = sign_compact(&key, TYP_GRANT_VOUCHER, &claims()).unwrap();
assert!(matches!(
verify_compact(key.aid(), TYP_TCT, &token),
Err(CryptoError::TypMismatch { .. })
));
}
#[test]
fn rejects_alg_none_and_wrong_alg() {
let key = AitpSigningKey::from_seed(&[4u8; 32]);
let token = sign_compact(&key, TYP_TCT, &claims()).unwrap();
let (_, rest) = token.split_once('.').unwrap();
for evil_alg in ["none", "None", "NONE", "ES256", "HS256", "RS256"] {
let evil_header = base64url::encode(
format!("{{\"alg\":\"{evil_alg}\",\"typ\":\"aitp-tct+jwt\"}}").as_bytes(),
);
let evil = format!("{evil_header}.{rest}");
assert!(
matches!(
verify_compact(key.aid(), TYP_TCT, &evil),
Err(CryptoError::AlgMismatch(_))
),
"alg {evil_alg} must be rejected before signature checking"
);
}
}
#[test]
fn rejects_extra_header_params() {
let key = AitpSigningKey::from_seed(&[6u8; 32]);
let token = sign_compact(&key, TYP_TCT, &claims()).unwrap();
let (_, rest) = token.split_once('.').unwrap();
for evil in [
r#"{"alg":"EdDSA","typ":"aitp-tct+jwt","crit":["exp"]}"#,
r#"{"alg":"EdDSA","typ":"aitp-tct+jwt","kid":"x"}"#,
r#"{"alg":"EdDSA","typ":"aitp-tct+jwt","jwk":{}}"#,
r#"{"alg":"EdDSA"}"#,
r#"{"typ":"aitp-tct+jwt"}"#,
r#"{"alg":"EdDSA","alg":"none","typ":"aitp-tct+jwt"}"#,
] {
let evil_token = format!("{}.{rest}", base64url::encode(evil.as_bytes()));
assert!(
matches!(
verify_compact(key.aid(), TYP_TCT, &evil_token),
Err(CryptoError::JwsMalformed(_))
),
"header {evil} must be rejected"
);
}
}
#[test]
fn rejects_wrong_segment_shapes() {
let key = AitpSigningKey::from_seed(&[7u8; 32]);
let token = sign_compact(&key, TYP_TCT, &claims()).unwrap();
let (h, rest) = token.split_once('.').unwrap();
let (p, s) = rest.split_once('.').unwrap();
for evil in [
format!("{h}.{p}"), format!("{h}.{p}.{s}.x"), format!("{h}..{s}"), format!("{h}.{p}."), format!("{h}.{p}.{s}="), format!("{h}.{p}.{}!", &s[..s.len() - 1]), ] {
assert!(
verify_compact(key.aid(), TYP_TCT, &evil).is_err(),
"shape {evil:?} must be rejected"
);
}
}
#[test]
fn rejects_tampered_payload_and_cross_key() {
let key = AitpSigningKey::from_seed(&[8u8; 32]);
let other = AitpSigningKey::from_seed(&[9u8; 32]);
let token = sign_compact(&key, TYP_TCT, &claims()).unwrap();
assert!(matches!(
verify_compact(other.aid(), TYP_TCT, &token),
Err(CryptoError::SignatureInvalid)
));
let (h, rest) = token.split_once('.').unwrap();
let (_, s) = rest.split_once('.').unwrap();
let tampered_payload = base64url::encode(
&jcs::canonicalize_serializable(&json!({"ver": "aitp/0.2"})).unwrap(),
);
let tampered = format!("{h}.{tampered_payload}.{s}");
assert!(matches!(
verify_compact(key.aid(), TYP_TCT, &tampered),
Err(CryptoError::SignatureInvalid)
));
}
#[test]
fn rejects_non_object_payload() {
let key = AitpSigningKey::from_seed(&[10u8; 32]);
let token = sign_compact(&key, TYP_TCT, &json!(["not", "an", "object"])).unwrap();
assert!(matches!(
verify_compact(key.aid(), TYP_TCT, &token),
Err(CryptoError::JwsMalformed(_))
));
}
#[test]
fn p256_key_rejects_eddsa_alg_header() {
let p256 = AitpSigningKey::from_p256_seed(&[11u8; 32]).unwrap();
let ed = AitpSigningKey::from_seed(&[12u8; 32]);
let token = sign_compact(&ed, TYP_TCT, &claims()).unwrap();
assert!(matches!(
verify_compact(p256.aid(), TYP_TCT, &token),
Err(CryptoError::AlgMismatch(_))
));
}
#[test]
fn decode_payload_unverified_returns_claims() {
let key = AitpSigningKey::from_seed(&[13u8; 32]);
let token = sign_compact(&key, TYP_TCT, &claims()).unwrap();
let payload = decode_payload_unverified(&token).unwrap();
let v: serde_json::Value = serde_json::from_slice(&payload).unwrap();
assert_eq!(v["iss"], claims()["iss"]);
}
}