Skip to main content

atomr_remote/
tls.rs

1//! TLS settings + helpers.
2//!
3//! The actual handshake plugs into the `TcpTransport` through the
4//! `TlsConfig` shape; that wiring lands once the reader/writer split (5.D)
5//! is in.
6//!
7//! For now we ship the typed configuration + a lightweight helper
8//! to load PEM-encoded cert/key pairs into the form `rustls`
9//! expects. Both feature-gated bits are deferred to keep the slim
10//! build dep-free.
11
12use std::path::PathBuf;
13
14use thiserror::Error;
15
16#[derive(Debug, Error)]
17#[non_exhaustive]
18pub enum TlsError {
19    #[error("io error reading `{path}`: {source}")]
20    Io {
21        path: String,
22        #[source]
23        source: std::io::Error,
24    },
25    #[error("invalid PEM input: {0}")]
26    Pem(String),
27}
28
29/// TLS configuration knobs surfaced on `RemoteSettings`.
30#[derive(Debug, Clone, Default)]
31#[non_exhaustive]
32pub struct TlsConfig {
33    /// PEM-encoded certificate chain.
34    pub cert_path: Option<PathBuf>,
35    /// PEM-encoded private key.
36    pub key_path: Option<PathBuf>,
37    /// PEM-encoded trust roots (defaults to the OS root store).
38    pub ca_path: Option<PathBuf>,
39    /// Require client cert verification (mTLS).
40    pub require_client_auth: bool,
41    /// SNI hostname when initiating outbound connections.
42    pub server_name: Option<String>,
43    /// Allow unknown roots (dev / self-signed). MUST be `false` in
44    /// production.
45    pub insecure_accept_any_cert: bool,
46}
47
48impl TlsConfig {
49    pub fn enabled(&self) -> bool {
50        self.cert_path.is_some() && self.key_path.is_some()
51    }
52
53    pub fn with_cert(mut self, p: impl Into<PathBuf>) -> Self {
54        self.cert_path = Some(p.into());
55        self
56    }
57
58    pub fn with_key(mut self, p: impl Into<PathBuf>) -> Self {
59        self.key_path = Some(p.into());
60        self
61    }
62
63    pub fn with_ca(mut self, p: impl Into<PathBuf>) -> Self {
64        self.ca_path = Some(p.into());
65        self
66    }
67
68    pub fn with_server_name(mut self, name: impl Into<String>) -> Self {
69        self.server_name = Some(name.into());
70        self
71    }
72
73    pub fn with_client_auth(mut self, on: bool) -> Self {
74        self.require_client_auth = on;
75        self
76    }
77}
78
79/// Best-effort PEM block extraction. Returns the DER bytes of every
80/// block whose header matches `expected_label` (e.g. "CERTIFICATE",
81/// "PRIVATE KEY"). Used as a base for the eventual `rustls`
82/// integration without taking the dep here.
83pub fn parse_pem_blocks(text: &str, expected_label: &str) -> Result<Vec<Vec<u8>>, TlsError> {
84    let begin = format!("-----BEGIN {expected_label}-----");
85    let end = format!("-----END {expected_label}-----");
86    let mut out = Vec::new();
87    let mut iter = text.split(&begin[..]);
88    let _ = iter.next(); // discard preamble
89    for block in iter {
90        let Some(end_idx) = block.find(&end[..]) else {
91            return Err(TlsError::Pem(format!("missing {end}")));
92        };
93        let body: String = block[..end_idx].chars().filter(|c| !c.is_whitespace()).collect();
94        let bytes = base64_decode(&body).map_err(|e| TlsError::Pem(format!("base64: {e}")))?;
95        out.push(bytes);
96    }
97    Ok(out)
98}
99
100/// Tiny standard-base64 decoder (no `=` padding required). Returns
101/// the decoded bytes or a string describing the offending character.
102fn base64_decode(s: &str) -> Result<Vec<u8>, String> {
103    fn val(c: u8) -> Option<u8> {
104        Some(match c {
105            b'A'..=b'Z' => c - b'A',
106            b'a'..=b'z' => c - b'a' + 26,
107            b'0'..=b'9' => c - b'0' + 52,
108            b'+' => 62,
109            b'/' => 63,
110            _ => return None,
111        })
112    }
113    let bytes: Vec<u8> = s.bytes().filter(|&b| b != b'=' && !b.is_ascii_whitespace()).collect();
114    let mut out = Vec::with_capacity(bytes.len() * 3 / 4);
115    let mut buf = 0u32;
116    let mut bits = 0u32;
117    for (i, &b) in bytes.iter().enumerate() {
118        let v = val(b).ok_or_else(|| format!("bad char at {i}: {b:#x}"))?;
119        buf = (buf << 6) | v as u32;
120        bits += 6;
121        if bits >= 8 {
122            bits -= 8;
123            out.push(((buf >> bits) & 0xff) as u8);
124        }
125    }
126    Ok(out)
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132
133    #[test]
134    fn enabled_requires_both_cert_and_key() {
135        let mut t = TlsConfig::default();
136        assert!(!t.enabled());
137        t = t.with_cert("/etc/cert.pem");
138        assert!(!t.enabled());
139        t = t.with_key("/etc/key.pem");
140        assert!(t.enabled());
141    }
142
143    #[test]
144    fn builders_chain() {
145        let t = TlsConfig::default()
146            .with_cert("/c")
147            .with_key("/k")
148            .with_ca("/ca")
149            .with_server_name("example.com")
150            .with_client_auth(true);
151        assert!(t.enabled());
152        assert_eq!(t.server_name.as_deref(), Some("example.com"));
153        assert!(t.require_client_auth);
154    }
155
156    #[test]
157    fn parse_pem_extracts_certificate_block() {
158        let pem = "\
159-----BEGIN CERTIFICATE-----
160SGVsbG8gd29ybGQh
161-----END CERTIFICATE-----
162";
163        let blocks = parse_pem_blocks(pem, "CERTIFICATE").unwrap();
164        assert_eq!(blocks.len(), 1);
165        assert_eq!(blocks[0], b"Hello world!");
166    }
167
168    #[test]
169    fn parse_pem_handles_multiple_blocks() {
170        let pem = "\
171-----BEGIN CERTIFICATE-----
172SGVsbG8=
173-----END CERTIFICATE-----
174-----BEGIN CERTIFICATE-----
175V29ybGQ=
176-----END CERTIFICATE-----
177";
178        let blocks = parse_pem_blocks(pem, "CERTIFICATE").unwrap();
179        assert_eq!(blocks.len(), 2);
180        assert_eq!(blocks[0], b"Hello");
181        assert_eq!(blocks[1], b"World");
182    }
183
184    #[test]
185    fn parse_pem_missing_end_errors() {
186        let pem = "-----BEGIN CERTIFICATE-----\nSGV=\n";
187        let r = parse_pem_blocks(pem, "CERTIFICATE");
188        assert!(matches!(r, Err(TlsError::Pem(_))));
189    }
190}