Skip to main content

atomr_remote/
tls.rs

1//! TLS settings + helpers.
2//!
3//! Phase 5.E of `docs/full-port-plan.md`. Akka.NET parity:
4//! `Akka.Remote.SslSettings`. The actual handshake plugs into the
5//! `TcpTransport` through the `TlsConfig` shape; that wiring lands
6//! once the reader/writer split (5.D) is in.
7//!
8//! For now we ship the typed configuration + a lightweight helper
9//! to load PEM-encoded cert/key pairs into the form `rustls`
10//! expects. Both feature-gated bits are deferred to keep the slim
11//! build dep-free.
12
13use std::path::PathBuf;
14
15use thiserror::Error;
16
17#[derive(Debug, Error)]
18#[non_exhaustive]
19pub enum TlsError {
20    #[error("io error reading `{path}`: {source}")]
21    Io {
22        path: String,
23        #[source]
24        source: std::io::Error,
25    },
26    #[error("invalid PEM input: {0}")]
27    Pem(String),
28}
29
30/// TLS configuration knobs surfaced on `RemoteSettings`.
31#[derive(Debug, Clone, Default)]
32#[non_exhaustive]
33pub struct TlsConfig {
34    /// PEM-encoded certificate chain.
35    pub cert_path: Option<PathBuf>,
36    /// PEM-encoded private key.
37    pub key_path: Option<PathBuf>,
38    /// PEM-encoded trust roots (defaults to the OS root store).
39    pub ca_path: Option<PathBuf>,
40    /// Require client cert verification (mTLS).
41    pub require_client_auth: bool,
42    /// SNI hostname when initiating outbound connections.
43    pub server_name: Option<String>,
44    /// Allow unknown roots (dev / self-signed). MUST be `false` in
45    /// production.
46    pub insecure_accept_any_cert: bool,
47}
48
49impl TlsConfig {
50    pub fn enabled(&self) -> bool {
51        self.cert_path.is_some() && self.key_path.is_some()
52    }
53
54    pub fn with_cert(mut self, p: impl Into<PathBuf>) -> Self {
55        self.cert_path = Some(p.into());
56        self
57    }
58
59    pub fn with_key(mut self, p: impl Into<PathBuf>) -> Self {
60        self.key_path = Some(p.into());
61        self
62    }
63
64    pub fn with_ca(mut self, p: impl Into<PathBuf>) -> Self {
65        self.ca_path = Some(p.into());
66        self
67    }
68
69    pub fn with_server_name(mut self, name: impl Into<String>) -> Self {
70        self.server_name = Some(name.into());
71        self
72    }
73
74    pub fn with_client_auth(mut self, on: bool) -> Self {
75        self.require_client_auth = on;
76        self
77    }
78}
79
80/// Best-effort PEM block extraction. Returns the DER bytes of every
81/// block whose header matches `expected_label` (e.g. "CERTIFICATE",
82/// "PRIVATE KEY"). Used as a base for the eventual `rustls`
83/// integration without taking the dep here.
84pub fn parse_pem_blocks(text: &str, expected_label: &str) -> Result<Vec<Vec<u8>>, TlsError> {
85    let begin = format!("-----BEGIN {expected_label}-----");
86    let end = format!("-----END {expected_label}-----");
87    let mut out = Vec::new();
88    let mut iter = text.split(&begin[..]);
89    let _ = iter.next(); // discard preamble
90    for block in iter {
91        let Some(end_idx) = block.find(&end[..]) else {
92            return Err(TlsError::Pem(format!("missing {end}")));
93        };
94        let body: String = block[..end_idx].chars().filter(|c| !c.is_whitespace()).collect();
95        let bytes = base64_decode(&body).map_err(|e| TlsError::Pem(format!("base64: {e}")))?;
96        out.push(bytes);
97    }
98    Ok(out)
99}
100
101/// Tiny standard-base64 decoder (no `=` padding required). Returns
102/// the decoded bytes or a string describing the offending character.
103fn base64_decode(s: &str) -> Result<Vec<u8>, String> {
104    fn val(c: u8) -> Option<u8> {
105        Some(match c {
106            b'A'..=b'Z' => c - b'A',
107            b'a'..=b'z' => c - b'a' + 26,
108            b'0'..=b'9' => c - b'0' + 52,
109            b'+' => 62,
110            b'/' => 63,
111            _ => return None,
112        })
113    }
114    let bytes: Vec<u8> = s.bytes().filter(|&b| b != b'=' && !b.is_ascii_whitespace()).collect();
115    let mut out = Vec::with_capacity(bytes.len() * 3 / 4);
116    let mut buf = 0u32;
117    let mut bits = 0u32;
118    for (i, &b) in bytes.iter().enumerate() {
119        let v = val(b).ok_or_else(|| format!("bad char at {i}: {b:#x}"))?;
120        buf = (buf << 6) | v as u32;
121        bits += 6;
122        if bits >= 8 {
123            bits -= 8;
124            out.push(((buf >> bits) & 0xff) as u8);
125        }
126    }
127    Ok(out)
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    #[test]
135    fn enabled_requires_both_cert_and_key() {
136        let mut t = TlsConfig::default();
137        assert!(!t.enabled());
138        t = t.with_cert("/etc/cert.pem");
139        assert!(!t.enabled());
140        t = t.with_key("/etc/key.pem");
141        assert!(t.enabled());
142    }
143
144    #[test]
145    fn builders_chain() {
146        let t = TlsConfig::default()
147            .with_cert("/c")
148            .with_key("/k")
149            .with_ca("/ca")
150            .with_server_name("example.com")
151            .with_client_auth(true);
152        assert!(t.enabled());
153        assert_eq!(t.server_name.as_deref(), Some("example.com"));
154        assert!(t.require_client_auth);
155    }
156
157    #[test]
158    fn parse_pem_extracts_certificate_block() {
159        let pem = "\
160-----BEGIN CERTIFICATE-----
161SGVsbG8gd29ybGQh
162-----END CERTIFICATE-----
163";
164        let blocks = parse_pem_blocks(pem, "CERTIFICATE").unwrap();
165        assert_eq!(blocks.len(), 1);
166        assert_eq!(blocks[0], b"Hello world!");
167    }
168
169    #[test]
170    fn parse_pem_handles_multiple_blocks() {
171        let pem = "\
172-----BEGIN CERTIFICATE-----
173SGVsbG8=
174-----END CERTIFICATE-----
175-----BEGIN CERTIFICATE-----
176V29ybGQ=
177-----END CERTIFICATE-----
178";
179        let blocks = parse_pem_blocks(pem, "CERTIFICATE").unwrap();
180        assert_eq!(blocks.len(), 2);
181        assert_eq!(blocks[0], b"Hello");
182        assert_eq!(blocks[1], b"World");
183    }
184
185    #[test]
186    fn parse_pem_missing_end_errors() {
187        let pem = "-----BEGIN CERTIFICATE-----\nSGV=\n";
188        let r = parse_pem_blocks(pem, "CERTIFICATE");
189        assert!(matches!(r, Err(TlsError::Pem(_))));
190    }
191}