1use std::path::PathBuf;
14
15use thiserror::Error;
16
17#[derive(Debug, Error)]
18#[non_exhaustive]
19pub enum TlsError {
20 #[error("io error reading `{path}`: {source}")]
21 Io {
22 path: String,
23 #[source]
24 source: std::io::Error,
25 },
26 #[error("invalid PEM input: {0}")]
27 Pem(String),
28}
29
30#[derive(Debug, Clone, Default)]
32#[non_exhaustive]
33pub struct TlsConfig {
34 pub cert_path: Option<PathBuf>,
36 pub key_path: Option<PathBuf>,
38 pub ca_path: Option<PathBuf>,
40 pub require_client_auth: bool,
42 pub server_name: Option<String>,
44 pub insecure_accept_any_cert: bool,
47}
48
49impl TlsConfig {
50 pub fn enabled(&self) -> bool {
51 self.cert_path.is_some() && self.key_path.is_some()
52 }
53
54 pub fn with_cert(mut self, p: impl Into<PathBuf>) -> Self {
55 self.cert_path = Some(p.into());
56 self
57 }
58
59 pub fn with_key(mut self, p: impl Into<PathBuf>) -> Self {
60 self.key_path = Some(p.into());
61 self
62 }
63
64 pub fn with_ca(mut self, p: impl Into<PathBuf>) -> Self {
65 self.ca_path = Some(p.into());
66 self
67 }
68
69 pub fn with_server_name(mut self, name: impl Into<String>) -> Self {
70 self.server_name = Some(name.into());
71 self
72 }
73
74 pub fn with_client_auth(mut self, on: bool) -> Self {
75 self.require_client_auth = on;
76 self
77 }
78}
79
80pub fn parse_pem_blocks(text: &str, expected_label: &str) -> Result<Vec<Vec<u8>>, TlsError> {
85 let begin = format!("-----BEGIN {expected_label}-----");
86 let end = format!("-----END {expected_label}-----");
87 let mut out = Vec::new();
88 let mut iter = text.split(&begin[..]);
89 let _ = iter.next(); for block in iter {
91 let Some(end_idx) = block.find(&end[..]) else {
92 return Err(TlsError::Pem(format!("missing {end}")));
93 };
94 let body: String = block[..end_idx].chars().filter(|c| !c.is_whitespace()).collect();
95 let bytes = base64_decode(&body).map_err(|e| TlsError::Pem(format!("base64: {e}")))?;
96 out.push(bytes);
97 }
98 Ok(out)
99}
100
101fn base64_decode(s: &str) -> Result<Vec<u8>, String> {
104 fn val(c: u8) -> Option<u8> {
105 Some(match c {
106 b'A'..=b'Z' => c - b'A',
107 b'a'..=b'z' => c - b'a' + 26,
108 b'0'..=b'9' => c - b'0' + 52,
109 b'+' => 62,
110 b'/' => 63,
111 _ => return None,
112 })
113 }
114 let bytes: Vec<u8> = s.bytes().filter(|&b| b != b'=' && !b.is_ascii_whitespace()).collect();
115 let mut out = Vec::with_capacity(bytes.len() * 3 / 4);
116 let mut buf = 0u32;
117 let mut bits = 0u32;
118 for (i, &b) in bytes.iter().enumerate() {
119 let v = val(b).ok_or_else(|| format!("bad char at {i}: {b:#x}"))?;
120 buf = (buf << 6) | v as u32;
121 bits += 6;
122 if bits >= 8 {
123 bits -= 8;
124 out.push(((buf >> bits) & 0xff) as u8);
125 }
126 }
127 Ok(out)
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133
134 #[test]
135 fn enabled_requires_both_cert_and_key() {
136 let mut t = TlsConfig::default();
137 assert!(!t.enabled());
138 t = t.with_cert("/etc/cert.pem");
139 assert!(!t.enabled());
140 t = t.with_key("/etc/key.pem");
141 assert!(t.enabled());
142 }
143
144 #[test]
145 fn builders_chain() {
146 let t = TlsConfig::default()
147 .with_cert("/c")
148 .with_key("/k")
149 .with_ca("/ca")
150 .with_server_name("example.com")
151 .with_client_auth(true);
152 assert!(t.enabled());
153 assert_eq!(t.server_name.as_deref(), Some("example.com"));
154 assert!(t.require_client_auth);
155 }
156
157 #[test]
158 fn parse_pem_extracts_certificate_block() {
159 let pem = "\
160-----BEGIN CERTIFICATE-----
161SGVsbG8gd29ybGQh
162-----END CERTIFICATE-----
163";
164 let blocks = parse_pem_blocks(pem, "CERTIFICATE").unwrap();
165 assert_eq!(blocks.len(), 1);
166 assert_eq!(blocks[0], b"Hello world!");
167 }
168
169 #[test]
170 fn parse_pem_handles_multiple_blocks() {
171 let pem = "\
172-----BEGIN CERTIFICATE-----
173SGVsbG8=
174-----END CERTIFICATE-----
175-----BEGIN CERTIFICATE-----
176V29ybGQ=
177-----END CERTIFICATE-----
178";
179 let blocks = parse_pem_blocks(pem, "CERTIFICATE").unwrap();
180 assert_eq!(blocks.len(), 2);
181 assert_eq!(blocks[0], b"Hello");
182 assert_eq!(blocks[1], b"World");
183 }
184
185 #[test]
186 fn parse_pem_missing_end_errors() {
187 let pem = "-----BEGIN CERTIFICATE-----\nSGV=\n";
188 let r = parse_pem_blocks(pem, "CERTIFICATE");
189 assert!(matches!(r, Err(TlsError::Pem(_))));
190 }
191}