Skip to main content

harn_cli/commands/orchestrator/
tls.rs

1use super::errors::OrchestratorError;
2use std::net::{SocketAddr, TcpListener};
3use std::path::{Path, PathBuf};
4use std::sync::Once;
5use std::time::Duration;
6
7use axum::Router;
8use axum_server::tls_rustls::RustlsConfig;
9use axum_server::Handle;
10
11#[derive(Clone, Debug, PartialEq, Eq)]
12pub struct TlsFiles {
13    pub(crate) cert: PathBuf,
14    pub(crate) key: PathBuf,
15}
16
17impl TlsFiles {
18    /// Constructor used by tests and by harness consumers that supply
19    /// the cert/key paths directly.
20    pub fn new(cert: PathBuf, key: PathBuf) -> Self {
21        Self { cert, key }
22    }
23
24    pub(crate) fn from_args(
25        cert: Option<PathBuf>,
26        key: Option<PathBuf>,
27    ) -> Result<Option<Self>, OrchestratorError> {
28        match (cert, key) {
29            (None, None) => Ok(None),
30            (Some(cert), Some(key)) => Ok(Some(Self { cert, key })),
31            (Some(_), None) => Err("`--cert` requires `--key`".to_string().into()),
32            (None, Some(_)) => Err("`--key` requires `--cert`".to_string().into()),
33        }
34    }
35}
36
37pub(crate) struct ServerRuntime {
38    local_addr: SocketAddr,
39    handle: Handle<SocketAddr>,
40    task: tokio::task::JoinHandle<Result<(), OrchestratorError>>,
41    tls_enabled: bool,
42}
43
44impl ServerRuntime {
45    pub(crate) async fn start(
46        bind: SocketAddr,
47        app: Router,
48        tls: Option<&TlsFiles>,
49    ) -> Result<Self, OrchestratorError> {
50        let listener = bind_listener(bind)?;
51        let local_addr = listener
52            .local_addr()
53            .map_err(|error| format!("failed to inspect listener address: {error}"))?;
54        let handle = Handle::new();
55        let handle_for_task = handle.clone();
56
57        let task = if let Some(tls) = tls {
58            let rustls = load_rustls_config(&tls.cert, &tls.key).await?;
59            tokio::spawn(async move {
60                axum_server::from_tcp_rustls(listener, rustls)
61                    .map_err(|error| format!("HTTPS listener setup failed: {error}"))?
62                    .handle(handle_for_task)
63                    .serve(app.into_make_service())
64                    .await
65                    .map_err(|error| {
66                        OrchestratorError::Tls(format!("HTTPS listener failed: {error}"))
67                    })
68            })
69        } else {
70            tokio::spawn(async move {
71                axum_server::from_tcp(listener)
72                    .map_err(|error| format!("HTTP listener setup failed: {error}"))?
73                    .handle(handle_for_task)
74                    .serve(app.into_make_service())
75                    .await
76                    .map_err(|error| {
77                        OrchestratorError::Tls(format!("HTTP listener failed: {error}"))
78                    })
79            })
80        };
81
82        Ok(Self {
83            local_addr,
84            handle,
85            task,
86            tls_enabled: tls.is_some(),
87        })
88    }
89
90    pub(crate) fn local_addr(&self) -> SocketAddr {
91        self.local_addr
92    }
93
94    pub(crate) fn tls_enabled(&self) -> bool {
95        self.tls_enabled
96    }
97
98    pub(crate) async fn shutdown(self, timeout: Duration) -> Result<(), OrchestratorError> {
99        self.handle.graceful_shutdown(Some(timeout));
100        match self.task.await {
101            Ok(result) => result,
102            Err(error) => Err(format!("listener task join failed: {error}").into()),
103        }
104    }
105}
106
107async fn load_rustls_config(cert: &Path, key: &Path) -> Result<RustlsConfig, OrchestratorError> {
108    install_crypto_provider();
109    if !cert.is_file() {
110        return Err(format!("TLS certificate not found: {}", cert.display()).into());
111    }
112    if !key.is_file() {
113        return Err(format!("TLS private key not found: {}", key.display()).into());
114    }
115
116    RustlsConfig::from_pem_file(cert.to_path_buf(), key.to_path_buf())
117        .await
118        .map_err(|error| {
119            OrchestratorError::Tls({
120                format!(
121                    "failed to load TLS certificate {} and key {}: {error}",
122                    cert.display(),
123                    key.display()
124                )
125            })
126        })
127}
128
129fn install_crypto_provider() {
130    static INSTALL: Once = Once::new();
131    INSTALL.call_once(|| {
132        let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
133    });
134}
135
136fn bind_listener(bind: SocketAddr) -> Result<TcpListener, OrchestratorError> {
137    let listener = TcpListener::bind(bind)
138        .map_err(|error| format!("failed to bind listener on {bind}: {error}"))?;
139    listener
140        .set_nonblocking(true)
141        .map_err(|error| format!("failed to enable nonblocking listener mode: {error}"))?;
142    Ok(listener)
143}