spiffe-rs 0.1.0

Rust port of spiffe-go with SPIFFE IDs, bundles, SVIDs, Workload API client, federation helpers, and rustls-based SPIFFE TLS utilities.
Documentation
use crate::bundle::jwtbundle;
use crate::bundle::jwtbundle::JwtKey;
use crate::spiffeid::ID;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;
use p256::ecdsa::{signature::Verifier, Signature as P256Signature, VerifyingKey as P256VerifyingKey};
use p384::ecdsa::{Signature as P384Signature, VerifyingKey as P384VerifyingKey};
use p521::ecdsa::{Signature as P521Signature, VerifyingKey as P521VerifyingKey};
use rsa::pkcs1v15::{Signature as RsaSignature, VerifyingKey as RsaVerifyingKey};
use rsa::pss::{Signature as RsaPssSignature, VerifyingKey as RsaPssVerifyingKey};
use rsa::RsaPublicKey;
use serde_json::{Map, Value};
use sha2::{Digest, Sha256, Sha384, Sha512};
use pkcs8::AssociatedOid;
use rsa::signature::digest::FixedOutputReset;
use std::collections::HashMap;
use std::time::{Duration, SystemTime};

#[derive(Debug, Clone)]
pub struct Error(String);

impl std::fmt::Display for Error {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        self.0.fmt(f)
    }
}

impl std::error::Error for Error {}

pub type Result<T> = std::result::Result<T, Error>;

fn wrap_error(message: impl std::fmt::Display) -> Error {
    Error(format!("jwtsvid: {}", message))
}

#[derive(Debug, Clone)]
pub struct SVID {
    pub id: ID,
    pub audience: Vec<String>,
    pub expiry: SystemTime,
    pub claims: HashMap<String, Value>,
    pub hint: String,
    token: String,
}

#[derive(Debug, Clone)]
pub struct Params {
    pub subject: ID,
    pub audience: String,
    pub extra_audiences: Vec<String>,
}

impl Params {
    pub fn new(subject: ID, audience: impl Into<String>) -> Self {
        Self {
            subject,
            audience: audience.into(),
            extra_audiences: Vec::new(),
        }
    }

    pub fn with_extra_audience(mut self, audience: impl Into<String>) -> Self {
        self.extra_audiences.push(audience.into());
        self
    }

    pub fn audience_list(&self) -> Vec<String> {
        let mut audiences = Vec::with_capacity(1 + self.extra_audiences.len());
        audiences.push(self.audience.clone());
        audiences.extend(self.extra_audiences.clone());
        audiences
    }
}

impl SVID {
    pub fn marshal(&self) -> String {
        self.token.clone()
    }
}

pub fn parse_and_validate(
    token: &str,
    bundles: &dyn jwtbundle::Source,
    audience: &[String],
) -> Result<SVID> {
    parse(token, audience, |header, signing_input, signature, trust_domain| {
        let key_id = header
            .kid
            .as_deref()
            .ok_or_else(|| wrap_error("token header missing key id"))?;
        let bundle = bundles
            .get_jwt_bundle_for_trust_domain(trust_domain.clone())
            .map_err(|_| {
            wrap_error(format!("no bundle found for trust domain \"{}\"", trust_domain))
        })?;
        let authority = bundle
            .find_jwt_authority(key_id)
            .ok_or_else(|| {
                wrap_error(format!(
                    "no JWT authority \"{}\" found for trust domain \"{}\"",
                    key_id, trust_domain
                ))
            })?;
        verify_signature(&header.alg, &authority, signing_input, signature).map_err(|_| {
            wrap_error("unable to get claims from token: go-jose/go-jose: error in cryptographic primitive")
        })?;
        Ok(())
    })
}

pub fn parse_insecure(token: &str, audience: &[String]) -> Result<SVID> {
    parse(token, audience, |_header, _signing_input, _signature, _td| Ok(()))
}

fn parse<F>(token: &str, audience: &[String], verify: F) -> Result<SVID>
where
    F: Fn(&Header, &str, &[u8], &crate::spiffeid::TrustDomain) -> Result<()>,
{
    let parts: Vec<&str> = token.split('.').collect();
    if parts.len() != 3 {
        return Err(wrap_error("unable to parse JWT token"));
    }
    let header_bytes = URL_SAFE_NO_PAD
        .decode(parts[0].as_bytes())
        .map_err(|_| wrap_error("unable to parse JWT token"))?;
    let payload_bytes = URL_SAFE_NO_PAD
        .decode(parts[1].as_bytes())
        .map_err(|_| wrap_error("unable to parse JWT token"))?;
    let signature = URL_SAFE_NO_PAD
        .decode(parts[2].as_bytes())
        .map_err(|_| wrap_error("unable to parse JWT token"))?;

    let header: Header =
        serde_json::from_slice(&header_bytes).map_err(|_| wrap_error("unable to parse JWT token"))?;

    if !is_allowed_alg(&header.alg) {
        return Err(wrap_error("unable to parse JWT token"));
    }
    if let Some(typ) = header.typ.as_deref() {
        if typ != "JWT" && typ != "JOSE" {
            return Err(wrap_error("token header type not equal to either JWT or JOSE"));
        }
    }

    let claims: Map<String, Value> =
        serde_json::from_slice(&payload_bytes).map_err(|_| wrap_error("unable to parse JWT token"))?;
    let subject = claims
        .get("sub")
        .and_then(|v| v.as_str())
        .ok_or_else(|| wrap_error("token missing subject claim"))?;
    let expiry = claims
        .get("exp")
        .and_then(|v| v.as_i64().or_else(|| v.as_f64().map(|v| v as i64)))
        .ok_or_else(|| wrap_error("token missing exp claim"))?;
    let aud = extract_audience(&claims);

    let id = ID::from_string(subject)
        .map_err(|err| wrap_error(format!("token has an invalid subject claim: {}", err)))?;

    let trust_domain = id.trust_domain();
    verify(&header, &format!("{}.{}", parts[0], parts[1]), &signature, &trust_domain)?;

    validate_claims(expiry, &aud, audience)?;

    Ok(SVID {
        id,
        audience: aud,
        expiry: SystemTime::UNIX_EPOCH + Duration::from_secs(expiry as u64),
        claims: claims.into_iter().collect::<HashMap<_, _>>(),
        hint: String::new(),
        token: token.to_string(),
    })
}

fn validate_claims(expiry: i64, audience: &[String], expected: &[String]) -> Result<()> {
    let now = SystemTime::now()
        .duration_since(SystemTime::UNIX_EPOCH)
        .map_err(|_| wrap_error("token has expired"))?
        .as_secs() as i64;
    if expiry <= now {
        return Err(wrap_error("token has expired"));
    }
    if !expected.is_empty() && !expected.iter().any(|a| audience.contains(a)) {
        return Err(wrap_error(format!(
            "expected audience in {:?} (audience={:?})",
            expected, audience
        )));
    }
    Ok(())
}

fn extract_audience(claims: &Map<String, Value>) -> Vec<String> {
    match claims.get("aud") {
        Some(Value::String(s)) => vec![s.clone()],
        Some(Value::Array(items)) => items
            .iter()
            .filter_map(|v| v.as_str().map(|s| s.to_string()))
            .collect(),
        _ => Vec::new(),
    }
}

fn is_allowed_alg(alg: &str) -> bool {
    matches!(
        alg,
        "RS256" | "RS384" | "RS512" | "ES256" | "ES384" | "ES512" | "PS256" | "PS384" | "PS512"
    )
}

fn verify_signature(alg: &str, key: &JwtKey, signing_input: &str, signature: &[u8]) -> Result<()> {
    match (alg, key) {
        ("RS256", JwtKey::Rsa { n, e }) => verify_rsa_pkcs1::<Sha256>(n, e, signing_input, signature),
        ("RS384", JwtKey::Rsa { n, e }) => verify_rsa_pkcs1::<Sha384>(n, e, signing_input, signature),
        ("RS512", JwtKey::Rsa { n, e }) => verify_rsa_pkcs1::<Sha512>(n, e, signing_input, signature),
        ("PS256", JwtKey::Rsa { n, e }) => verify_rsa_pss::<Sha256>(n, e, signing_input, signature),
        ("PS384", JwtKey::Rsa { n, e }) => verify_rsa_pss::<Sha384>(n, e, signing_input, signature),
        ("PS512", JwtKey::Rsa { n, e }) => verify_rsa_pss::<Sha512>(n, e, signing_input, signature),
        ("ES256", JwtKey::Ec { x, y, .. }) => verify_ecdsa_p256(x, y, signing_input, signature),
        ("ES384", JwtKey::Ec { x, y, .. }) => verify_ecdsa_p384(x, y, signing_input, signature),
        ("ES512", JwtKey::Ec { x, y, .. }) => verify_ecdsa_p521(x, y, signing_input, signature),
        _ => Err(wrap_error("unable to parse JWT token")),
    }
}

fn verify_rsa_pkcs1<D>(
    n: &[u8],
    e: &[u8],
    signing_input: &str,
    signature: &[u8],
) -> Result<()>
where
    D: Digest + AssociatedOid,
{
    let public_key = rsa_public_key(n, e)?;
    let verifying_key = RsaVerifyingKey::<D>::new(public_key);
    let sig = RsaSignature::try_from(signature).map_err(|_| wrap_error("invalid signature"))?;
    verifying_key
        .verify(signing_input.as_bytes(), &sig)
        .map_err(|_| wrap_error("invalid signature"))?;
    Ok(())
}

fn verify_rsa_pss<D>(
    n: &[u8],
    e: &[u8],
    signing_input: &str,
    signature: &[u8],
) -> Result<()>
where
    D: Digest + FixedOutputReset,
{
    let public_key = rsa_public_key(n, e)?;
    let verifying_key = RsaPssVerifyingKey::<D>::new(public_key);
    let sig = RsaPssSignature::try_from(signature).map_err(|_| wrap_error("invalid signature"))?;
    verifying_key
        .verify(signing_input.as_bytes(), &sig)
        .map_err(|_| wrap_error("invalid signature"))?;
    Ok(())
}

fn rsa_public_key(n: &[u8], e: &[u8]) -> Result<RsaPublicKey> {
    let n = rsa::BigUint::from_bytes_be(n);
    let e = rsa::BigUint::from_bytes_be(e);
    RsaPublicKey::new(n, e).map_err(|_| wrap_error("invalid RSA key"))
}

fn verify_ecdsa_p256(
    x: &[u8],
    y: &[u8],
    signing_input: &str,
    signature: &[u8],
) -> Result<()> {
    let public_key = ecdsa_public_key(x, y)?;
    let key = P256VerifyingKey::from_sec1_bytes(&public_key)
        .map_err(|_| wrap_error("invalid EC key"))?;
    let sig = P256Signature::from_slice(signature).map_err(|_| wrap_error("invalid signature"))?;
    key.verify(signing_input.as_bytes(), &sig)
        .map_err(|_| wrap_error("invalid signature"))?;
    Ok(())
}

fn verify_ecdsa_p384(
    x: &[u8],
    y: &[u8],
    signing_input: &str,
    signature: &[u8],
) -> Result<()> {
    let public_key = ecdsa_public_key(x, y)?;
    let key = P384VerifyingKey::from_sec1_bytes(&public_key)
        .map_err(|_| wrap_error("invalid EC key"))?;
    let sig = P384Signature::from_slice(signature).map_err(|_| wrap_error("invalid signature"))?;
    key.verify(signing_input.as_bytes(), &sig)
        .map_err(|_| wrap_error("invalid signature"))?;
    Ok(())
}

fn verify_ecdsa_p521(
    x: &[u8],
    y: &[u8],
    signing_input: &str,
    signature: &[u8],
) -> Result<()> {
    let public_key = ecdsa_public_key(x, y)?;
    let key = P521VerifyingKey::from_sec1_bytes(&public_key)
        .map_err(|_| wrap_error("invalid EC key"))?;
    let sig = P521Signature::from_slice(signature).map_err(|_| wrap_error("invalid signature"))?;
    key.verify(signing_input.as_bytes(), &sig)
        .map_err(|_| wrap_error("invalid signature"))?;
    Ok(())
}

fn ecdsa_public_key(x: &[u8], y: &[u8]) -> Result<Vec<u8>> {
    let mut out = Vec::with_capacity(1 + x.len() + y.len());
    out.push(0x04);
    out.extend_from_slice(x);
    out.extend_from_slice(y);
    Ok(out)
}

#[derive(Debug, serde::Deserialize)]
struct Header {
    alg: String,
    #[serde(default)]
    kid: Option<String>,
    #[serde(rename = "typ")]
    #[serde(default)]
    typ: Option<String>,
}