Skip to main content

durable_streams_server/
startup.rs

1//! Startup preflight, phase-aware diagnostics, and typed error mapping.
2//!
3//! The server startup sequence is divided into explicit phases. Each phase
4//! produces a typed error that preserves the underlying cause chain and
5//! identifies the failing phase so operators can act on the first failure
6//! without guessing which stage broke.
7
8use crate::config::{Config, ConfigLoadError, ConfigValidationError, TlsVersion, TransportMode};
9use rustls::RootCertStore;
10use rustls::server::WebPkiClientVerifier;
11use std::fmt;
12use std::io;
13use std::net::SocketAddr;
14use std::path::Path;
15use std::sync::Arc;
16use thiserror::Error;
17
18// ── Startup phases ─────────────────────────────────────────────────
19
20/// Discrete phases of the server startup sequence.
21///
22/// Each phase maps to a single logical step. When a phase fails, the
23/// error identifies which phase produced the failure so operators can
24/// diagnose without guessing.
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum StartupPhase {
27    /// Loading and merging configuration sources.
28    LoadConfig,
29    /// Validating the merged configuration.
30    ValidateConfig,
31    /// Resolving the transport mode (HTTP / TLS / mTLS).
32    ResolveTransport,
33    /// Checking TLS file presence and readability.
34    CheckTlsFiles,
35    /// Building the rustls `ServerConfig`.
36    BuildTlsContext,
37    /// Binding the TCP listener.
38    BindListener,
39    /// Running the server after binding.
40    StartServer,
41}
42
43impl fmt::Display for StartupPhase {
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45        f.write_str(match self {
46            Self::LoadConfig => "load_config",
47            Self::ValidateConfig => "validate_config",
48            Self::ResolveTransport => "resolve_transport",
49            Self::CheckTlsFiles => "check_tls_files",
50            Self::BuildTlsContext => "build_tls_context",
51            Self::BindListener => "bind_listener",
52            Self::StartServer => "start_server",
53        })
54    }
55}
56
57// ── Startup errors ─────────────────────────────────────────────────
58
59/// Top-level startup error carrying phase context and a typed cause.
60#[derive(Debug, Error)]
61#[error("[{phase}] {kind}")]
62pub struct StartupError {
63    /// The phase that failed.
64    pub phase: StartupPhase,
65    /// The typed cause.
66    pub kind: StartupErrorKind,
67}
68
69impl StartupError {
70    #[must_use]
71    pub fn new(phase: StartupPhase, kind: StartupErrorKind) -> Self {
72        Self { phase, kind }
73    }
74}
75
76/// Typed cause attached to a [`StartupError`].
77#[derive(Debug, Error)]
78pub enum StartupErrorKind {
79    /// Configuration file could not be loaded or parsed.
80    #[error("config load failed: {0}")]
81    ConfigLoad(#[from] ConfigLoadError),
82
83    /// Merged configuration failed validation.
84    #[error("config validation failed: {0}")]
85    ConfigValidation(#[from] ConfigValidationError),
86
87    /// A required TLS file is missing from the filesystem.
88    #[error("TLS file not found: {path}")]
89    TlsFileNotFound { path: String },
90
91    /// A TLS file exists but is not a regular file (e.g. directory, symlink to dir).
92    #[error("TLS path is not a regular file: {path}")]
93    TlsFileNotRegular { path: String },
94
95    /// A TLS file exists but could not be read (permissions, I/O).
96    #[error("TLS file is not readable: {path}: {reason}")]
97    TlsFileNotReadable { path: String, reason: String },
98
99    /// The rustls `ServerConfig` could not be built from the provided PEM files.
100    #[error("failed to build TLS context: {0}")]
101    TlsContext(String),
102
103    /// The TCP listener could not bind to the configured address.
104    #[error("failed to bind {addr}: {source}")]
105    Bind { addr: SocketAddr, source: io::Error },
106
107    /// A runtime error after the server began accepting connections.
108    #[error("server error: {0}")]
109    Runtime(String),
110}
111
112// ── TLS file preflight ─────────────────────────────────────────────
113
114/// Classification of a single TLS file check.
115#[derive(Debug, Clone, PartialEq, Eq)]
116pub enum TlsFileStatus {
117    /// File exists, is a regular file, and is readable.
118    Ok,
119    /// File does not exist on disk.
120    NotFound,
121    /// Path exists but is not a regular file.
122    NotRegular,
123    /// File exists but cannot be read.
124    NotReadable(String),
125}
126
127/// Check whether a single path refers to a readable regular file.
128#[must_use]
129pub fn check_tls_file(path: &str) -> TlsFileStatus {
130    let p = Path::new(path);
131    let metadata = match std::fs::metadata(p) {
132        Ok(m) => m,
133        Err(e) if e.kind() == io::ErrorKind::NotFound => return TlsFileStatus::NotFound,
134        Err(e) => return TlsFileStatus::NotReadable(e.to_string()),
135    };
136    if !metadata.is_file() {
137        return TlsFileStatus::NotRegular;
138    }
139    // Attempt to open for reading to verify permissions.
140    if let Err(e) = std::fs::File::open(p) {
141        return TlsFileStatus::NotReadable(e.to_string());
142    }
143    TlsFileStatus::Ok
144}
145
146/// Run preflight checks on all configured TLS file paths.
147///
148/// Returns `Ok(())` if all present paths pass, or the first error found.
149///
150/// # Errors
151///
152/// Returns a [`StartupError`] with phase [`StartupPhase::CheckTlsFiles`] when
153/// a configured TLS file is missing, not a regular file, or not readable.
154pub fn preflight_tls_files(config: &Config) -> Result<(), StartupError> {
155    let files: Vec<(&str, &str)> = [
156        ("cert_path", config.transport.tls.cert_path.as_deref()),
157        ("key_path", config.transport.tls.key_path.as_deref()),
158        (
159            "client_ca_path",
160            config.transport.tls.client_ca_path.as_deref(),
161        ),
162    ]
163    .into_iter()
164    .filter_map(|(label, path)| path.map(|p| (label, p)))
165    .collect();
166
167    for (_label, path) in &files {
168        match check_tls_file(path) {
169            TlsFileStatus::Ok => {}
170            TlsFileStatus::NotFound => {
171                return Err(StartupError::new(
172                    StartupPhase::CheckTlsFiles,
173                    StartupErrorKind::TlsFileNotFound {
174                        path: (*path).to_string(),
175                    },
176                ));
177            }
178            TlsFileStatus::NotRegular => {
179                return Err(StartupError::new(
180                    StartupPhase::CheckTlsFiles,
181                    StartupErrorKind::TlsFileNotRegular {
182                        path: (*path).to_string(),
183                    },
184                ));
185            }
186            TlsFileStatus::NotReadable(reason) => {
187                return Err(StartupError::new(
188                    StartupPhase::CheckTlsFiles,
189                    StartupErrorKind::TlsFileNotReadable {
190                        path: (*path).to_string(),
191                        reason,
192                    },
193                ));
194            }
195        }
196    }
197    Ok(())
198}
199
200/// Bind a TCP listener and classify failures under the bind phase.
201///
202/// # Errors
203///
204/// Returns a [`StartupError`] with phase [`StartupPhase::BindListener`] when
205/// the address cannot be bound or non-blocking mode cannot be enabled.
206pub fn bind_tcp_listener(addr: SocketAddr) -> Result<std::net::TcpListener, StartupError> {
207    let listener =
208        std::net::TcpListener::bind(addr).map_err(|source| StartupError::bind(addr, source))?;
209    listener
210        .set_nonblocking(true)
211        .map_err(|source| StartupError::bind(addr, source))?;
212    Ok(listener)
213}
214
215// ── TLS server config builder ──────────────────────────────────────
216
217/// Build a [`rustls::ServerConfig`] from the resolved [`Config`].
218///
219/// Handles:
220/// - TLS protocol version selection (1.2, 1.3, or both)
221/// - ALPN protocol negotiation
222/// - Server certificate and private key loading
223/// - Client certificate verification for mTLS mode
224///
225/// # Errors
226///
227/// Returns a [`StartupError`] in the [`StartupPhase::BuildTlsContext`] phase
228/// when certificate/key loading or verifier construction fails.
229///
230/// # Panics
231///
232/// Panics if `cert_path` or `key_path` (or `client_ca_path` for mTLS) are
233/// `None`. Callers must run [`Config::validate`] before calling this function.
234pub fn build_tls_server_config(config: &Config) -> Result<rustls::ServerConfig, StartupError> {
235    let cert_path = config
236        .transport
237        .tls
238        .cert_path
239        .as_deref()
240        .expect("cert_path validated present before build_tls_server_config");
241    let key_path = config
242        .transport
243        .tls
244        .key_path
245        .as_deref()
246        .expect("key_path validated present before build_tls_server_config");
247
248    // ── TLS version selection ──────────────────────────────────────
249    let versions = tls_protocol_versions(
250        config.transport.tls.min_version,
251        config.transport.tls.max_version,
252    );
253
254    // ── Certificate chain and private key ──────────────────────────
255    let certs = load_pem_certs(cert_path)?;
256    let key = load_pem_private_key(key_path)?;
257
258    // ── Build ServerConfig with or without client verification ─────
259    let mut server_config = match config.transport.mode {
260        TransportMode::Mtls => {
261            let client_ca_path =
262                config.transport.tls.client_ca_path.as_deref().expect(
263                    "client_ca_path validated present for mTLS before build_tls_server_config",
264                );
265            let client_roots = load_root_store(client_ca_path)?;
266            let verifier = WebPkiClientVerifier::builder(Arc::new(client_roots))
267                .build()
268                .map_err(|e| {
269                    StartupError::tls_context(format!("failed to build client cert verifier: {e}"))
270                })?;
271            rustls::ServerConfig::builder_with_protocol_versions(&versions)
272                .with_client_cert_verifier(verifier)
273                .with_single_cert(certs, key)
274                .map_err(|e| {
275                    StartupError::tls_context(format!("failed to build mTLS server config: {e}"))
276                })?
277        }
278        _ => rustls::ServerConfig::builder_with_protocol_versions(&versions)
279            .with_no_client_auth()
280            .with_single_cert(certs, key)
281            .map_err(|e| {
282                StartupError::tls_context(format!("failed to build TLS server config: {e}"))
283            })?,
284    };
285
286    // ── ALPN protocols ─────────────────────────────────────────────
287    server_config.alpn_protocols = config
288        .transport
289        .tls
290        .alpn_protocols
291        .iter()
292        .map(|a| a.as_str().as_bytes().to_vec())
293        .collect();
294
295    Ok(server_config)
296}
297
298/// Map config TLS version range to rustls protocol versions.
299fn tls_protocol_versions(
300    min: TlsVersion,
301    max: TlsVersion,
302) -> Vec<&'static rustls::SupportedProtocolVersion> {
303    let mut versions = Vec::with_capacity(2);
304    if min <= TlsVersion::V1_2 && max >= TlsVersion::V1_2 {
305        versions.push(&rustls::version::TLS12);
306    }
307    if min <= TlsVersion::V1_3 && max >= TlsVersion::V1_3 {
308        versions.push(&rustls::version::TLS13);
309    }
310    versions
311}
312
313/// Load PEM certificate chain from a file path.
314fn load_pem_certs(
315    path: &str,
316) -> Result<Vec<rustls::pki_types::CertificateDer<'static>>, StartupError> {
317    let data = std::fs::read(path).map_err(|e| {
318        StartupError::tls_context(format!("failed to read cert file '{path}': {e}"))
319    })?;
320    let certs: Vec<_> = rustls_pemfile::certs(&mut data.as_slice())
321        .collect::<Result<Vec<_>, _>>()
322        .map_err(|e| {
323            StartupError::tls_context(format!("failed to parse PEM certs from '{path}': {e}"))
324        })?;
325    if certs.is_empty() {
326        return Err(StartupError::tls_context(format!(
327            "no certificates found in '{path}'"
328        )));
329    }
330    Ok(certs)
331}
332
333/// Load a PEM private key from a file path.
334fn load_pem_private_key(
335    path: &str,
336) -> Result<rustls::pki_types::PrivateKeyDer<'static>, StartupError> {
337    let data = std::fs::read(path)
338        .map_err(|e| StartupError::tls_context(format!("failed to read key file '{path}': {e}")))?;
339    rustls_pemfile::private_key(&mut data.as_slice())
340        .map_err(|e| {
341            StartupError::tls_context(format!("failed to parse PEM key from '{path}': {e}"))
342        })?
343        .ok_or_else(|| StartupError::tls_context(format!("no private key found in '{path}'")))
344}
345
346/// Load a root certificate store from a PEM CA bundle.
347fn load_root_store(path: &str) -> Result<RootCertStore, StartupError> {
348    let data = std::fs::read(path)
349        .map_err(|e| StartupError::tls_context(format!("failed to read CA file '{path}': {e}")))?;
350    let certs: Vec<_> = rustls_pemfile::certs(&mut data.as_slice())
351        .collect::<Result<Vec<_>, _>>()
352        .map_err(|e| {
353            StartupError::tls_context(format!("failed to parse PEM CA certs from '{path}': {e}"))
354        })?;
355    if certs.is_empty() {
356        return Err(StartupError::tls_context(format!(
357            "no CA certificates found in '{path}'"
358        )));
359    }
360    let mut store = RootCertStore::empty();
361    for cert in certs {
362        store.add(cert).map_err(|e| {
363            StartupError::tls_context(format!("failed to add CA cert to trust store: {e}"))
364        })?;
365    }
366    Ok(store)
367}
368
369// ── Structured startup logging ─────────────────────────────────────
370
371/// Log the startup phase transition at `INFO` level.
372pub fn log_phase(phase: StartupPhase) {
373    tracing::info!(startup_phase = %phase, "entering startup phase");
374}
375
376/// Log transport diagnostics after config validation succeeds.
377pub fn log_transport_summary(config: &Config) {
378    let transport = config.transport.mode.as_str();
379    let versions: Vec<&str> = config
380        .transport
381        .http
382        .versions
383        .iter()
384        .map(|v| v.as_str())
385        .collect();
386
387    tracing::info!(
388        transport.mode = transport,
389        http.versions = ?versions,
390        "transport resolved"
391    );
392
393    if config.transport.mode.uses_tls() {
394        let alpn: Vec<&str> = config
395            .transport
396            .tls
397            .alpn_protocols
398            .iter()
399            .map(|a| a.as_str())
400            .collect();
401        tracing::info!(
402            tls.min_version = config.transport.tls.min_version.as_str(),
403            tls.max_version = config.transport.tls.max_version.as_str(),
404            tls.alpn = ?alpn,
405            tls.has_client_ca = config.transport.tls.client_ca_path.is_some(),
406            "TLS configuration"
407        );
408    }
409
410    tracing::info!(
411        proxy.enabled = config.proxy.enabled,
412        proxy.forwarded_headers = ?config.proxy.forwarded_headers,
413        proxy.trusted_proxy_count = config.proxy.trusted_proxies.len(),
414        proxy.identity_mode = ?config.proxy.identity.mode,
415        "proxy trust state"
416    );
417}
418
419/// Log a startup failure at `ERROR` level with phase context.
420pub fn log_startup_failure(error: &StartupError) {
421    tracing::error!(
422        startup_phase = %error.phase,
423        error = %error.kind,
424        "startup failed"
425    );
426}
427
428// ── Helper constructors ────────────────────────────────────────────
429
430impl StartupError {
431    #[must_use]
432    pub fn config_load(source: ConfigLoadError) -> Self {
433        Self::new(StartupPhase::LoadConfig, source.into())
434    }
435
436    #[must_use]
437    pub fn config_validation(source: ConfigValidationError) -> Self {
438        Self::new(StartupPhase::ValidateConfig, source.into())
439    }
440
441    pub fn tls_context(message: impl Into<String>) -> Self {
442        Self::new(
443            StartupPhase::BuildTlsContext,
444            StartupErrorKind::TlsContext(message.into()),
445        )
446    }
447
448    #[must_use]
449    pub fn bind(addr: SocketAddr, source: io::Error) -> Self {
450        Self::new(
451            StartupPhase::BindListener,
452            StartupErrorKind::Bind { addr, source },
453        )
454    }
455
456    pub fn runtime(message: impl Into<String>) -> Self {
457        Self::new(
458            StartupPhase::StartServer,
459            StartupErrorKind::Runtime(message.into()),
460        )
461    }
462}
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467    use std::fs;
468    use tempfile::TempDir;
469
470    #[test]
471    fn check_tls_file_ok() {
472        let dir = TempDir::new().unwrap();
473        let file = dir.path().join("cert.pem");
474        fs::write(&file, b"not-a-real-cert").unwrap();
475        assert_eq!(check_tls_file(file.to_str().unwrap()), TlsFileStatus::Ok);
476    }
477
478    #[test]
479    fn check_tls_file_not_found() {
480        assert_eq!(
481            check_tls_file("/tmp/does-not-exist-12345.pem"),
482            TlsFileStatus::NotFound
483        );
484    }
485
486    #[test]
487    fn check_tls_file_not_regular() {
488        // A directory is not a regular file.
489        let dir = TempDir::new().unwrap();
490        assert_eq!(
491            check_tls_file(dir.path().to_str().unwrap()),
492            TlsFileStatus::NotRegular
493        );
494    }
495
496    #[test]
497    fn preflight_passes_for_http_mode() {
498        let config = Config::default(); // transport.mode = Http, no TLS paths
499        assert!(preflight_tls_files(&config).is_ok());
500    }
501
502    #[test]
503    fn preflight_fails_for_missing_cert() {
504        let mut config = Config::default();
505        config.transport.tls.cert_path = Some("/tmp/ds-nonexistent-cert-12345.pem".to_string());
506        let err = preflight_tls_files(&config).unwrap_err();
507        assert_eq!(err.phase, StartupPhase::CheckTlsFiles);
508        assert!(
509            matches!(&err.kind, StartupErrorKind::TlsFileNotFound { path }
510                if path.contains("nonexistent"))
511        );
512    }
513
514    #[test]
515    fn preflight_fails_for_directory_as_cert() {
516        let dir = TempDir::new().unwrap();
517        let mut config = Config::default();
518        config.transport.tls.cert_path = Some(dir.path().to_str().unwrap().to_string());
519        let err = preflight_tls_files(&config).unwrap_err();
520        assert_eq!(err.phase, StartupPhase::CheckTlsFiles);
521        assert!(matches!(
522            &err.kind,
523            StartupErrorKind::TlsFileNotRegular { .. }
524        ));
525    }
526
527    #[test]
528    fn startup_error_display_includes_phase() {
529        let err = StartupError::new(
530            StartupPhase::CheckTlsFiles,
531            StartupErrorKind::TlsFileNotFound {
532                path: "/etc/ssl/missing.pem".to_string(),
533            },
534        );
535        let msg = err.to_string();
536        assert!(msg.contains("check_tls_files"), "got: {msg}");
537        assert!(msg.contains("missing.pem"), "got: {msg}");
538    }
539
540    #[test]
541    fn startup_error_preserves_config_validation_cause() {
542        let validation_err = ConfigValidationError::MaxMemoryBytesTooSmall;
543        let err = StartupError::config_validation(validation_err);
544        assert_eq!(err.phase, StartupPhase::ValidateConfig);
545        let msg = err.to_string();
546        assert!(msg.contains("validate_config"), "got: {msg}");
547        assert!(msg.contains("max_memory_bytes"), "got: {msg}");
548    }
549
550    #[test]
551    fn startup_phase_display() {
552        assert_eq!(StartupPhase::LoadConfig.to_string(), "load_config");
553        assert_eq!(StartupPhase::ValidateConfig.to_string(), "validate_config");
554        assert_eq!(
555            StartupPhase::ResolveTransport.to_string(),
556            "resolve_transport"
557        );
558        assert_eq!(StartupPhase::CheckTlsFiles.to_string(), "check_tls_files");
559        assert_eq!(
560            StartupPhase::BuildTlsContext.to_string(),
561            "build_tls_context"
562        );
563        assert_eq!(StartupPhase::BindListener.to_string(), "bind_listener");
564        assert_eq!(StartupPhase::StartServer.to_string(), "start_server");
565    }
566
567    #[test]
568    fn bind_tcp_listener_returns_bind_phase_error() {
569        let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
570        let addr = listener.local_addr().unwrap();
571
572        let err = bind_tcp_listener(addr).unwrap_err();
573        assert_eq!(err.phase, StartupPhase::BindListener);
574        assert!(matches!(&err.kind, StartupErrorKind::Bind { addr: bound, .. } if *bound == addr));
575    }
576}