use super::builder::ResourceLimits;
use super::errors::{LimitKind, MetricsErrorKind, X509SourceError};
use super::metrics::MetricsRecorder;
use crate::workload_api::x509_context::X509Context;
use crate::x509_source::types::SvidPicker;
use std::sync::Arc;
pub(super) fn validate_limits(
ctx: &X509Context,
limits: ResourceLimits,
) -> Result<(), X509SourceError> {
if let Some(max_svids) = limits.max_svids {
let actual = ctx.svids().len();
if actual > max_svids {
return Err(X509SourceError::ResourceLimitExceeded {
kind: LimitKind::MaxSvids,
limit: max_svids,
actual,
});
}
}
if let Some(max_bundles) = limits.max_bundles {
let actual = ctx.bundle_set().len();
if actual > max_bundles {
return Err(X509SourceError::ResourceLimitExceeded {
kind: LimitKind::MaxBundles,
limit: max_bundles,
actual,
});
}
}
if let Some(max_bundle_der_bytes) = limits.max_bundle_der_bytes {
for (_, bundle) in ctx.bundle_set().iter() {
let actual: usize = bundle
.authorities()
.iter()
.map(|cert| cert.as_bytes().len())
.sum();
if actual > max_bundle_der_bytes {
return Err(X509SourceError::ResourceLimitExceeded {
kind: LimitKind::MaxBundleDerBytes,
limit: max_bundle_der_bytes,
actual,
});
}
}
}
Ok(())
}
pub(super) fn validate_limits_and_record_metric(
ctx: &X509Context,
limits: ResourceLimits,
metrics: Option<&dyn MetricsRecorder>,
) -> Result<(), X509SourceError> {
if let Err(e) = validate_limits(ctx, limits) {
if let X509SourceError::ResourceLimitExceeded { kind, .. } = &e {
if let Some(m) = metrics {
m.record_error(metric_kind_for_limit(*kind));
}
}
return Err(e);
}
Ok(())
}
pub(super) const fn metric_kind_for_limit(kind: LimitKind) -> MetricsErrorKind {
match kind {
LimitKind::MaxSvids => MetricsErrorKind::LimitMaxSvids,
LimitKind::MaxBundles => MetricsErrorKind::LimitMaxBundles,
LimitKind::MaxBundleDerBytes => MetricsErrorKind::LimitMaxBundleDerBytes,
}
}
pub(super) fn select_svid(
ctx: &X509Context,
picker: Option<&dyn SvidPicker>,
) -> Option<Arc<crate::X509Svid>> {
if let Some(p) = picker {
p.pick_svid(ctx.svids())
.and_then(|idx| ctx.svids().get(idx))
.cloned()
} else {
ctx.default_svid().cloned()
}
}
pub(super) fn validate_context(
ctx: &X509Context,
picker: Option<&dyn SvidPicker>,
limits: ResourceLimits,
metrics: Option<&dyn MetricsRecorder>,
) -> Result<(), X509SourceError> {
validate_limits_and_record_metric(ctx, limits, metrics)?;
if select_svid(ctx, picker).is_none() {
if let Some(m) = metrics {
m.record_error(MetricsErrorKind::NoSuitableSvid);
}
return Err(X509SourceError::NoSuitableSvid);
}
Ok(())
}