1use std::path::PathBuf;
13
14use thiserror::Error;
15
16#[derive(Debug, Error)]
17#[non_exhaustive]
18pub enum TlsError {
19 #[error("io error reading `{path}`: {source}")]
20 Io {
21 path: String,
22 #[source]
23 source: std::io::Error,
24 },
25 #[error("invalid PEM input: {0}")]
26 Pem(String),
27}
28
29#[derive(Debug, Clone, Default)]
31#[non_exhaustive]
32pub struct TlsConfig {
33 pub cert_path: Option<PathBuf>,
35 pub key_path: Option<PathBuf>,
37 pub ca_path: Option<PathBuf>,
39 pub require_client_auth: bool,
41 pub server_name: Option<String>,
43 pub insecure_accept_any_cert: bool,
46}
47
48impl TlsConfig {
49 pub fn enabled(&self) -> bool {
50 self.cert_path.is_some() && self.key_path.is_some()
51 }
52
53 pub fn with_cert(mut self, p: impl Into<PathBuf>) -> Self {
54 self.cert_path = Some(p.into());
55 self
56 }
57
58 pub fn with_key(mut self, p: impl Into<PathBuf>) -> Self {
59 self.key_path = Some(p.into());
60 self
61 }
62
63 pub fn with_ca(mut self, p: impl Into<PathBuf>) -> Self {
64 self.ca_path = Some(p.into());
65 self
66 }
67
68 pub fn with_server_name(mut self, name: impl Into<String>) -> Self {
69 self.server_name = Some(name.into());
70 self
71 }
72
73 pub fn with_client_auth(mut self, on: bool) -> Self {
74 self.require_client_auth = on;
75 self
76 }
77}
78
79pub fn parse_pem_blocks(text: &str, expected_label: &str) -> Result<Vec<Vec<u8>>, TlsError> {
84 let begin = format!("-----BEGIN {expected_label}-----");
85 let end = format!("-----END {expected_label}-----");
86 let mut out = Vec::new();
87 let mut iter = text.split(&begin[..]);
88 let _ = iter.next(); for block in iter {
90 let Some(end_idx) = block.find(&end[..]) else {
91 return Err(TlsError::Pem(format!("missing {end}")));
92 };
93 let body: String = block[..end_idx].chars().filter(|c| !c.is_whitespace()).collect();
94 let bytes = base64_decode(&body).map_err(|e| TlsError::Pem(format!("base64: {e}")))?;
95 out.push(bytes);
96 }
97 Ok(out)
98}
99
100fn base64_decode(s: &str) -> Result<Vec<u8>, String> {
103 fn val(c: u8) -> Option<u8> {
104 Some(match c {
105 b'A'..=b'Z' => c - b'A',
106 b'a'..=b'z' => c - b'a' + 26,
107 b'0'..=b'9' => c - b'0' + 52,
108 b'+' => 62,
109 b'/' => 63,
110 _ => return None,
111 })
112 }
113 let bytes: Vec<u8> = s.bytes().filter(|&b| b != b'=' && !b.is_ascii_whitespace()).collect();
114 let mut out = Vec::with_capacity(bytes.len() * 3 / 4);
115 let mut buf = 0u32;
116 let mut bits = 0u32;
117 for (i, &b) in bytes.iter().enumerate() {
118 let v = val(b).ok_or_else(|| format!("bad char at {i}: {b:#x}"))?;
119 buf = (buf << 6) | v as u32;
120 bits += 6;
121 if bits >= 8 {
122 bits -= 8;
123 out.push(((buf >> bits) & 0xff) as u8);
124 }
125 }
126 Ok(out)
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 #[test]
134 fn enabled_requires_both_cert_and_key() {
135 let mut t = TlsConfig::default();
136 assert!(!t.enabled());
137 t = t.with_cert("/etc/cert.pem");
138 assert!(!t.enabled());
139 t = t.with_key("/etc/key.pem");
140 assert!(t.enabled());
141 }
142
143 #[test]
144 fn builders_chain() {
145 let t = TlsConfig::default()
146 .with_cert("/c")
147 .with_key("/k")
148 .with_ca("/ca")
149 .with_server_name("example.com")
150 .with_client_auth(true);
151 assert!(t.enabled());
152 assert_eq!(t.server_name.as_deref(), Some("example.com"));
153 assert!(t.require_client_auth);
154 }
155
156 #[test]
157 fn parse_pem_extracts_certificate_block() {
158 let pem = "\
159-----BEGIN CERTIFICATE-----
160SGVsbG8gd29ybGQh
161-----END CERTIFICATE-----
162";
163 let blocks = parse_pem_blocks(pem, "CERTIFICATE").unwrap();
164 assert_eq!(blocks.len(), 1);
165 assert_eq!(blocks[0], b"Hello world!");
166 }
167
168 #[test]
169 fn parse_pem_handles_multiple_blocks() {
170 let pem = "\
171-----BEGIN CERTIFICATE-----
172SGVsbG8=
173-----END CERTIFICATE-----
174-----BEGIN CERTIFICATE-----
175V29ybGQ=
176-----END CERTIFICATE-----
177";
178 let blocks = parse_pem_blocks(pem, "CERTIFICATE").unwrap();
179 assert_eq!(blocks.len(), 2);
180 assert_eq!(blocks[0], b"Hello");
181 assert_eq!(blocks[1], b"World");
182 }
183
184 #[test]
185 fn parse_pem_missing_end_errors() {
186 let pem = "-----BEGIN CERTIFICATE-----\nSGV=\n";
187 let r = parse_pem_blocks(pem, "CERTIFICATE");
188 assert!(matches!(r, Err(TlsError::Pem(_))));
189 }
190}