use crate::bundle::BundleSource;
use crate::cert::error::CertificateError;
use crate::cert::parsing::to_certificate_vec_unbounded;
use crate::cert::Certificate;
use crate::spiffe_id::TrustDomain;
use std::collections::BTreeMap;
use std::convert::Infallible;
use std::sync::Arc;
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct X509Bundle {
trust_domain: TrustDomain,
x509_authorities: Vec<Certificate>,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct X509BundleSet {
bundles: BTreeMap<TrustDomain, Arc<X509Bundle>>,
}
#[derive(Debug, thiserror::Error, PartialEq)]
#[non_exhaustive]
pub enum X509BundleError {
#[error(transparent)]
Certificate(#[from] CertificateError),
}
impl X509Bundle {
pub const fn new(trust_domain: TrustDomain) -> Self {
Self {
trust_domain,
x509_authorities: Vec::new(),
}
}
pub fn from_x509_authorities(
trust_domain: TrustDomain,
authorities: &[&[u8]],
) -> Result<Self, X509BundleError> {
let x509_authorities = authorities
.iter()
.map(|b| Certificate::try_from(*b))
.collect::<Result<Vec<_>, _>>()?;
Ok(Self {
trust_domain,
x509_authorities,
})
}
pub fn parse_from_der(
trust_domain: TrustDomain,
bundle_der: &[u8],
) -> Result<Self, X509BundleError> {
let x509_authorities = to_certificate_vec_unbounded(bundle_der)?;
Ok(Self {
trust_domain,
x509_authorities,
})
}
pub fn add_authority(&mut self, authority_bytes: &[u8]) -> Result<(), X509BundleError> {
let certificate = Certificate::try_from(authority_bytes)?;
self.x509_authorities.push(certificate);
Ok(())
}
pub const fn trust_domain(&self) -> &TrustDomain {
&self.trust_domain
}
pub fn authorities(&self) -> &[Certificate] {
&self.x509_authorities
}
}
impl X509BundleSet {
pub const fn new() -> Self {
Self {
bundles: BTreeMap::new(),
}
}
pub fn add_bundle(&mut self, bundle: X509Bundle) {
let trust_domain = bundle.trust_domain().clone();
self.bundles.insert(trust_domain, Arc::new(bundle));
}
pub fn get(&self, trust_domain: &TrustDomain) -> Option<Arc<X509Bundle>> {
self.bundles.get(trust_domain).cloned()
}
pub fn get_ref(&self, trust_domain: &TrustDomain) -> Option<&Arc<X509Bundle>> {
self.bundles.get(trust_domain)
}
pub fn iter(&self) -> impl Iterator<Item = (&TrustDomain, &Arc<X509Bundle>)> {
self.bundles.iter()
}
pub fn len(&self) -> usize {
self.bundles.len()
}
pub fn is_empty(&self) -> bool {
self.bundles.is_empty()
}
#[deprecated(since = "0.9.0", note = "Use `X509BundleSet::get` instead.")]
pub fn bundle_for(&self, trust_domain: &TrustDomain) -> Option<&Arc<X509Bundle>> {
self.bundles.get(trust_domain)
}
}
impl Default for X509BundleSet {
fn default() -> Self {
Self::new()
}
}
impl BundleSource for X509BundleSet {
type Item = X509Bundle;
type Error = Infallible;
fn bundle_for_trust_domain(
&self,
trust_domain: &TrustDomain,
) -> Result<Option<Arc<Self::Item>>, Self::Error> {
Ok(self.get(trust_domain))
}
}
impl Extend<X509Bundle> for X509BundleSet {
fn extend<T: IntoIterator<Item = X509Bundle>>(&mut self, iter: T) {
for b in iter {
self.add_bundle(b);
}
}
}
impl FromIterator<X509Bundle> for X509BundleSet {
fn from_iter<T: IntoIterator<Item = X509Bundle>>(iter: T) -> Self {
let mut set = Self::new();
set.extend(iter);
set
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bundle_set_add_bundle_replaces_existing_for_same_trust_domain() {
let td = TrustDomain::new("example.org").unwrap();
let b1 = X509Bundle::new(td.clone());
let b2 = X509Bundle::new(td.clone());
let mut set = X509BundleSet::new();
set.add_bundle(b1);
assert_eq!(set.len(), 1);
set.add_bundle(b2);
assert_eq!(set.len(), 1, "should replace bundle for same trust domain");
assert!(set.get(&td).is_some());
}
#[test]
fn bundle_set_extend_and_from_iter_work() {
let td1 = TrustDomain::new("example.org").unwrap();
let td2 = TrustDomain::new("example2.org").unwrap();
let b1 = X509Bundle::new(td1);
let b2 = X509Bundle::new(td2);
let mut set = X509BundleSet::new();
set.extend([b1.clone(), b2.clone()]);
assert_eq!(set.len(), 2);
let set2: X509BundleSet = [b1, b2].into_iter().collect();
assert_eq!(set2.len(), 2);
}
#[test]
fn bundle_set_bundle_source_impl_matches_get() {
let td = TrustDomain::new("example.org").unwrap();
let b = X509Bundle::new(td.clone());
let mut set = X509BundleSet::new();
set.add_bundle(b);
let via_get = set.get(&td).unwrap();
let via_trait = set.bundle_for_trust_domain(&td).unwrap().unwrap();
assert_eq!(via_get, via_trait);
}
}