#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
#[allow(non_camel_case_types)]
pub enum ProtocolVersion {
TLSv1_2,
TLSv1_3,
Other(u16),
}
pub struct CertVerify<'a> {
pub server_name: &'a str,
pub chain_der: &'a [Vec<u8>],
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CertVerdict {
Accept,
Reject,
}
#[derive(Clone)]
pub struct VerifyCallback(std::sync::Arc<dyn Fn(&CertVerify<'_>) -> CertVerdict + Send + Sync>);
impl VerifyCallback {
pub fn new(f: impl Fn(&CertVerify<'_>) -> CertVerdict + Send + Sync + 'static) -> Self {
VerifyCallback(std::sync::Arc::new(f))
}
pub fn call(&self, v: &CertVerify<'_>) -> CertVerdict {
(self.0)(v)
}
}
impl std::fmt::Debug for VerifyCallback {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("VerifyCallback(..)")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn verify_callback_invokes_closure_with_chain() {
let leaf = vec![1u8, 2, 3];
let chain = vec![leaf.clone()];
let cb = VerifyCallback::new(|v: &CertVerify<'_>| {
if v.server_name == "example.com" && v.chain_der.first() == Some(&vec![1u8, 2, 3]) {
CertVerdict::Accept
} else {
CertVerdict::Reject
}
});
assert_eq!(
cb.call(&CertVerify {
server_name: "example.com",
chain_der: &chain,
}),
CertVerdict::Accept
);
assert_eq!(
cb.call(&CertVerify {
server_name: "evil.com",
chain_der: &chain,
}),
CertVerdict::Reject
);
assert_eq!(format!("{cb:?}"), "VerifyCallback(..)");
}
}