1use std::io::Write;
6use std::net::{SocketAddr, ToSocketAddrs};
7use std::sync::{Arc, Mutex};
8use std::time::Duration;
9
10use ntp_proto::{KeyExchangeClient, KeyExchangeError, KeyExchangeResult, ProtocolVersion};
11use rustls::pki_types::{CertificateDer, ServerName as RustlsServerName, UnixTime};
12use sha2::{Digest, Sha256};
13use tracing::{debug, info, warn};
14use x509_parser::prelude::*;
15
16use crate::config::NtsClientConfig;
17use crate::error::{Error, Result};
18use crate::types::{CertificateInfo, NtsKeResult};
19
20pub(crate) async fn perform_nts_ke(config: &NtsClientConfig) -> Result<NtsKeResult> {
22 let ke_start = std::time::Instant::now();
23
24 info!(
25 "Starting NTS-KE with {}:{}",
26 config.nts_ke_server, config.nts_ke_port
27 );
28
29 let server_addr = resolve_server(&config.nts_ke_server, config.nts_ke_port).await?;
31 debug!("Resolved server address: {}", server_addr);
32
33 let (tls_config, captured_certs) = build_tls_config(config)?;
35
36 let protocol_version = ProtocolVersion::V4;
38
39 let server_name = config.nts_ke_server.clone();
41 let timeout_duration = config.timeout;
42
43 let result = tokio::task::spawn_blocking(move || {
44 perform_nts_ke_blocking(
45 server_addr,
46 server_name,
47 tls_config,
48 protocol_version,
49 timeout_duration,
50 )
51 })
52 .await
53 .map_err(|e| Error::KeyExchange(format!("Task join error: {}", e)))??;
54
55 let ke_duration = ke_start.elapsed();
56 debug!("NTS-KE completed in {:?}", ke_duration);
57
58 let certificate = {
60 let certs = captured_certs.lock().unwrap();
61 if !certs.is_empty() {
62 extract_certificate_info(&certs)
63 } else {
64 None
65 }
66 };
67
68 if let Some(ref cert) = certificate {
69 debug!(
70 "Captured certificate: subject={}, issuer={}",
71 cert.subject, cert.issuer
72 );
73 }
74
75 convert_ke_result(result, ke_duration, certificate)
77}
78
79fn perform_nts_ke_blocking(
81 server_addr: SocketAddr,
82 server_name: String,
83 tls_config: ntp_proto::tls_utils::ClientConfig,
84 protocol_version: ProtocolVersion,
85 timeout_duration: Duration,
86) -> Result<KeyExchangeResult> {
87 let mut socket =
89 std::net::TcpStream::connect_timeout(&server_addr, timeout_duration).map_err(Error::Io)?;
90
91 socket.set_nonblocking(true).map_err(Error::Io)?;
92
93 debug!("TCP connection established");
94
95 let mut ke_client = KeyExchangeClient::new(
97 server_name,
98 tls_config,
99 protocol_version,
100 Vec::<String>::new(), )
102 .map_err(Error::from)?;
103
104 debug!("KeyExchangeClient created");
105
106 let start = std::time::Instant::now();
108 loop {
109 if start.elapsed() > timeout_duration {
110 return Err(Error::Timeout);
111 }
112
113 if ke_client.wants_write() {
115 match ke_client.write_socket(&mut socket) {
116 Ok(n) => {
117 if n > 0 {
118 debug!("Wrote {} bytes to socket", n);
119 }
120 }
121 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {}
122 Err(e) => return Err(Error::Io(e)),
123 }
124 }
125
126 if ke_client.wants_read() {
128 match ke_client.read_socket(&mut socket) {
129 Ok(n) => {
130 if n > 0 {
131 debug!("Read {} bytes from socket", n);
132 }
133 }
134 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {}
135 Err(e) => return Err(Error::Io(e)),
136 }
137 }
138
139 match ke_client.progress() {
141 std::ops::ControlFlow::Break(Ok(result)) => {
142 debug!("NTS-KE succeeded");
143 return Ok(result);
144 }
145 std::ops::ControlFlow::Break(Err(e)) => {
146 return Err(Error::from(e));
147 }
148 std::ops::ControlFlow::Continue(client) => {
149 ke_client = client;
150 std::thread::sleep(std::time::Duration::from_millis(10));
152 }
153 }
154 }
155}
156
157fn extract_certificate_info(certs: &[CertificateDer<'_>]) -> Option<CertificateInfo> {
159 let cert_der = certs.first()?;
161
162 let (_, cert) = X509Certificate::from_der(cert_der.as_ref()).ok()?;
164
165 let subject = cert.subject().to_string();
167
168 let issuer = cert.issuer().to_string();
170
171 let valid_from = format!("{}", cert.validity().not_before);
173 let valid_until = format!("{}", cert.validity().not_after);
174
175 let serial_number = format!("{:x}", cert.serial);
177
178 let san_dns_names = cert
180 .subject_alternative_name()
181 .ok()
182 .flatten()
183 .map(|san| {
184 san.value
185 .general_names
186 .iter()
187 .filter_map(|gn| match gn {
188 GeneralName::DNSName(name) => Some(name.to_string()),
189 _ => None,
190 })
191 .collect::<Vec<_>>()
192 })
193 .unwrap_or_default();
194
195 let signature_algorithm = cert.signature_algorithm.algorithm.to_string();
197
198 let public_key_algorithm = cert.public_key().algorithm.algorithm.to_string();
200
201 let mut hasher = Sha256::new();
203 hasher.update(cert_der.as_ref());
204 let fingerprint_sha256 = format!("{:x}", hasher.finalize());
205
206 let is_self_signed = cert.subject() == cert.issuer();
208
209 Some(CertificateInfo {
210 subject,
211 issuer,
212 valid_from,
213 valid_until,
214 serial_number,
215 san_dns_names,
216 signature_algorithm,
217 public_key_algorithm,
218 fingerprint_sha256,
219 is_self_signed,
220 })
221}
222
223#[derive(Debug)]
225struct CapturingVerifier {
226 inner: Arc<dyn rustls::client::danger::ServerCertVerifier>,
227 captured_certs: Arc<Mutex<Vec<CertificateDer<'static>>>>,
228}
229
230impl rustls::client::danger::ServerCertVerifier for CapturingVerifier {
231 fn verify_server_cert(
232 &self,
233 end_entity: &CertificateDer<'_>,
234 intermediates: &[CertificateDer<'_>],
235 server_name: &RustlsServerName<'_>,
236 ocsp_response: &[u8],
237 now: UnixTime,
238 ) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
239 let mut certs = self.captured_certs.lock().unwrap();
241 certs.push(end_entity.clone().into_owned());
242 for cert in intermediates {
243 certs.push(cert.clone().into_owned());
244 }
245
246 self.inner
248 .verify_server_cert(end_entity, intermediates, server_name, ocsp_response, now)
249 }
250
251 fn verify_tls12_signature(
252 &self,
253 message: &[u8],
254 cert: &CertificateDer<'_>,
255 dss: &rustls::DigitallySignedStruct,
256 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
257 self.inner.verify_tls12_signature(message, cert, dss)
258 }
259
260 fn verify_tls13_signature(
261 &self,
262 message: &[u8],
263 cert: &CertificateDer<'_>,
264 dss: &rustls::DigitallySignedStruct,
265 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
266 self.inner.verify_tls13_signature(message, cert, dss)
267 }
268
269 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
270 self.inner.supported_verify_schemes()
271 }
272}
273
274fn build_tls_config(
276 config: &NtsClientConfig,
277) -> Result<(
278 ntp_proto::tls_utils::ClientConfig,
279 Arc<Mutex<Vec<CertificateDer<'static>>>>,
280)> {
281 use ntp_proto::tls_utils::{self};
282
283 let _ = rustls::crypto::ring::default_provider().install_default();
286
287 let key_log = std::env::var("SSLKEYLOGFILE")
289 .ok()
290 .and_then(|path| {
291 debug!("Enabling TLS keylog to: {}", path);
292 std::fs::OpenOptions::new()
293 .create(true)
294 .append(true)
295 .open(&path)
296 .ok()
297 })
298 .map(|file| Arc::new(KeyLogFile(Mutex::new(file))) as Arc<dyn rustls::KeyLog>);
299
300 let captured_certs = Arc::new(Mutex::new(Vec::new()));
302
303 if config.verify_tls_cert {
304 let builder = tls_utils::client_config_builder_with_protocol_versions(&[&tls_utils::TLS13]);
306 let provider = builder.crypto_provider().clone();
307
308 let platform_verifier = tls_utils::PlatformVerifier::new().with_provider(provider);
309
310 let capturing_verifier = CapturingVerifier {
312 inner: Arc::new(platform_verifier),
313 captured_certs: captured_certs.clone(),
314 };
315
316 let mut tls_config = builder
317 .dangerous()
318 .with_custom_certificate_verifier(Arc::new(capturing_verifier))
319 .with_no_client_auth();
320
321 if let Some(kl) = key_log {
322 tls_config.key_log = kl;
323 }
324
325 Ok((tls_config, captured_certs))
326 } else {
327 warn!("TLS certificate verification is disabled!");
329
330 let builder = tls_utils::client_config_builder_with_protocol_versions(&[&tls_utils::TLS13]);
331 let provider = builder.crypto_provider().clone();
332
333 let no_verification = NoVerification { provider };
335
336 let capturing_verifier = CapturingVerifier {
337 inner: Arc::new(no_verification),
338 captured_certs: captured_certs.clone(),
339 };
340
341 let mut tls_config = builder
342 .dangerous()
343 .with_custom_certificate_verifier(Arc::new(capturing_verifier))
344 .with_no_client_auth();
345
346 if let Some(kl) = key_log {
347 tls_config.key_log = kl;
348 }
349
350 Ok((tls_config, captured_certs))
351 }
352}
353
354#[derive(Debug)]
356struct NoVerification {
357 provider: Arc<rustls::crypto::CryptoProvider>,
358}
359
360impl rustls::client::danger::ServerCertVerifier for NoVerification {
361 fn verify_server_cert(
362 &self,
363 _end_entity: &rustls::pki_types::CertificateDer<'_>,
364 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
365 _server_name: &rustls::pki_types::ServerName<'_>,
366 _ocsp_response: &[u8],
367 _now: rustls::pki_types::UnixTime,
368 ) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
369 Ok(rustls::client::danger::ServerCertVerified::assertion())
370 }
371
372 fn verify_tls12_signature(
373 &self,
374 _message: &[u8],
375 _cert: &rustls::pki_types::CertificateDer<'_>,
376 _dss: &rustls::DigitallySignedStruct,
377 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
378 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
379 }
380
381 fn verify_tls13_signature(
382 &self,
383 _message: &[u8],
384 _cert: &rustls::pki_types::CertificateDer<'_>,
385 _dss: &rustls::DigitallySignedStruct,
386 ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
387 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
388 }
389
390 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
391 self.provider
392 .signature_verification_algorithms
393 .supported_schemes()
394 }
395}
396
397async fn resolve_server(server: &str, port: u16) -> Result<SocketAddr> {
399 let addrs = format!("{}:{}", server, port)
400 .to_socket_addrs()
401 .map_err(|e| Error::ServerUnavailable(format!("DNS resolution failed: {}", e)))?;
402
403 addrs
404 .into_iter()
405 .next()
406 .ok_or_else(|| Error::ServerUnavailable("No addresses resolved".to_string()))
407}
408
409fn convert_ke_result(
411 mut result: KeyExchangeResult,
412 ke_duration: Duration,
413 certificate: Option<CertificateInfo>,
414) -> std::result::Result<NtsKeResult, Error> {
415 let ntp_server = if let Ok(ip_addr) = result.remote.parse() {
417 SocketAddr::new(ip_addr, result.port)
418 } else {
419 let addr_str = format!("{}:{}", result.remote, result.port);
421 addr_str
422 .to_socket_addrs()
423 .ok()
424 .and_then(|mut addrs| addrs.next())
425 .ok_or_else(|| {
426 Error::Other(format!(
427 "Failed to resolve NTP server address: {}:{}. DNS resolution returned no results.",
428 result.remote, result.port
429 ))
430 })?
431 };
432
433 let mut cookies = Vec::new();
436 while let Some(cookie) = result.nts.get_cookie() {
437 cookies.push(cookie);
438 }
439
440 debug!("Extracted {} cookies from NTS-KE", cookies.len());
441
442 let (c2s, s2c) = result.nts.get_keys();
445
446 debug!("Extracted NTS ciphers for authenticated NTP");
447
448 let aead_algorithm = match c2s.key_bytes().len() {
449 32 => "AEAD_AES_SIV_CMAC_256".to_string(),
450 64 => "AEAD_AES_SIV_CMAC_512".to_string(),
451 other => format!("UNKNOWN_KEY_LEN_{}", other),
452 };
453
454 Ok(NtsKeResult::new(
455 ntp_server,
456 aead_algorithm,
457 cookies,
458 ke_duration,
459 c2s,
460 s2c,
461 certificate,
462 ))
463}
464
465impl From<KeyExchangeError> for Error {
467 fn from(err: KeyExchangeError) -> Self {
468 match err {
469 KeyExchangeError::UnrecognizedCriticalRecord => {
470 Error::KeyExchange("Unrecognized critical NTS record".to_string())
471 }
472 KeyExchangeError::BadRequest => Error::KeyExchange("Bad request".to_string()),
473 KeyExchangeError::InternalServerError => {
474 Error::KeyExchange("Internal server error".to_string())
475 }
476 KeyExchangeError::UnknownErrorCode(code) => {
477 Error::KeyExchange(format!("Unknown error code: {}", code))
478 }
479 KeyExchangeError::BadResponse => Error::KeyExchange("Bad response".to_string()),
480 KeyExchangeError::NoValidProtocol => {
481 Error::KeyExchange("No valid protocol negotiated".to_string())
482 }
483 KeyExchangeError::NoValidAlgorithm => {
484 Error::KeyExchange("No valid AEAD algorithm negotiated".to_string())
485 }
486 KeyExchangeError::InvalidFixedKeyLength => {
487 Error::KeyExchange("Invalid fixed key length".to_string())
488 }
489 KeyExchangeError::NoCookies => Error::KeyExchange("No cookies received".to_string()),
490 KeyExchangeError::CookiesTooBig => Error::KeyExchange("Cookies too big".to_string()),
491 KeyExchangeError::Io(e) => Error::Io(e),
492 KeyExchangeError::Tls(e) => Error::Tls(format!("TLS error: {:?}", e)),
493 KeyExchangeError::Certificate(e) => Error::Tls(format!("Certificate error: {:?}", e)),
494 KeyExchangeError::DnsName(e) => Error::Tls(format!("DNS name error: {:?}", e)),
495 KeyExchangeError::IncompleteResponse => {
496 Error::KeyExchange("Incomplete NTS-KE response".to_string())
497 }
498 }
499 }
500}
501
502#[derive(Debug)]
504struct KeyLogFile(Mutex<std::fs::File>);
505
506impl rustls::KeyLog for KeyLogFile {
507 fn log(&self, label: &str, client_random: &[u8], secret: &[u8]) {
508 if let Ok(mut file) = self.0.lock() {
509 let _ = writeln!(
510 file,
511 "{} {} {}",
512 label,
513 to_hex(client_random),
514 to_hex(secret)
515 );
516 let _ = file.flush();
517 }
518 }
519}
520
521fn to_hex(bytes: &[u8]) -> String {
523 bytes.iter().map(|b| format!("{:02x}", b)).collect()
524}