rkik_nts/
nts_ke.rs

1//! NTS Key Exchange (NTS-KE) implementation using ntp-proto.
2//!
3//! This module wraps ntp-proto's KeyExchangeClient to provide an async interface.
4
5use std::io::Write;
6use std::net::{SocketAddr, ToSocketAddrs};
7use std::sync::{Arc, Mutex};
8use std::time::Duration;
9
10use ntp_proto::{KeyExchangeClient, KeyExchangeError, KeyExchangeResult, ProtocolVersion};
11use rustls::pki_types::{CertificateDer, ServerName as RustlsServerName, UnixTime};
12use sha2::{Digest, Sha256};
13use tracing::{debug, info, warn};
14use x509_parser::prelude::*;
15
16use crate::config::NtsClientConfig;
17use crate::error::{Error, Result};
18use crate::types::{CertificateInfo, NtsKeResult};
19
20/// Perform NTS-KE using ntp-proto's KeyExchangeClient
21pub(crate) async fn perform_nts_ke(config: &NtsClientConfig) -> Result<NtsKeResult> {
22    let ke_start = std::time::Instant::now();
23
24    info!(
25        "Starting NTS-KE with {}:{}",
26        config.nts_ke_server, config.nts_ke_port
27    );
28
29    // Resolve server address
30    let server_addr = resolve_server(&config.nts_ke_server, config.nts_ke_port).await?;
31    debug!("Resolved server address: {}", server_addr);
32
33    // Build TLS config with certificate capturing
34    let (tls_config, captured_certs) = build_tls_config(config)?;
35
36    // Determine protocol version (always V4 for now)
37    let protocol_version = ProtocolVersion::V4;
38
39    // Perform key exchange in a blocking task since KeyExchangeClient uses sync I/O
40    let server_name = config.nts_ke_server.clone();
41    let timeout_duration = config.timeout;
42
43    let result = tokio::task::spawn_blocking(move || {
44        perform_nts_ke_blocking(
45            server_addr,
46            server_name,
47            tls_config,
48            protocol_version,
49            timeout_duration,
50        )
51    })
52    .await
53    .map_err(|e| Error::KeyExchange(format!("Task join error: {}", e)))??;
54
55    let ke_duration = ke_start.elapsed();
56    debug!("NTS-KE completed in {:?}", ke_duration);
57
58    // Extract certificate information after successful handshake
59    let certificate = {
60        let certs = captured_certs.lock().unwrap();
61        if !certs.is_empty() {
62            extract_certificate_info(&certs)
63        } else {
64            None
65        }
66    };
67
68    if let Some(ref cert) = certificate {
69        debug!(
70            "Captured certificate: subject={}, issuer={}",
71            cert.subject, cert.issuer
72        );
73    }
74
75    // Convert KeyExchangeResult to NtsKeResult
76    convert_ke_result(result, ke_duration, certificate)
77}
78
79/// Perform NTS-KE in a blocking context
80fn perform_nts_ke_blocking(
81    server_addr: SocketAddr,
82    server_name: String,
83    tls_config: ntp_proto::tls_utils::ClientConfig,
84    protocol_version: ProtocolVersion,
85    timeout_duration: Duration,
86) -> Result<KeyExchangeResult> {
87    // Connect TCP socket (blocking)
88    let mut socket =
89        std::net::TcpStream::connect_timeout(&server_addr, timeout_duration).map_err(Error::Io)?;
90
91    socket.set_nonblocking(true).map_err(Error::Io)?;
92
93    debug!("TCP connection established");
94
95    // Create KeyExchangeClient
96    let mut ke_client = KeyExchangeClient::new(
97        server_name,
98        tls_config,
99        protocol_version,
100        Vec::<String>::new(), // no denied servers
101    )
102    .map_err(Error::from)?;
103
104    debug!("KeyExchangeClient created");
105
106    // Run the state machine
107    let start = std::time::Instant::now();
108    loop {
109        if start.elapsed() > timeout_duration {
110            return Err(Error::Timeout);
111        }
112
113        // Write any pending TLS data to socket
114        if ke_client.wants_write() {
115            match ke_client.write_socket(&mut socket) {
116                Ok(n) => {
117                    if n > 0 {
118                        debug!("Wrote {} bytes to socket", n);
119                    }
120                }
121                Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {}
122                Err(e) => return Err(Error::Io(e)),
123            }
124        }
125
126        // Read any available data from socket
127        if ke_client.wants_read() {
128            match ke_client.read_socket(&mut socket) {
129                Ok(n) => {
130                    if n > 0 {
131                        debug!("Read {} bytes from socket", n);
132                    }
133                }
134                Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {}
135                Err(e) => return Err(Error::Io(e)),
136            }
137        }
138
139        // Progress the state machine
140        match ke_client.progress() {
141            std::ops::ControlFlow::Break(Ok(result)) => {
142                debug!("NTS-KE succeeded");
143                return Ok(result);
144            }
145            std::ops::ControlFlow::Break(Err(e)) => {
146                return Err(Error::from(e));
147            }
148            std::ops::ControlFlow::Continue(client) => {
149                ke_client = client;
150                // Small sleep to avoid busy-waiting
151                std::thread::sleep(std::time::Duration::from_millis(10));
152            }
153        }
154    }
155}
156
157/// Extract certificate information from the peer certificate
158fn extract_certificate_info(certs: &[CertificateDer<'_>]) -> Option<CertificateInfo> {
159    // Get the first certificate (server certificate)
160    let cert_der = certs.first()?;
161
162    // Parse the certificate using x509-parser
163    let (_, cert) = X509Certificate::from_der(cert_der.as_ref()).ok()?;
164
165    // Extract subject
166    let subject = cert.subject().to_string();
167
168    // Extract issuer
169    let issuer = cert.issuer().to_string();
170
171    // Extract validity period and convert to RFC3339-like format
172    let valid_from = format!("{}", cert.validity().not_before);
173    let valid_until = format!("{}", cert.validity().not_after);
174
175    // Extract serial number as hex string
176    let serial_number = format!("{:x}", cert.serial);
177
178    // Extract SANs (Subject Alternative Names)
179    let san_dns_names = cert
180        .subject_alternative_name()
181        .ok()
182        .flatten()
183        .map(|san| {
184            san.value
185                .general_names
186                .iter()
187                .filter_map(|gn| match gn {
188                    GeneralName::DNSName(name) => Some(name.to_string()),
189                    _ => None,
190                })
191                .collect::<Vec<_>>()
192        })
193        .unwrap_or_default();
194
195    // Extract signature algorithm
196    let signature_algorithm = cert.signature_algorithm.algorithm.to_string();
197
198    // Extract public key algorithm
199    let public_key_algorithm = cert.public_key().algorithm.algorithm.to_string();
200
201    // Calculate SHA-256 fingerprint
202    let mut hasher = Sha256::new();
203    hasher.update(cert_der.as_ref());
204    let fingerprint_sha256 = format!("{:x}", hasher.finalize());
205
206    // Check if self-signed (simple check: subject == issuer)
207    let is_self_signed = cert.subject() == cert.issuer();
208
209    Some(CertificateInfo {
210        subject,
211        issuer,
212        valid_from,
213        valid_until,
214        serial_number,
215        san_dns_names,
216        signature_algorithm,
217        public_key_algorithm,
218        fingerprint_sha256,
219        is_self_signed,
220    })
221}
222
223/// Custom certificate verifier that captures the certificate chain
224#[derive(Debug)]
225struct CapturingVerifier {
226    inner: Arc<dyn rustls::client::danger::ServerCertVerifier>,
227    captured_certs: Arc<Mutex<Vec<CertificateDer<'static>>>>,
228}
229
230impl rustls::client::danger::ServerCertVerifier for CapturingVerifier {
231    fn verify_server_cert(
232        &self,
233        end_entity: &CertificateDer<'_>,
234        intermediates: &[CertificateDer<'_>],
235        server_name: &RustlsServerName<'_>,
236        ocsp_response: &[u8],
237        now: UnixTime,
238    ) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
239        // Capture the certificates
240        let mut certs = self.captured_certs.lock().unwrap();
241        certs.push(end_entity.clone().into_owned());
242        for cert in intermediates {
243            certs.push(cert.clone().into_owned());
244        }
245
246        // Delegate to the real verifier
247        self.inner
248            .verify_server_cert(end_entity, intermediates, server_name, ocsp_response, now)
249    }
250
251    fn verify_tls12_signature(
252        &self,
253        message: &[u8],
254        cert: &CertificateDer<'_>,
255        dss: &rustls::DigitallySignedStruct,
256    ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
257        self.inner.verify_tls12_signature(message, cert, dss)
258    }
259
260    fn verify_tls13_signature(
261        &self,
262        message: &[u8],
263        cert: &CertificateDer<'_>,
264        dss: &rustls::DigitallySignedStruct,
265    ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
266        self.inner.verify_tls13_signature(message, cert, dss)
267    }
268
269    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
270        self.inner.supported_verify_schemes()
271    }
272}
273
274/// Build TLS config for NTS-KE with certificate capturing
275fn build_tls_config(
276    config: &NtsClientConfig,
277) -> Result<(
278    ntp_proto::tls_utils::ClientConfig,
279    Arc<Mutex<Vec<CertificateDer<'static>>>>,
280)> {
281    use ntp_proto::tls_utils::{self};
282
283    // Ensure a default crypto provider is installed
284    // This is safe to call multiple times - it will only install once
285    let _ = rustls::crypto::ring::default_provider().install_default();
286
287    // Enable TLS keylog for Wireshark decryption if SSLKEYLOGFILE is set
288    let key_log = std::env::var("SSLKEYLOGFILE")
289        .ok()
290        .and_then(|path| {
291            debug!("Enabling TLS keylog to: {}", path);
292            std::fs::OpenOptions::new()
293                .create(true)
294                .append(true)
295                .open(&path)
296                .ok()
297        })
298        .map(|file| Arc::new(KeyLogFile(Mutex::new(file))) as Arc<dyn rustls::KeyLog>);
299
300    // Create container for captured certificates
301    let captured_certs = Arc::new(Mutex::new(Vec::new()));
302
303    if config.verify_tls_cert {
304        // Normal verification with system certificates
305        let builder = tls_utils::client_config_builder_with_protocol_versions(&[&tls_utils::TLS13]);
306        let provider = builder.crypto_provider().clone();
307
308        let platform_verifier = tls_utils::PlatformVerifier::new().with_provider(provider);
309
310        // Wrap with capturing verifier
311        let capturing_verifier = CapturingVerifier {
312            inner: Arc::new(platform_verifier),
313            captured_certs: captured_certs.clone(),
314        };
315
316        let mut tls_config = builder
317            .dangerous()
318            .with_custom_certificate_verifier(Arc::new(capturing_verifier))
319            .with_no_client_auth();
320
321        if let Some(kl) = key_log {
322            tls_config.key_log = kl;
323        }
324
325        Ok((tls_config, captured_certs))
326    } else {
327        // No verification mode (for self-signed certificates)
328        warn!("TLS certificate verification is disabled!");
329
330        let builder = tls_utils::client_config_builder_with_protocol_versions(&[&tls_utils::TLS13]);
331        let provider = builder.crypto_provider().clone();
332
333        // Use NoVerification verifier wrapped with capturing
334        let no_verification = NoVerification { provider };
335
336        let capturing_verifier = CapturingVerifier {
337            inner: Arc::new(no_verification),
338            captured_certs: captured_certs.clone(),
339        };
340
341        let mut tls_config = builder
342            .dangerous()
343            .with_custom_certificate_verifier(Arc::new(capturing_verifier))
344            .with_no_client_auth();
345
346        if let Some(kl) = key_log {
347            tls_config.key_log = kl;
348        }
349
350        Ok((tls_config, captured_certs))
351    }
352}
353
354/// A certificate verifier that accepts all certificates (for testing only!)
355#[derive(Debug)]
356struct NoVerification {
357    provider: Arc<rustls::crypto::CryptoProvider>,
358}
359
360impl rustls::client::danger::ServerCertVerifier for NoVerification {
361    fn verify_server_cert(
362        &self,
363        _end_entity: &rustls::pki_types::CertificateDer<'_>,
364        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
365        _server_name: &rustls::pki_types::ServerName<'_>,
366        _ocsp_response: &[u8],
367        _now: rustls::pki_types::UnixTime,
368    ) -> std::result::Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
369        Ok(rustls::client::danger::ServerCertVerified::assertion())
370    }
371
372    fn verify_tls12_signature(
373        &self,
374        _message: &[u8],
375        _cert: &rustls::pki_types::CertificateDer<'_>,
376        _dss: &rustls::DigitallySignedStruct,
377    ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
378        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
379    }
380
381    fn verify_tls13_signature(
382        &self,
383        _message: &[u8],
384        _cert: &rustls::pki_types::CertificateDer<'_>,
385        _dss: &rustls::DigitallySignedStruct,
386    ) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
387        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
388    }
389
390    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
391        self.provider
392            .signature_verification_algorithms
393            .supported_schemes()
394    }
395}
396
397/// Resolve server address
398async fn resolve_server(server: &str, port: u16) -> Result<SocketAddr> {
399    let addrs = format!("{}:{}", server, port)
400        .to_socket_addrs()
401        .map_err(|e| Error::ServerUnavailable(format!("DNS resolution failed: {}", e)))?;
402
403    addrs
404        .into_iter()
405        .next()
406        .ok_or_else(|| Error::ServerUnavailable("No addresses resolved".to_string()))
407}
408
409/// Convert ntp-proto's KeyExchangeResult to our NtsKeResult
410fn convert_ke_result(
411    mut result: KeyExchangeResult,
412    ke_duration: Duration,
413    certificate: Option<CertificateInfo>,
414) -> std::result::Result<NtsKeResult, Error> {
415    // Try to parse the remote as an IP address first, otherwise resolve it
416    let ntp_server = if let Ok(ip_addr) = result.remote.parse() {
417        SocketAddr::new(ip_addr, result.port)
418    } else {
419        // If not an IP, try to resolve the hostname
420        let addr_str = format!("{}:{}", result.remote, result.port);
421        addr_str
422            .to_socket_addrs()
423            .ok()
424            .and_then(|mut addrs| addrs.next())
425            .ok_or_else(|| {
426                Error::Other(format!(
427                    "Failed to resolve NTP server address: {}:{}. DNS resolution returned no results.",
428                    result.remote, result.port
429                ))
430            })?
431    };
432
433    // Extract cookies from the CookieStash by consuming them using the public API
434    // CookieStash is not Clone, so we need to extract all cookies into a Vec
435    let mut cookies = Vec::new();
436    while let Some(cookie) = result.nts.get_cookie() {
437        cookies.push(cookie);
438    }
439
440    debug!("Extracted {} cookies from NTS-KE", cookies.len());
441
442    // Extract the ciphers from SourceNtsData using get_keys()
443    // This consumes the SourceNtsData and returns (c2s, s2c) ciphers
444    let (c2s, s2c) = result.nts.get_keys();
445
446    debug!("Extracted NTS ciphers for authenticated NTP");
447
448    let aead_algorithm = match c2s.key_bytes().len() {
449        32 => "AEAD_AES_SIV_CMAC_256".to_string(),
450        64 => "AEAD_AES_SIV_CMAC_512".to_string(),
451        other => format!("UNKNOWN_KEY_LEN_{}", other),
452    };
453
454    Ok(NtsKeResult::new(
455        ntp_server,
456        aead_algorithm,
457        cookies,
458        ke_duration,
459        c2s,
460        s2c,
461        certificate,
462    ))
463}
464
465/// Convert KeyExchangeError to our Error type
466impl From<KeyExchangeError> for Error {
467    fn from(err: KeyExchangeError) -> Self {
468        match err {
469            KeyExchangeError::UnrecognizedCriticalRecord => {
470                Error::KeyExchange("Unrecognized critical NTS record".to_string())
471            }
472            KeyExchangeError::BadRequest => Error::KeyExchange("Bad request".to_string()),
473            KeyExchangeError::InternalServerError => {
474                Error::KeyExchange("Internal server error".to_string())
475            }
476            KeyExchangeError::UnknownErrorCode(code) => {
477                Error::KeyExchange(format!("Unknown error code: {}", code))
478            }
479            KeyExchangeError::BadResponse => Error::KeyExchange("Bad response".to_string()),
480            KeyExchangeError::NoValidProtocol => {
481                Error::KeyExchange("No valid protocol negotiated".to_string())
482            }
483            KeyExchangeError::NoValidAlgorithm => {
484                Error::KeyExchange("No valid AEAD algorithm negotiated".to_string())
485            }
486            KeyExchangeError::InvalidFixedKeyLength => {
487                Error::KeyExchange("Invalid fixed key length".to_string())
488            }
489            KeyExchangeError::NoCookies => Error::KeyExchange("No cookies received".to_string()),
490            KeyExchangeError::CookiesTooBig => Error::KeyExchange("Cookies too big".to_string()),
491            KeyExchangeError::Io(e) => Error::Io(e),
492            KeyExchangeError::Tls(e) => Error::Tls(format!("TLS error: {:?}", e)),
493            KeyExchangeError::Certificate(e) => Error::Tls(format!("Certificate error: {:?}", e)),
494            KeyExchangeError::DnsName(e) => Error::Tls(format!("DNS name error: {:?}", e)),
495            KeyExchangeError::IncompleteResponse => {
496                Error::KeyExchange("Incomplete NTS-KE response".to_string())
497            }
498        }
499    }
500}
501
502/// KeyLog handler for writing TLS secrets to file (for Wireshark decryption)
503#[derive(Debug)]
504struct KeyLogFile(Mutex<std::fs::File>);
505
506impl rustls::KeyLog for KeyLogFile {
507    fn log(&self, label: &str, client_random: &[u8], secret: &[u8]) {
508        if let Ok(mut file) = self.0.lock() {
509            let _ = writeln!(
510                file,
511                "{} {} {}",
512                label,
513                to_hex(client_random),
514                to_hex(secret)
515            );
516            let _ = file.flush();
517        }
518    }
519}
520
521/// Encode bytes to hexadecimal string
522fn to_hex(bytes: &[u8]) -> String {
523    bytes.iter().map(|b| format!("{:02x}", b)).collect()
524}