spiffe-rs 0.1.0

Rust port of spiffe-go with SPIFFE IDs, bundles, SVIDs, Workload API client, federation helpers, and rustls-based SPIFFE TLS utilities.
Documentation
use crate::internal::pemutil;
use crate::internal::x509util;
use crate::spiffeid::TrustDomain;
use std::collections::HashMap;
use std::fs;
use std::io::Read;
use std::sync::RwLock;

#[derive(Debug, Clone)]
pub struct Error(String);

impl std::fmt::Display for Error {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        self.0.fmt(f)
    }
}

impl std::error::Error for Error {}

impl Error {
    pub fn new(message: impl Into<String>) -> Error {
        Error(message.into())
    }
}

pub type Result<T> = std::result::Result<T, Error>;

fn wrap_error(message: impl std::fmt::Display) -> Error {
    Error(format!("x509bundle: {}", message))
}

#[derive(Debug)]
pub struct Bundle {
    trust_domain: TrustDomain,
    x509_authorities: RwLock<Vec<Vec<u8>>>,
}

impl Bundle {
    pub fn new(trust_domain: TrustDomain) -> Bundle {
        Bundle {
            trust_domain,
            x509_authorities: RwLock::new(Vec::new()),
        }
    }

    pub fn from_x509_authorities(trust_domain: TrustDomain, authorities: &[Vec<u8>]) -> Bundle {
        Bundle {
            trust_domain,
            x509_authorities: RwLock::new(x509util::copy_x509_authorities(authorities)),
        }
    }

    pub fn load(trust_domain: TrustDomain, path: &str) -> Result<Bundle> {
        let bytes = fs::read(path)
            .map_err(|err| wrap_error(format!("unable to load X.509 bundle file: {}", err)))?;
        Bundle::parse(trust_domain, &bytes)
    }

    pub fn read(trust_domain: TrustDomain, reader: &mut dyn Read) -> Result<Bundle> {
        let mut bytes = Vec::new();
        reader
            .read_to_end(&mut bytes)
            .map_err(|err| wrap_error(format!("unable to read X.509 bundle: {}", err)))?;
        Bundle::parse(trust_domain, &bytes)
    }

    pub fn parse(trust_domain: TrustDomain, bytes: &[u8]) -> Result<Bundle> {
        let bundle = Bundle::new(trust_domain);
        if bytes.is_empty() {
            return Ok(bundle);
        }
        let certs = pemutil::parse_certificates(bytes)
            .map_err(|err| wrap_error(format!("cannot parse certificate: {}", err)))?;
        for cert in certs {
            bundle.add_x509_authority(&cert);
        }
        Ok(bundle)
    }

    pub fn parse_raw(trust_domain: TrustDomain, bytes: &[u8]) -> Result<Bundle> {
        let bundle = Bundle::new(trust_domain);
        if bytes.is_empty() {
            return Ok(bundle);
        }
        let certs = parse_raw_certificates(bytes)
            .map_err(|err| wrap_error(format!("cannot parse certificate: {}", err)))?;
        for cert in certs {
            bundle.add_x509_authority(&cert);
        }
        Ok(bundle)
    }

    pub fn trust_domain(&self) -> TrustDomain {
        self.trust_domain.clone()
    }

    pub fn x509_authorities(&self) -> Vec<Vec<u8>> {
        self.x509_authorities
            .read()
            .map(|guard| x509util::copy_x509_authorities(&guard))
            .unwrap_or_default()
    }

    pub fn add_x509_authority(&self, authority: &[u8]) {
        if let Ok(mut guard) = self.x509_authorities.write() {
            if guard.iter().any(|cert| cert == authority) {
                return;
            }
            guard.push(authority.to_vec());
        }
    }

    pub fn remove_x509_authority(&self, authority: &[u8]) {
        if let Ok(mut guard) = self.x509_authorities.write() {
            if let Some(index) = guard.iter().position(|cert| cert == authority) {
                guard.remove(index);
            }
        }
    }

    pub fn has_x509_authority(&self, authority: &[u8]) -> bool {
        self.x509_authorities
            .read()
            .map(|guard| guard.iter().any(|cert| cert == authority))
            .unwrap_or(false)
    }

    pub fn set_x509_authorities(&self, authorities: &[Vec<u8>]) {
        if let Ok(mut guard) = self.x509_authorities.write() {
            *guard = x509util::copy_x509_authorities(authorities);
        }
    }

    pub fn empty(&self) -> bool {
        self.x509_authorities
            .read()
            .map(|guard| guard.is_empty())
            .unwrap_or(true)
    }

    pub fn marshal(&self) -> Result<Vec<u8>> {
        let certs = self.x509_authorities();
        Ok(pemutil::encode_certificates(&certs))
    }

    pub fn equal(&self, other: &Bundle) -> bool {
        self.trust_domain == other.trust_domain
            && x509util::certs_equal(&self.x509_authorities(), &other.x509_authorities())
    }

    pub fn clone_bundle(&self) -> Bundle {
        Bundle::from_x509_authorities(self.trust_domain(), &self.x509_authorities())
    }

    pub fn get_x509_bundle_for_trust_domain(&self, trust_domain: TrustDomain) -> Result<Bundle> {
        if self.trust_domain != trust_domain {
            return Err(wrap_error(format!(
                "no X.509 bundle found for trust domain: \"{}\"",
                trust_domain
            )));
        }
        Ok(self.clone_bundle())
    }
}

pub trait Source {
    fn get_x509_bundle_for_trust_domain(&self, trust_domain: TrustDomain) -> Result<Bundle>;
}

#[derive(Debug)]
pub struct Set {
    bundles: RwLock<HashMap<TrustDomain, Bundle>>,
}

impl Set {
    pub fn new(bundles: &[Bundle]) -> Set {
        let mut map = HashMap::new();
        for bundle in bundles {
            map.insert(bundle.trust_domain(), bundle.clone_bundle());
        }
        Set {
            bundles: RwLock::new(map),
        }
    }

    pub fn add(&self, bundle: &Bundle) {
        if let Ok(mut guard) = self.bundles.write() {
            guard.insert(bundle.trust_domain(), bundle.clone_bundle());
        }
    }

    pub fn remove(&self, trust_domain: TrustDomain) {
        if let Ok(mut guard) = self.bundles.write() {
            guard.remove(&trust_domain);
        }
    }

    pub fn has(&self, trust_domain: TrustDomain) -> bool {
        self.bundles
            .read()
            .map(|guard| guard.contains_key(&trust_domain))
            .unwrap_or(false)
    }

    pub fn get(&self, trust_domain: TrustDomain) -> Option<Bundle> {
        self.bundles
            .read()
            .ok()
            .and_then(|guard| guard.get(&trust_domain).map(|b| b.clone_bundle()))
    }

    pub fn bundles(&self) -> Vec<Bundle> {
        let mut bundles = self
            .bundles
            .read()
            .map(|guard| guard.values().map(|b| b.clone_bundle()).collect::<Vec<_>>())
            .unwrap_or_default();
        bundles.sort_by(|a, b| a.trust_domain().compare(&b.trust_domain()));
        bundles
    }

    pub fn len(&self) -> usize {
        self.bundles.read().map(|guard| guard.len()).unwrap_or(0)
    }

    pub fn get_x509_bundle_for_trust_domain(&self, trust_domain: TrustDomain) -> Result<Bundle> {
        let guard = self
            .bundles
            .read()
            .map_err(|_| wrap_error("bundle store poisoned"))?;
        let bundle = guard.get(&trust_domain).ok_or_else(|| {
            wrap_error(format!(
                "no X.509 bundle for trust domain \"{}\"",
                trust_domain
            ))
        })?;
        Ok(bundle.clone_bundle())
    }
}

impl Source for Set {
    fn get_x509_bundle_for_trust_domain(&self, trust_domain: TrustDomain) -> Result<Bundle> {
        self.get_x509_bundle_for_trust_domain(trust_domain)
    }
}

impl Source for Bundle {
    fn get_x509_bundle_for_trust_domain(&self, trust_domain: TrustDomain) -> Result<Bundle> {
        self.get_x509_bundle_for_trust_domain(trust_domain)
    }
}

fn parse_raw_certificates(bytes: &[u8]) -> std::result::Result<Vec<Vec<u8>>, String> {
    let mut remaining = bytes;
    let mut certs = Vec::new();
    while !remaining.is_empty() {
        let (rest, _cert) = x509_parser::parse_x509_certificate(remaining)
            .map_err(|err| err.to_string())?;
        let consumed = remaining
            .len()
            .checked_sub(rest.len())
            .ok_or_else(|| "invalid certificate length".to_string())?;
        certs.push(remaining[..consumed].to_vec());
        remaining = rest;
    }
    Ok(certs)
}