harn_cli/commands/orchestrator/
tls.rs1use super::errors::OrchestratorError;
2use std::net::{SocketAddr, TcpListener};
3use std::path::{Path, PathBuf};
4use std::sync::Once;
5use std::time::Duration;
6
7use axum::Router;
8use axum_server::tls_rustls::RustlsConfig;
9use axum_server::Handle;
10
11#[derive(Clone, Debug, PartialEq, Eq)]
12pub struct TlsFiles {
13 pub(crate) cert: PathBuf,
14 pub(crate) key: PathBuf,
15}
16
17impl TlsFiles {
18 pub fn new(cert: PathBuf, key: PathBuf) -> Self {
21 Self { cert, key }
22 }
23
24 pub(crate) fn from_args(
25 cert: Option<PathBuf>,
26 key: Option<PathBuf>,
27 ) -> Result<Option<Self>, OrchestratorError> {
28 match (cert, key) {
29 (None, None) => Ok(None),
30 (Some(cert), Some(key)) => Ok(Some(Self { cert, key })),
31 (Some(_), None) => Err("`--cert` requires `--key`".to_string().into()),
32 (None, Some(_)) => Err("`--key` requires `--cert`".to_string().into()),
33 }
34 }
35}
36
37pub(crate) struct ServerRuntime {
38 local_addr: SocketAddr,
39 handle: Handle<SocketAddr>,
40 task: tokio::task::JoinHandle<Result<(), OrchestratorError>>,
41 tls_enabled: bool,
42}
43
44impl ServerRuntime {
45 pub(crate) async fn start(
46 bind: SocketAddr,
47 app: Router,
48 tls: Option<&TlsFiles>,
49 ) -> Result<Self, OrchestratorError> {
50 let listener = bind_listener(bind)?;
51 let local_addr = listener
52 .local_addr()
53 .map_err(|error| format!("failed to inspect listener address: {error}"))?;
54 let handle = Handle::new();
55 let handle_for_task = handle.clone();
56
57 let task = if let Some(tls) = tls {
58 let rustls = load_rustls_config(&tls.cert, &tls.key).await?;
59 tokio::spawn(async move {
60 axum_server::from_tcp_rustls(listener, rustls)
61 .map_err(|error| format!("HTTPS listener setup failed: {error}"))?
62 .handle(handle_for_task)
63 .serve(app.into_make_service())
64 .await
65 .map_err(|error| {
66 OrchestratorError::Tls(format!("HTTPS listener failed: {error}"))
67 })
68 })
69 } else {
70 tokio::spawn(async move {
71 axum_server::from_tcp(listener)
72 .map_err(|error| format!("HTTP listener setup failed: {error}"))?
73 .handle(handle_for_task)
74 .serve(app.into_make_service())
75 .await
76 .map_err(|error| {
77 OrchestratorError::Tls(format!("HTTP listener failed: {error}"))
78 })
79 })
80 };
81
82 Ok(Self {
83 local_addr,
84 handle,
85 task,
86 tls_enabled: tls.is_some(),
87 })
88 }
89
90 pub(crate) fn local_addr(&self) -> SocketAddr {
91 self.local_addr
92 }
93
94 pub(crate) fn tls_enabled(&self) -> bool {
95 self.tls_enabled
96 }
97
98 pub(crate) async fn shutdown(self, timeout: Duration) -> Result<(), OrchestratorError> {
99 self.handle.graceful_shutdown(Some(timeout));
100 match self.task.await {
101 Ok(result) => result,
102 Err(error) => Err(format!("listener task join failed: {error}").into()),
103 }
104 }
105}
106
107async fn load_rustls_config(cert: &Path, key: &Path) -> Result<RustlsConfig, OrchestratorError> {
108 install_crypto_provider();
109 if !cert.is_file() {
110 return Err(format!("TLS certificate not found: {}", cert.display()).into());
111 }
112 if !key.is_file() {
113 return Err(format!("TLS private key not found: {}", key.display()).into());
114 }
115
116 RustlsConfig::from_pem_file(cert.to_path_buf(), key.to_path_buf())
117 .await
118 .map_err(|error| {
119 OrchestratorError::Tls({
120 format!(
121 "failed to load TLS certificate {} and key {}: {error}",
122 cert.display(),
123 key.display()
124 )
125 })
126 })
127}
128
129fn install_crypto_provider() {
130 static INSTALL: Once = Once::new();
131 INSTALL.call_once(|| {
132 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
133 });
134}
135
136fn bind_listener(bind: SocketAddr) -> Result<TcpListener, OrchestratorError> {
137 let listener = TcpListener::bind(bind)
138 .map_err(|error| format!("failed to bind listener on {bind}: {error}"))?;
139 listener
140 .set_nonblocking(true)
141 .map_err(|error| format!("failed to enable nonblocking listener mode: {error}"))?;
142 Ok(listener)
143}