#![cfg_attr(not(feature = "std"), no_std)]
#![cfg_attr(docsrs, feature(doc_cfg))]
#![forbid(unsafe_code)]
#![warn(missing_docs, rust_2018_idioms)]
extern crate alloc;
use alloc::string::String;
use alloc::vec::Vec;
#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum AiaError {
FetchingDisabled,
HttpStatus(u16),
ResponseTooLarge {
limit: usize,
actual: usize,
},
MalformedCertificate(String),
Timeout,
UriUnsupported(String),
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
IoFailure {
#[cfg_attr(feature = "serde", serde(with = "io_error_kind_serde"))]
kind: std::io::ErrorKind,
message: String,
},
}
impl core::fmt::Display for AiaError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::FetchingDisabled => f.write_str("AIA fetching is disabled"),
Self::HttpStatus(code) => write!(f, "AIA fetch returned HTTP status {code}"),
Self::ResponseTooLarge { limit, actual } => write!(
f,
"AIA response exceeded size cap: limit {limit} bytes, observed {actual} bytes",
),
Self::MalformedCertificate(msg) => {
write!(f, "AIA-fetched bytes did not parse as a certificate: {msg}")
}
Self::Timeout => f.write_str("AIA fetch timed out"),
Self::UriUnsupported(uri) => write!(f, "AIA URI scheme not supported: {uri}"),
#[cfg(feature = "std")]
Self::IoFailure { kind, message } => {
write!(f, "AIA fetch I/O failure ({kind:?}): {message}")
}
}
}
}
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
impl std::error::Error for AiaError {}
#[cfg(all(feature = "std", feature = "serde"))]
mod io_error_kind_serde {
use serde::{Deserialize, Deserializer, Serializer};
use std::io::ErrorKind;
pub(super) fn serialize<S>(kind: &ErrorKind, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
use alloc::format;
let label = format!("{kind:?}");
serializer.serialize_str(&label)
}
pub(super) fn deserialize<'de, D>(deserializer: D) -> Result<ErrorKind, D::Error>
where
D: Deserializer<'de>,
{
let s = <&str>::deserialize(deserializer)?;
Ok(kind_for(s))
}
fn kind_for(label: &str) -> ErrorKind {
match label {
"NotFound" => ErrorKind::NotFound,
"PermissionDenied" => ErrorKind::PermissionDenied,
"ConnectionRefused" => ErrorKind::ConnectionRefused,
"ConnectionReset" => ErrorKind::ConnectionReset,
"ConnectionAborted" => ErrorKind::ConnectionAborted,
"NotConnected" => ErrorKind::NotConnected,
"AddrInUse" => ErrorKind::AddrInUse,
"AddrNotAvailable" => ErrorKind::AddrNotAvailable,
"BrokenPipe" => ErrorKind::BrokenPipe,
"AlreadyExists" => ErrorKind::AlreadyExists,
"WouldBlock" => ErrorKind::WouldBlock,
"InvalidInput" => ErrorKind::InvalidInput,
"InvalidData" => ErrorKind::InvalidData,
"TimedOut" => ErrorKind::TimedOut,
"WriteZero" => ErrorKind::WriteZero,
"Interrupted" => ErrorKind::Interrupted,
"Unsupported" => ErrorKind::Unsupported,
"UnexpectedEof" => ErrorKind::UnexpectedEof,
"OutOfMemory" => ErrorKind::OutOfMemory,
_ => ErrorKind::Other,
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::format;
#[test]
fn debug_label_round_trip_covers_msrv_variants() {
let cases: &[(ErrorKind, &str)] = &[
(ErrorKind::NotFound, "NotFound"),
(ErrorKind::PermissionDenied, "PermissionDenied"),
(ErrorKind::ConnectionRefused, "ConnectionRefused"),
(ErrorKind::ConnectionReset, "ConnectionReset"),
(ErrorKind::ConnectionAborted, "ConnectionAborted"),
(ErrorKind::NotConnected, "NotConnected"),
(ErrorKind::AddrInUse, "AddrInUse"),
(ErrorKind::AddrNotAvailable, "AddrNotAvailable"),
(ErrorKind::BrokenPipe, "BrokenPipe"),
(ErrorKind::AlreadyExists, "AlreadyExists"),
(ErrorKind::WouldBlock, "WouldBlock"),
(ErrorKind::InvalidInput, "InvalidInput"),
(ErrorKind::InvalidData, "InvalidData"),
(ErrorKind::TimedOut, "TimedOut"),
(ErrorKind::WriteZero, "WriteZero"),
(ErrorKind::Interrupted, "Interrupted"),
(ErrorKind::Unsupported, "Unsupported"),
(ErrorKind::UnexpectedEof, "UnexpectedEof"),
(ErrorKind::OutOfMemory, "OutOfMemory"),
(ErrorKind::Other, "Other"),
];
for (kind, expected_label) in cases {
let debug_label = format!("{kind:?}");
assert_eq!(
debug_label, *expected_label,
"Debug format for {kind:?}"
);
assert_eq!(
kind_for(expected_label),
*kind,
"kind_for({expected_label:?})"
);
}
}
#[test]
fn unknown_label_resolves_to_other() {
assert_eq!(kind_for("DefinitelyNotAVariant"), ErrorKind::Other);
assert_eq!(kind_for(""), ErrorKind::Other);
assert_eq!(kind_for(" NotFound "), ErrorKind::Other);
}
#[test]
fn post_msrv_variants_serialize_with_real_name() {
let post_msrv: &[(ErrorKind, &str)] = &[
(ErrorKind::HostUnreachable, "HostUnreachable"),
(ErrorKind::NetworkUnreachable, "NetworkUnreachable"),
(ErrorKind::NetworkDown, "NetworkDown"),
(ErrorKind::NotADirectory, "NotADirectory"),
(ErrorKind::IsADirectory, "IsADirectory"),
(ErrorKind::DirectoryNotEmpty, "DirectoryNotEmpty"),
(ErrorKind::ReadOnlyFilesystem, "ReadOnlyFilesystem"),
(ErrorKind::StaleNetworkFileHandle, "StaleNetworkFileHandle"),
(ErrorKind::StorageFull, "StorageFull"),
(ErrorKind::NotSeekable, "NotSeekable"),
(ErrorKind::FileTooLarge, "FileTooLarge"),
(ErrorKind::ResourceBusy, "ResourceBusy"),
(ErrorKind::ExecutableFileBusy, "ExecutableFileBusy"),
(ErrorKind::Deadlock, "Deadlock"),
(ErrorKind::CrossesDevices, "CrossesDevices"),
(ErrorKind::TooManyLinks, "TooManyLinks"),
(ErrorKind::InvalidFilename, "InvalidFilename"),
(ErrorKind::ArgumentListTooLong, "ArgumentListTooLong"),
];
for (kind, expected_label) in post_msrv {
let debug_label = format!("{kind:?}");
assert_eq!(
debug_label, *expected_label,
"Debug format for post-MSRV {kind:?}"
);
assert_eq!(
kind_for(expected_label),
ErrorKind::Other,
"kind_for({expected_label:?}) should gracefully degrade to Other"
);
}
}
}
}
pub trait AiaFetcher {
fn fetch(&self, uri: &str) -> Result<Vec<u8>, AiaError>;
fn batch_fetch(&self, uris: &[&str]) -> Vec<Result<Vec<u8>, AiaError>> {
uris.iter().map(|uri| self.fetch(uri)).collect()
}
}
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
pub struct NoAiaFetcher;
impl AiaFetcher for NoAiaFetcher {
fn fetch(&self, _uri: &str) -> Result<Vec<u8>, AiaError> {
Err(AiaError::FetchingDisabled)
}
}
const _: fn() = || {
fn _assert_send_sync<T: Send + Sync>() {}
_assert_send_sync::<AiaError>();
_assert_send_sync::<NoAiaFetcher>();
};
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "std")]
use alloc::format;
use alloc::string::ToString;
#[test]
fn display_fetching_disabled() {
assert_eq!(
AiaError::FetchingDisabled.to_string(),
"AIA fetching is disabled"
);
}
#[test]
fn display_http_status() {
assert_eq!(
AiaError::HttpStatus(503).to_string(),
"AIA fetch returned HTTP status 503"
);
}
#[test]
fn display_response_too_large() {
assert_eq!(
AiaError::ResponseTooLarge {
limit: 65_536,
actual: 131_072,
}
.to_string(),
"AIA response exceeded size cap: limit 65536 bytes, observed 131072 bytes",
);
}
#[test]
fn display_malformed_certificate() {
assert_eq!(
AiaError::MalformedCertificate("expected SEQUENCE got SET".into()).to_string(),
"AIA-fetched bytes did not parse as a certificate: expected SEQUENCE got SET",
);
}
#[test]
fn display_timeout() {
assert_eq!(AiaError::Timeout.to_string(), "AIA fetch timed out");
}
#[test]
fn display_uri_unsupported() {
assert_eq!(
AiaError::UriUnsupported("ldap://ca.example.com".into()).to_string(),
"AIA URI scheme not supported: ldap://ca.example.com",
);
}
#[test]
#[cfg(feature = "std")]
fn display_io_failure() {
let e = AiaError::IoFailure {
kind: std::io::ErrorKind::ConnectionRefused,
message: "connection refused by 10.0.0.1:443".into(),
};
assert_eq!(
format!("{e}"),
"AIA fetch I/O failure (ConnectionRefused): connection refused by 10.0.0.1:443",
);
}
#[test]
fn clone_and_eq_unit_variants() {
assert_eq!(
AiaError::FetchingDisabled,
AiaError::FetchingDisabled.clone()
);
assert_eq!(AiaError::Timeout, AiaError::Timeout.clone());
}
#[test]
fn clone_and_eq_carrying_variants() {
let a = AiaError::HttpStatus(404);
assert_eq!(a, a.clone());
let b = AiaError::ResponseTooLarge {
limit: 1024,
actual: 2048,
};
assert_eq!(b, b.clone());
let c = AiaError::MalformedCertificate("parse error at offset 7".into());
assert_eq!(c, c.clone());
let d = AiaError::UriUnsupported("ldap".into());
assert_eq!(d, d.clone());
}
#[test]
fn distinct_variants_are_not_equal() {
assert_ne!(AiaError::FetchingDisabled, AiaError::Timeout);
assert_ne!(AiaError::HttpStatus(404), AiaError::HttpStatus(503));
assert_ne!(
AiaError::UriUnsupported("ldap".into()),
AiaError::UriUnsupported("file".into()),
);
}
#[test]
#[cfg(feature = "std")]
fn io_failure_clone_and_eq() {
let a = AiaError::IoFailure {
kind: std::io::ErrorKind::TimedOut,
message: "deadline exceeded".into(),
};
assert_eq!(a, a.clone());
let b = AiaError::IoFailure {
kind: std::io::ErrorKind::TimedOut,
message: "different message".into(),
};
assert_ne!(a, b);
let c = AiaError::IoFailure {
kind: std::io::ErrorKind::NotFound,
message: "deadline exceeded".into(),
};
assert_ne!(a, c);
}
use alloc::vec;
use core::cell::Cell;
struct RecordingFetcher {
call_count: Cell<usize>,
}
impl RecordingFetcher {
fn new() -> Self {
Self {
call_count: Cell::new(0),
}
}
}
impl AiaFetcher for RecordingFetcher {
fn fetch(&self, uri: &str) -> Result<Vec<u8>, AiaError> {
self.call_count.set(self.call_count.get() + 1);
if uri.starts_with("http://") || uri.starts_with("https://") {
Ok(uri.as_bytes().to_vec())
} else {
Err(AiaError::UriUnsupported(uri.into()))
}
}
}
#[test]
fn fetch_records_each_call() {
let f = RecordingFetcher::new();
let r = f.fetch("http://ca.example/ca.crt").expect("ok");
assert_eq!(r, b"http://ca.example/ca.crt".to_vec());
assert_eq!(f.call_count.get(), 1);
let _ = f.fetch("http://ca.example/ca.crt");
let _ = f.fetch("http://ca.example/ca.crt");
assert_eq!(f.call_count.get(), 3);
}
#[test]
fn fetch_classifies_unsupported_scheme() {
let f = RecordingFetcher::new();
let r = f.fetch("ldap://ca.example/cn=ca");
assert_eq!(
r,
Err(AiaError::UriUnsupported("ldap://ca.example/cn=ca".into())),
);
assert_eq!(f.call_count.get(), 1);
}
#[test]
fn batch_fetch_default_impl_iterates_each_uri() {
let f = RecordingFetcher::new();
let uris: &[&str] = &[
"http://ca.example/a.crt",
"ldap://ca.example/b",
"https://ca.example/c.crt",
];
let results = f.batch_fetch(uris);
assert_eq!(results.len(), 3);
assert_eq!(results[0], Ok(b"http://ca.example/a.crt".to_vec()));
assert_eq!(
results[1],
Err(AiaError::UriUnsupported("ldap://ca.example/b".into())),
);
assert_eq!(results[2], Ok(b"https://ca.example/c.crt".to_vec()));
assert_eq!(f.call_count.get(), 3);
}
#[test]
fn batch_fetch_empty_input_returns_empty_output() {
let f = RecordingFetcher::new();
let empty: &[&str] = &[];
let results = f.batch_fetch(empty);
assert!(results.is_empty());
assert_eq!(f.call_count.get(), 0);
}
#[test]
fn batch_fetch_preserves_order() {
let f = RecordingFetcher::new();
let uris: &[&str] = &["http://a", "http://b", "http://c"];
let results = f.batch_fetch(uris);
let expected = vec![
Ok(b"http://a".to_vec()),
Ok(b"http://b".to_vec()),
Ok(b"http://c".to_vec()),
];
assert_eq!(results, expected);
}
struct OverriddenBatchFetcher {
batch_calls: Cell<usize>,
}
impl AiaFetcher for OverriddenBatchFetcher {
fn fetch(&self, _uri: &str) -> Result<Vec<u8>, AiaError> {
unreachable!("override should not delegate to fetch")
}
fn batch_fetch(&self, uris: &[&str]) -> Vec<Result<Vec<u8>, AiaError>> {
self.batch_calls.set(self.batch_calls.get() + 1);
uris.iter().map(|_| Err(AiaError::Timeout)).collect()
}
}
#[test]
fn batch_fetch_override_takes_precedence() {
let f = OverriddenBatchFetcher {
batch_calls: Cell::new(0),
};
let results = f.batch_fetch(&["http://a", "http://b"]);
assert_eq!(results.len(), 2);
assert_eq!(results[0], Err(AiaError::Timeout));
assert_eq!(results[1], Err(AiaError::Timeout));
assert_eq!(f.batch_calls.get(), 1);
}
#[test]
fn no_aia_fetcher_fetch_returns_fetching_disabled_for_any_uri() {
let f = NoAiaFetcher;
for uri in [
"http://ca.example/ca.crt",
"https://ca.example/ca.crt",
"ldap://ca.example/cn=ca",
"file:///etc/ssl/ca.pem",
"",
] {
assert_eq!(
f.fetch(uri),
Err(AiaError::FetchingDisabled),
"fetch({uri:?})",
);
}
}
#[test]
fn no_aia_fetcher_batch_fetch_returns_fetching_disabled_per_uri() {
let f = NoAiaFetcher;
let uris: &[&str] = &[
"http://ca.example/a.crt",
"http://ca.example/b.crt",
"http://ca.example/c.crt",
];
let results = f.batch_fetch(uris);
assert_eq!(results.len(), 3);
for (i, result) in results.iter().enumerate() {
assert_eq!(
*result,
Err(AiaError::FetchingDisabled),
"batch_fetch index {i}",
);
}
}
#[test]
fn no_aia_fetcher_batch_fetch_empty_input() {
let f = NoAiaFetcher;
let empty: &[&str] = &[];
let results = f.batch_fetch(empty);
assert!(results.is_empty());
}
#[test]
fn no_aia_fetcher_is_zero_sized() {
assert_eq!(core::mem::size_of::<NoAiaFetcher>(), 0);
}
#[test]
fn no_aia_fetcher_derives_default() {
let f: NoAiaFetcher = Default::default();
assert_eq!(f.fetch("http://x"), Err(AiaError::FetchingDisabled));
}
#[test]
fn no_aia_fetcher_is_copy() {
let a = NoAiaFetcher;
let b = a;
assert_eq!(a.fetch("http://x"), Err(AiaError::FetchingDisabled));
assert_eq!(b.fetch("http://x"), Err(AiaError::FetchingDisabled));
}
}