1use crate::config::{Config, ConfigLoadError, ConfigValidationError, TlsVersion, TransportMode};
9use rustls::RootCertStore;
10use rustls::server::WebPkiClientVerifier;
11use std::fmt;
12use std::io;
13use std::net::SocketAddr;
14use std::path::Path;
15use std::sync::Arc;
16use thiserror::Error;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum StartupPhase {
27 LoadConfig,
29 ValidateConfig,
31 ResolveTransport,
33 CheckTlsFiles,
35 BuildTlsContext,
37 BindListener,
39 StartServer,
41}
42
43impl fmt::Display for StartupPhase {
44 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45 f.write_str(match self {
46 Self::LoadConfig => "load_config",
47 Self::ValidateConfig => "validate_config",
48 Self::ResolveTransport => "resolve_transport",
49 Self::CheckTlsFiles => "check_tls_files",
50 Self::BuildTlsContext => "build_tls_context",
51 Self::BindListener => "bind_listener",
52 Self::StartServer => "start_server",
53 })
54 }
55}
56
57#[derive(Debug, Error)]
61#[error("[{phase}] {kind}")]
62pub struct StartupError {
63 pub phase: StartupPhase,
65 pub kind: StartupErrorKind,
67}
68
69impl StartupError {
70 #[must_use]
71 pub fn new(phase: StartupPhase, kind: StartupErrorKind) -> Self {
72 Self { phase, kind }
73 }
74}
75
76#[derive(Debug, Error)]
78pub enum StartupErrorKind {
79 #[error("config load failed: {0}")]
81 ConfigLoad(#[from] ConfigLoadError),
82
83 #[error("config validation failed: {0}")]
85 ConfigValidation(#[from] ConfigValidationError),
86
87 #[error("TLS file not found: {path}")]
89 TlsFileNotFound { path: String },
90
91 #[error("TLS path is not a regular file: {path}")]
93 TlsFileNotRegular { path: String },
94
95 #[error("TLS file is not readable: {path}: {reason}")]
97 TlsFileNotReadable { path: String, reason: String },
98
99 #[error("failed to build TLS context: {0}")]
101 TlsContext(String),
102
103 #[error("failed to bind {addr}: {source}")]
105 Bind { addr: SocketAddr, source: io::Error },
106
107 #[error("server error: {0}")]
109 Runtime(String),
110}
111
112#[derive(Debug, Clone, PartialEq, Eq)]
116pub enum TlsFileStatus {
117 Ok,
119 NotFound,
121 NotRegular,
123 NotReadable(String),
125}
126
127#[must_use]
129pub fn check_tls_file(path: &str) -> TlsFileStatus {
130 let p = Path::new(path);
131 let metadata = match std::fs::metadata(p) {
132 Ok(m) => m,
133 Err(e) if e.kind() == io::ErrorKind::NotFound => return TlsFileStatus::NotFound,
134 Err(e) => return TlsFileStatus::NotReadable(e.to_string()),
135 };
136 if !metadata.is_file() {
137 return TlsFileStatus::NotRegular;
138 }
139 if let Err(e) = std::fs::File::open(p) {
141 return TlsFileStatus::NotReadable(e.to_string());
142 }
143 TlsFileStatus::Ok
144}
145
146pub fn preflight_tls_files(config: &Config) -> Result<(), StartupError> {
155 let files: Vec<(&str, &str)> = [
156 ("cert_path", config.transport.tls.cert_path.as_deref()),
157 ("key_path", config.transport.tls.key_path.as_deref()),
158 (
159 "client_ca_path",
160 config.transport.tls.client_ca_path.as_deref(),
161 ),
162 ]
163 .into_iter()
164 .filter_map(|(label, path)| path.map(|p| (label, p)))
165 .collect();
166
167 for (_label, path) in &files {
168 match check_tls_file(path) {
169 TlsFileStatus::Ok => {}
170 TlsFileStatus::NotFound => {
171 return Err(StartupError::new(
172 StartupPhase::CheckTlsFiles,
173 StartupErrorKind::TlsFileNotFound {
174 path: (*path).to_string(),
175 },
176 ));
177 }
178 TlsFileStatus::NotRegular => {
179 return Err(StartupError::new(
180 StartupPhase::CheckTlsFiles,
181 StartupErrorKind::TlsFileNotRegular {
182 path: (*path).to_string(),
183 },
184 ));
185 }
186 TlsFileStatus::NotReadable(reason) => {
187 return Err(StartupError::new(
188 StartupPhase::CheckTlsFiles,
189 StartupErrorKind::TlsFileNotReadable {
190 path: (*path).to_string(),
191 reason,
192 },
193 ));
194 }
195 }
196 }
197 Ok(())
198}
199
200pub fn bind_tcp_listener(addr: SocketAddr) -> Result<std::net::TcpListener, StartupError> {
207 let listener =
208 std::net::TcpListener::bind(addr).map_err(|source| StartupError::bind(addr, source))?;
209 listener
210 .set_nonblocking(true)
211 .map_err(|source| StartupError::bind(addr, source))?;
212 Ok(listener)
213}
214
215pub fn build_tls_server_config(config: &Config) -> Result<rustls::ServerConfig, StartupError> {
235 let cert_path = config
236 .transport
237 .tls
238 .cert_path
239 .as_deref()
240 .expect("cert_path validated present before build_tls_server_config");
241 let key_path = config
242 .transport
243 .tls
244 .key_path
245 .as_deref()
246 .expect("key_path validated present before build_tls_server_config");
247
248 let versions = tls_protocol_versions(
250 config.transport.tls.min_version,
251 config.transport.tls.max_version,
252 );
253
254 let certs = load_pem_certs(cert_path)?;
256 let key = load_pem_private_key(key_path)?;
257
258 let mut server_config = match config.transport.mode {
260 TransportMode::Mtls => {
261 let client_ca_path =
262 config.transport.tls.client_ca_path.as_deref().expect(
263 "client_ca_path validated present for mTLS before build_tls_server_config",
264 );
265 let client_roots = load_root_store(client_ca_path)?;
266 let verifier = WebPkiClientVerifier::builder(Arc::new(client_roots))
267 .build()
268 .map_err(|e| {
269 StartupError::tls_context(format!("failed to build client cert verifier: {e}"))
270 })?;
271 rustls::ServerConfig::builder_with_protocol_versions(&versions)
272 .with_client_cert_verifier(verifier)
273 .with_single_cert(certs, key)
274 .map_err(|e| {
275 StartupError::tls_context(format!("failed to build mTLS server config: {e}"))
276 })?
277 }
278 _ => rustls::ServerConfig::builder_with_protocol_versions(&versions)
279 .with_no_client_auth()
280 .with_single_cert(certs, key)
281 .map_err(|e| {
282 StartupError::tls_context(format!("failed to build TLS server config: {e}"))
283 })?,
284 };
285
286 server_config.alpn_protocols = config
288 .transport
289 .tls
290 .alpn_protocols
291 .iter()
292 .map(|a| a.as_str().as_bytes().to_vec())
293 .collect();
294
295 Ok(server_config)
296}
297
298fn tls_protocol_versions(
300 min: TlsVersion,
301 max: TlsVersion,
302) -> Vec<&'static rustls::SupportedProtocolVersion> {
303 let mut versions = Vec::with_capacity(2);
304 if min <= TlsVersion::V1_2 && max >= TlsVersion::V1_2 {
305 versions.push(&rustls::version::TLS12);
306 }
307 if min <= TlsVersion::V1_3 && max >= TlsVersion::V1_3 {
308 versions.push(&rustls::version::TLS13);
309 }
310 versions
311}
312
313fn load_pem_certs(
315 path: &str,
316) -> Result<Vec<rustls::pki_types::CertificateDer<'static>>, StartupError> {
317 let data = std::fs::read(path).map_err(|e| {
318 StartupError::tls_context(format!("failed to read cert file '{path}': {e}"))
319 })?;
320 let certs: Vec<_> = rustls_pemfile::certs(&mut data.as_slice())
321 .collect::<Result<Vec<_>, _>>()
322 .map_err(|e| {
323 StartupError::tls_context(format!("failed to parse PEM certs from '{path}': {e}"))
324 })?;
325 if certs.is_empty() {
326 return Err(StartupError::tls_context(format!(
327 "no certificates found in '{path}'"
328 )));
329 }
330 Ok(certs)
331}
332
333fn load_pem_private_key(
335 path: &str,
336) -> Result<rustls::pki_types::PrivateKeyDer<'static>, StartupError> {
337 let data = std::fs::read(path)
338 .map_err(|e| StartupError::tls_context(format!("failed to read key file '{path}': {e}")))?;
339 rustls_pemfile::private_key(&mut data.as_slice())
340 .map_err(|e| {
341 StartupError::tls_context(format!("failed to parse PEM key from '{path}': {e}"))
342 })?
343 .ok_or_else(|| StartupError::tls_context(format!("no private key found in '{path}'")))
344}
345
346fn load_root_store(path: &str) -> Result<RootCertStore, StartupError> {
348 let data = std::fs::read(path)
349 .map_err(|e| StartupError::tls_context(format!("failed to read CA file '{path}': {e}")))?;
350 let certs: Vec<_> = rustls_pemfile::certs(&mut data.as_slice())
351 .collect::<Result<Vec<_>, _>>()
352 .map_err(|e| {
353 StartupError::tls_context(format!("failed to parse PEM CA certs from '{path}': {e}"))
354 })?;
355 if certs.is_empty() {
356 return Err(StartupError::tls_context(format!(
357 "no CA certificates found in '{path}'"
358 )));
359 }
360 let mut store = RootCertStore::empty();
361 for cert in certs {
362 store.add(cert).map_err(|e| {
363 StartupError::tls_context(format!("failed to add CA cert to trust store: {e}"))
364 })?;
365 }
366 Ok(store)
367}
368
369pub fn log_phase(phase: StartupPhase) {
373 tracing::info!(startup_phase = %phase, "entering startup phase");
374}
375
376pub fn log_transport_summary(config: &Config) {
378 let transport = config.transport.mode.as_str();
379 let versions: Vec<&str> = config
380 .transport
381 .http
382 .versions
383 .iter()
384 .map(|v| v.as_str())
385 .collect();
386
387 tracing::info!(
388 transport.mode = transport,
389 http.versions = ?versions,
390 "transport resolved"
391 );
392
393 if config.transport.mode.uses_tls() {
394 let alpn: Vec<&str> = config
395 .transport
396 .tls
397 .alpn_protocols
398 .iter()
399 .map(|a| a.as_str())
400 .collect();
401 tracing::info!(
402 tls.min_version = config.transport.tls.min_version.as_str(),
403 tls.max_version = config.transport.tls.max_version.as_str(),
404 tls.alpn = ?alpn,
405 tls.has_client_ca = config.transport.tls.client_ca_path.is_some(),
406 "TLS configuration"
407 );
408 }
409
410 tracing::info!(
411 proxy.enabled = config.proxy.enabled,
412 proxy.forwarded_headers = ?config.proxy.forwarded_headers,
413 proxy.trusted_proxy_count = config.proxy.trusted_proxies.len(),
414 proxy.identity_mode = ?config.proxy.identity.mode,
415 "proxy trust state"
416 );
417}
418
419pub fn log_startup_failure(error: &StartupError) {
421 tracing::error!(
422 startup_phase = %error.phase,
423 error = %error.kind,
424 "startup failed"
425 );
426}
427
428impl StartupError {
431 #[must_use]
432 pub fn config_load(source: ConfigLoadError) -> Self {
433 Self::new(StartupPhase::LoadConfig, source.into())
434 }
435
436 #[must_use]
437 pub fn config_validation(source: ConfigValidationError) -> Self {
438 Self::new(StartupPhase::ValidateConfig, source.into())
439 }
440
441 pub fn tls_context(message: impl Into<String>) -> Self {
442 Self::new(
443 StartupPhase::BuildTlsContext,
444 StartupErrorKind::TlsContext(message.into()),
445 )
446 }
447
448 #[must_use]
449 pub fn bind(addr: SocketAddr, source: io::Error) -> Self {
450 Self::new(
451 StartupPhase::BindListener,
452 StartupErrorKind::Bind { addr, source },
453 )
454 }
455
456 pub fn runtime(message: impl Into<String>) -> Self {
457 Self::new(
458 StartupPhase::StartServer,
459 StartupErrorKind::Runtime(message.into()),
460 )
461 }
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467 use std::fs;
468 use tempfile::TempDir;
469
470 #[test]
471 fn check_tls_file_ok() {
472 let dir = TempDir::new().unwrap();
473 let file = dir.path().join("cert.pem");
474 fs::write(&file, b"not-a-real-cert").unwrap();
475 assert_eq!(check_tls_file(file.to_str().unwrap()), TlsFileStatus::Ok);
476 }
477
478 #[test]
479 fn check_tls_file_not_found() {
480 assert_eq!(
481 check_tls_file("/tmp/does-not-exist-12345.pem"),
482 TlsFileStatus::NotFound
483 );
484 }
485
486 #[test]
487 fn check_tls_file_not_regular() {
488 let dir = TempDir::new().unwrap();
490 assert_eq!(
491 check_tls_file(dir.path().to_str().unwrap()),
492 TlsFileStatus::NotRegular
493 );
494 }
495
496 #[test]
497 fn preflight_passes_for_http_mode() {
498 let config = Config::default(); assert!(preflight_tls_files(&config).is_ok());
500 }
501
502 #[test]
503 fn preflight_fails_for_missing_cert() {
504 let mut config = Config::default();
505 config.transport.tls.cert_path = Some("/tmp/ds-nonexistent-cert-12345.pem".to_string());
506 let err = preflight_tls_files(&config).unwrap_err();
507 assert_eq!(err.phase, StartupPhase::CheckTlsFiles);
508 assert!(
509 matches!(&err.kind, StartupErrorKind::TlsFileNotFound { path }
510 if path.contains("nonexistent"))
511 );
512 }
513
514 #[test]
515 fn preflight_fails_for_directory_as_cert() {
516 let dir = TempDir::new().unwrap();
517 let mut config = Config::default();
518 config.transport.tls.cert_path = Some(dir.path().to_str().unwrap().to_string());
519 let err = preflight_tls_files(&config).unwrap_err();
520 assert_eq!(err.phase, StartupPhase::CheckTlsFiles);
521 assert!(matches!(
522 &err.kind,
523 StartupErrorKind::TlsFileNotRegular { .. }
524 ));
525 }
526
527 #[test]
528 fn startup_error_display_includes_phase() {
529 let err = StartupError::new(
530 StartupPhase::CheckTlsFiles,
531 StartupErrorKind::TlsFileNotFound {
532 path: "/etc/ssl/missing.pem".to_string(),
533 },
534 );
535 let msg = err.to_string();
536 assert!(msg.contains("check_tls_files"), "got: {msg}");
537 assert!(msg.contains("missing.pem"), "got: {msg}");
538 }
539
540 #[test]
541 fn startup_error_preserves_config_validation_cause() {
542 let validation_err = ConfigValidationError::MaxMemoryBytesTooSmall;
543 let err = StartupError::config_validation(validation_err);
544 assert_eq!(err.phase, StartupPhase::ValidateConfig);
545 let msg = err.to_string();
546 assert!(msg.contains("validate_config"), "got: {msg}");
547 assert!(msg.contains("max_memory_bytes"), "got: {msg}");
548 }
549
550 #[test]
551 fn startup_phase_display() {
552 assert_eq!(StartupPhase::LoadConfig.to_string(), "load_config");
553 assert_eq!(StartupPhase::ValidateConfig.to_string(), "validate_config");
554 assert_eq!(
555 StartupPhase::ResolveTransport.to_string(),
556 "resolve_transport"
557 );
558 assert_eq!(StartupPhase::CheckTlsFiles.to_string(), "check_tls_files");
559 assert_eq!(
560 StartupPhase::BuildTlsContext.to_string(),
561 "build_tls_context"
562 );
563 assert_eq!(StartupPhase::BindListener.to_string(), "bind_listener");
564 assert_eq!(StartupPhase::StartServer.to_string(), "start_server");
565 }
566
567 #[test]
568 fn bind_tcp_listener_returns_bind_phase_error() {
569 let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
570 let addr = listener.local_addr().unwrap();
571
572 let err = bind_tcp_listener(addr).unwrap_err();
573 assert_eq!(err.phase, StartupPhase::BindListener);
574 assert!(matches!(&err.kind, StartupErrorKind::Bind { addr: bound, .. } if *bound == addr));
575 }
576}