use std::path::PathBuf;
use thiserror::Error;
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum TlsError {
#[error("io error reading `{path}`: {source}")]
Io {
path: String,
#[source]
source: std::io::Error,
},
#[error("invalid PEM input: {0}")]
Pem(String),
}
#[derive(Debug, Clone, Default)]
#[non_exhaustive]
pub struct TlsConfig {
pub cert_path: Option<PathBuf>,
pub key_path: Option<PathBuf>,
pub ca_path: Option<PathBuf>,
pub require_client_auth: bool,
pub server_name: Option<String>,
pub insecure_accept_any_cert: bool,
}
impl TlsConfig {
pub fn enabled(&self) -> bool {
self.cert_path.is_some() && self.key_path.is_some()
}
pub fn with_cert(mut self, p: impl Into<PathBuf>) -> Self {
self.cert_path = Some(p.into());
self
}
pub fn with_key(mut self, p: impl Into<PathBuf>) -> Self {
self.key_path = Some(p.into());
self
}
pub fn with_ca(mut self, p: impl Into<PathBuf>) -> Self {
self.ca_path = Some(p.into());
self
}
pub fn with_server_name(mut self, name: impl Into<String>) -> Self {
self.server_name = Some(name.into());
self
}
pub fn with_client_auth(mut self, on: bool) -> Self {
self.require_client_auth = on;
self
}
}
pub fn parse_pem_blocks(text: &str, expected_label: &str) -> Result<Vec<Vec<u8>>, TlsError> {
let begin = format!("-----BEGIN {expected_label}-----");
let end = format!("-----END {expected_label}-----");
let mut out = Vec::new();
let mut iter = text.split(&begin[..]);
let _ = iter.next(); for block in iter {
let Some(end_idx) = block.find(&end[..]) else {
return Err(TlsError::Pem(format!("missing {end}")));
};
let body: String = block[..end_idx].chars().filter(|c| !c.is_whitespace()).collect();
let bytes = base64_decode(&body).map_err(|e| TlsError::Pem(format!("base64: {e}")))?;
out.push(bytes);
}
Ok(out)
}
fn base64_decode(s: &str) -> Result<Vec<u8>, String> {
fn val(c: u8) -> Option<u8> {
Some(match c {
b'A'..=b'Z' => c - b'A',
b'a'..=b'z' => c - b'a' + 26,
b'0'..=b'9' => c - b'0' + 52,
b'+' => 62,
b'/' => 63,
_ => return None,
})
}
let bytes: Vec<u8> = s.bytes().filter(|&b| b != b'=' && !b.is_ascii_whitespace()).collect();
let mut out = Vec::with_capacity(bytes.len() * 3 / 4);
let mut buf = 0u32;
let mut bits = 0u32;
for (i, &b) in bytes.iter().enumerate() {
let v = val(b).ok_or_else(|| format!("bad char at {i}: {b:#x}"))?;
buf = (buf << 6) | v as u32;
bits += 6;
if bits >= 8 {
bits -= 8;
out.push(((buf >> bits) & 0xff) as u8);
}
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn enabled_requires_both_cert_and_key() {
let mut t = TlsConfig::default();
assert!(!t.enabled());
t = t.with_cert("/etc/cert.pem");
assert!(!t.enabled());
t = t.with_key("/etc/key.pem");
assert!(t.enabled());
}
#[test]
fn builders_chain() {
let t = TlsConfig::default()
.with_cert("/c")
.with_key("/k")
.with_ca("/ca")
.with_server_name("example.com")
.with_client_auth(true);
assert!(t.enabled());
assert_eq!(t.server_name.as_deref(), Some("example.com"));
assert!(t.require_client_auth);
}
#[test]
fn parse_pem_extracts_certificate_block() {
let pem = "\
-----BEGIN CERTIFICATE-----
SGVsbG8gd29ybGQh
-----END CERTIFICATE-----
";
let blocks = parse_pem_blocks(pem, "CERTIFICATE").unwrap();
assert_eq!(blocks.len(), 1);
assert_eq!(blocks[0], b"Hello world!");
}
#[test]
fn parse_pem_handles_multiple_blocks() {
let pem = "\
-----BEGIN CERTIFICATE-----
SGVsbG8=
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
V29ybGQ=
-----END CERTIFICATE-----
";
let blocks = parse_pem_blocks(pem, "CERTIFICATE").unwrap();
assert_eq!(blocks.len(), 2);
assert_eq!(blocks[0], b"Hello");
assert_eq!(blocks[1], b"World");
}
#[test]
fn parse_pem_missing_end_errors() {
let pem = "-----BEGIN CERTIFICATE-----\nSGV=\n";
let r = parse_pem_blocks(pem, "CERTIFICATE");
assert!(matches!(r, Err(TlsError::Pem(_))));
}
}