1use rcgen::{Certificate, CertificateParams};
7use std::fs;
8use std::net::SocketAddr;
9use std::path::PathBuf;
10use std::sync::Arc;
11use std::time::Duration;
12
13use crate::error::{M2MError, Result};
14
15#[derive(Debug, Clone)]
17pub enum CertConfig {
18 SelfSigned {
20 common_name: String,
22 },
23 Files {
25 cert_path: PathBuf,
27 key_path: PathBuf,
29 },
30 Raw {
32 cert_der: Vec<Vec<u8>>,
34 key_der: Vec<u8>,
36 },
37}
38
39impl Default for CertConfig {
40 fn default() -> Self {
41 Self::SelfSigned {
42 common_name: "localhost".to_string(),
43 }
44 }
45}
46
47impl CertConfig {
48 pub fn development() -> Self {
50 Self::SelfSigned {
51 common_name: "localhost".to_string(),
52 }
53 }
54
55 pub fn from_files(cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
57 Self::Files {
58 cert_path: cert_path.into(),
59 key_path: key_path.into(),
60 }
61 }
62
63 pub fn load(&self) -> Result<(Vec<rustls::Certificate>, rustls::PrivateKey)> {
65 match self {
66 Self::SelfSigned { common_name } => {
67 tracing::warn!(
68 "Using self-signed certificate for '{}' - NOT FOR PRODUCTION",
69 common_name
70 );
71
72 let mut params = CertificateParams::new(vec![
73 common_name.clone(),
74 "127.0.0.1".to_string(),
75 "::1".to_string(),
76 ]);
77
78 params.distinguished_name = rcgen::DistinguishedName::new();
79 params
80 .distinguished_name
81 .push(rcgen::DnType::CommonName, common_name.clone());
82
83 let cert = Certificate::from_params(params).map_err(|e| {
85 M2MError::Config(format!("Failed to generate self-signed cert: {}", e))
86 })?;
87
88 let cert_der =
89 rustls::Certificate(cert.serialize_der().map_err(|e| {
90 M2MError::Config(format!("Failed to serialize cert: {}", e))
91 })?);
92 let key_der = rustls::PrivateKey(cert.serialize_private_key_der());
93
94 Ok((vec![cert_der], key_der))
95 },
96 Self::Files {
97 cert_path,
98 key_path,
99 } => {
100 let cert_pem = fs::read(cert_path).map_err(|e| {
101 M2MError::Config(format!(
102 "Failed to read cert file {}: {}",
103 cert_path.display(),
104 e
105 ))
106 })?;
107
108 let key_pem = fs::read(key_path).map_err(|e| {
109 M2MError::Config(format!(
110 "Failed to read key file {}: {}",
111 key_path.display(),
112 e
113 ))
114 })?;
115
116 let certs: Vec<rustls::Certificate> =
117 rustls_pemfile::certs(&mut cert_pem.as_slice())
118 .map_err(|e| M2MError::Config(format!("Failed to parse cert PEM: {}", e)))?
119 .into_iter()
120 .map(rustls::Certificate)
121 .collect();
122
123 if certs.is_empty() {
124 return Err(M2MError::Config(
125 "No certificates found in PEM file".to_string(),
126 ));
127 }
128
129 let key = rustls_pemfile::pkcs8_private_keys(&mut key_pem.as_slice())
131 .map_err(|e| M2MError::Config(format!("Failed to parse key PEM: {}", e)))?
132 .into_iter()
133 .next()
134 .map(rustls::PrivateKey)
135 .or_else(|| {
136 rustls_pemfile::rsa_private_keys(&mut key_pem.as_slice())
137 .ok()?
138 .into_iter()
139 .next()
140 .map(rustls::PrivateKey)
141 })
142 .ok_or_else(|| {
143 M2MError::Config("No private key found in PEM file".to_string())
144 })?;
145
146 Ok((certs, key))
147 },
148 Self::Raw { cert_der, key_der } => {
149 let certs = cert_der
150 .iter()
151 .map(|c| rustls::Certificate(c.clone()))
152 .collect();
153 let key = rustls::PrivateKey(key_der.clone());
154 Ok((certs, key))
155 },
156 }
157 }
158}
159
160#[derive(Debug, Clone)]
162pub struct TlsConfig {
163 pub cert: CertConfig,
165 pub alpn_protocols: Vec<Vec<u8>>,
167}
168
169impl Default for TlsConfig {
170 fn default() -> Self {
171 Self {
172 cert: CertConfig::default(),
173 alpn_protocols: vec![b"h3".to_vec()],
174 }
175 }
176}
177
178impl TlsConfig {
179 pub fn development() -> Self {
181 Self {
182 cert: CertConfig::development(),
183 alpn_protocols: vec![b"h3".to_vec()],
184 }
185 }
186
187 pub fn production(cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
189 Self {
190 cert: CertConfig::from_files(cert_path, key_path),
191 alpn_protocols: vec![b"h3".to_vec()],
192 }
193 }
194}
195
196#[derive(Debug, Clone)]
198pub struct QuicTransportConfig {
199 pub listen_addr: SocketAddr,
201 pub tls: TlsConfig,
203 pub enable_0rtt: bool,
205 pub max_idle_timeout: Duration,
207 pub max_concurrent_bidi_streams: u32,
209 pub max_concurrent_uni_streams: u32,
211 pub use_bbr: bool,
213}
214
215impl Default for QuicTransportConfig {
216 fn default() -> Self {
217 Self {
218 listen_addr: "127.0.0.1:8443".parse().unwrap(),
219 tls: TlsConfig::default(),
220 enable_0rtt: true,
221 max_idle_timeout: Duration::from_secs(30),
222 max_concurrent_bidi_streams: 100,
223 max_concurrent_uni_streams: 100,
224 use_bbr: true,
225 }
226 }
227}
228
229impl QuicTransportConfig {
230 pub fn development() -> Self {
232 Self {
233 tls: TlsConfig::development(),
234 ..Default::default()
235 }
236 }
237
238 pub fn production(
240 listen_addr: SocketAddr,
241 cert_path: impl Into<PathBuf>,
242 key_path: impl Into<PathBuf>,
243 ) -> Self {
244 Self {
245 listen_addr,
246 tls: TlsConfig::production(cert_path, key_path),
247 ..Default::default()
248 }
249 }
250
251 pub fn with_listen_addr(mut self, addr: SocketAddr) -> Self {
253 self.listen_addr = addr;
254 self
255 }
256
257 pub fn with_idle_timeout(mut self, timeout: Duration) -> Self {
259 self.max_idle_timeout = timeout;
260 self
261 }
262
263 pub fn build_quinn_config(&self) -> Result<quinn::ServerConfig> {
265 let (certs, key) = self.tls.cert.load()?;
266
267 let mut rustls_config = rustls::ServerConfig::builder()
268 .with_safe_defaults()
269 .with_no_client_auth()
270 .with_single_cert(certs, key)
271 .map_err(|e| M2MError::Config(format!("Failed to build TLS config: {}", e)))?;
272
273 rustls_config.alpn_protocols = self.tls.alpn_protocols.clone();
274 rustls_config.max_early_data_size = u32::MAX; let mut transport_config = quinn::TransportConfig::default();
277 transport_config.max_idle_timeout(Some(
278 self.max_idle_timeout
279 .try_into()
280 .unwrap_or(quinn::IdleTimeout::try_from(Duration::from_secs(30)).unwrap()),
281 ));
282 transport_config.max_concurrent_bidi_streams(self.max_concurrent_bidi_streams.into());
283 transport_config.max_concurrent_uni_streams(self.max_concurrent_uni_streams.into());
284
285 if self.use_bbr {
287 transport_config
288 .congestion_controller_factory(Arc::new(quinn::congestion::BbrConfig::default()));
289 }
290
291 let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(rustls_config));
292 server_config.transport_config(Arc::new(transport_config));
293
294 Ok(server_config)
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301
302 #[test]
303 fn test_cert_config_self_signed() {
304 let config = CertConfig::development();
305 let result = config.load();
306 assert!(result.is_ok());
307
308 let (certs, _key) = result.unwrap();
309 assert_eq!(certs.len(), 1);
310 }
311
312 #[test]
313 fn test_quic_config_default() {
314 let config = QuicTransportConfig::default();
315 assert_eq!(config.listen_addr.port(), 8443);
316 assert!(config.enable_0rtt);
317 assert!(config.use_bbr);
318 }
319}