ant_quic/trust/
mod.rs

1// Copyright 2024 Saorsa Labs Ltd.
2//
3// Trust module: TOFU pinning, continuity-checked rotations, channel binding hooks,
4// and event/policy surfaces.
5
6use std::{
7    fs, io,
8    path::{Path, PathBuf},
9    sync::{Arc, Mutex, OnceLock},
10};
11
12/// Global trust runtime storage that allows resetting for tests
13static GLOBAL_TRUST: Mutex<Option<Arc<GlobalTrustRuntime>>> = Mutex::new(None);
14
15use crate::crypto::pqc::types::{MlDsaPublicKey, MlDsaSecretKey, MlDsaSignature};
16use crate::crypto::raw_public_keys::pqc::{
17    ML_DSA_65_SIGNATURE_SIZE, extract_public_key_from_spki, sign_with_ml_dsa, verify_with_ml_dsa,
18};
19use serde::{Deserialize, Serialize};
20use sha2::{Digest, Sha256};
21use tokio::io::AsyncWriteExt as _;
22
23use crate::{high_level::Connection, nat_traversal_api::PeerId};
24use thiserror::Error;
25
26/// Errors that can occur during trust operations such as pinning, rotation, and channel binding.
27#[derive(Error, Debug)]
28pub enum TrustError {
29    /// I/O error during trust operations.
30    #[error("I/O error: {0}")]
31    Io(#[from] io::Error),
32    /// Serialization/deserialization error.
33    #[error("serialization error: {0}")]
34    Serde(#[from] serde_json::Error),
35    /// Peer is already pinned and cannot be pinned again.
36    #[error("already pinned")]
37    AlreadyPinned,
38    /// Peer is not pinned yet and operation requires pinning.
39    #[error("not pinned yet")]
40    NotPinned,
41    /// Continuity signature is required but not provided.
42    #[error("continuity signature required")]
43    ContinuityRequired,
44    /// Continuity signature is invalid.
45    #[error("continuity signature invalid")]
46    ContinuityInvalid,
47    /// Channel binding operation failed.
48    #[error("channel binding failed: {0}")]
49    ChannelBinding(&'static str),
50}
51
52// ===================== Pin store =====================
53
54/// A record of pinned fingerprints for a peer, supporting key rotation with continuity.
55/// Contains the current fingerprint and optionally the previous one for continuity validation.
56#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
57pub struct PinRecord {
58    /// The current SHA-256 fingerprint of the peer's public key (SPKI).
59    pub current_fingerprint: [u8; 32],
60    /// The previous SHA-256 fingerprint if the key has been rotated, used for continuity validation.
61    pub previous_fingerprint: Option<[u8; 32]>,
62}
63
64/// A trait for storing and retrieving pinned peer fingerprints.
65/// Implementations must be thread-safe (Send + Sync) for concurrent access.
66pub trait PinStore: Send + Sync {
67    /// Load the pin record for a given peer, if one exists.
68    /// Returns None if the peer has not been pinned yet.
69    fn load(&self, peer: &PeerId) -> Result<Option<PinRecord>, TrustError>;
70    /// Save the first (initial) fingerprint for a peer.
71    /// Fails if the peer is already pinned.
72    fn save_first(&self, peer: &PeerId, fpr: [u8; 32]) -> Result<(), TrustError>;
73    /// Rotate a peer's fingerprint from old to new, updating the pin record.
74    /// Validates that the old fingerprint matches the current one.
75    fn rotate(&self, peer: &PeerId, old: [u8; 32], new: [u8; 32]) -> Result<(), TrustError>;
76}
77
78/// A filesystem-based implementation of PinStore that persists pin records as JSON files.
79/// Each peer's record is stored in a separate file named after the peer's hex-encoded ID.
80#[derive(Clone)]
81pub struct FsPinStore {
82    dir: Arc<PathBuf>,
83}
84
85impl FsPinStore {
86    /// Create a new filesystem pin store that stores records in the given directory.
87    /// The directory will be created if it doesn't exist.
88    pub fn new(dir: &Path) -> Self {
89        let _ = fs::create_dir_all(dir);
90        Self {
91            dir: Arc::new(dir.to_path_buf()),
92        }
93    }
94
95    fn path_for(&self, peer: &PeerId) -> PathBuf {
96        let hex = hex::encode(peer.0);
97        self.dir.join(format!("{hex}.json"))
98    }
99}
100
101impl PinStore for FsPinStore {
102    fn load(&self, peer: &PeerId) -> Result<Option<PinRecord>, TrustError> {
103        let path = self.path_for(peer);
104        if !path.exists() {
105            return Ok(None);
106        }
107        let data = fs::read(path)?;
108        Ok(Some(serde_json::from_slice(&data)?))
109    }
110
111    fn save_first(&self, peer: &PeerId, fpr: [u8; 32]) -> Result<(), TrustError> {
112        if self.load(peer)?.is_some() {
113            return Err(TrustError::AlreadyPinned);
114        }
115        let rec = PinRecord {
116            current_fingerprint: fpr,
117            previous_fingerprint: None,
118        };
119        let data = serde_json::to_vec_pretty(&rec)?;
120        fs::write(self.path_for(peer), data)?;
121        Ok(())
122    }
123
124    fn rotate(&self, peer: &PeerId, old: [u8; 32], new: [u8; 32]) -> Result<(), TrustError> {
125        let path = self.path_for(peer);
126        let Some(mut rec) = self.load(peer)? else {
127            return Err(TrustError::NotPinned);
128        };
129        if rec.current_fingerprint != old {
130            // Treat as invalid rotation attempt; keep state unchanged
131            return Err(TrustError::ContinuityInvalid);
132        }
133        rec.previous_fingerprint = Some(rec.current_fingerprint);
134        rec.current_fingerprint = new;
135        fs::write(path, serde_json::to_vec_pretty(&rec)?)?;
136        Ok(())
137    }
138}
139
140// ===================== Events & Policy =====================
141
142/// A trait for receiving notifications about trust-related events.
143/// Implementations can be used to monitor pinning, rotation, and channel binding operations.
144/// All methods have default empty implementations for optional overriding.
145pub trait EventSink: Send + Sync {
146    /// Called when a peer is first seen and pinned (TOFU operation).
147    /// Provides the peer ID and their initial fingerprint.
148    fn on_first_seen(&self, _peer: &PeerId, _fpr: &[u8; 32]) {}
149    /// Called when a peer's key is rotated from old to new fingerprint.
150    /// Provides both the old and new fingerprints.
151    fn on_rotation(&self, _old: &[u8; 32], _new: &[u8; 32]) {}
152    /// Called when channel binding verification succeeds for a peer.
153    /// Provides the peer ID that was successfully verified.
154    fn on_binding_verified(&self, _peer: &PeerId) {}
155}
156
157/// A test utility that collects and records trust-related events for verification.
158/// Useful in tests to assert that expected events were triggered.
159#[derive(Default)]
160pub struct EventCollector {
161    inner: Mutex<CollectorState>,
162}
163
164#[derive(Default)]
165struct CollectorState {
166    first_seen: Option<(PeerId, [u8; 32])>,
167    rotation: Option<([u8; 32], [u8; 32])>,
168    binding_verified: bool,
169}
170
171impl EventCollector {
172    /// Check if the `on_first_seen` event was called with the specified peer and fingerprint.
173    pub fn first_seen_called_with(&self, p: &PeerId, f: &[u8; 32]) -> bool {
174        self.inner
175            .lock()
176            .map(|s| {
177                s.first_seen
178                    .as_ref()
179                    .map(|(pp, ff)| pp == p && ff == f)
180                    .unwrap_or(false)
181            })
182            .unwrap_or(false)
183    }
184    /// Check if the `on_binding_verified` event was called.
185    pub fn binding_verified_called(&self) -> bool {
186        self.inner
187            .lock()
188            .map(|s| s.binding_verified)
189            .unwrap_or(false)
190    }
191}
192
193impl EventSink for EventCollector {
194    fn on_first_seen(&self, peer: &PeerId, fpr: &[u8; 32]) {
195        if let Ok(mut g) = self.inner.lock() {
196            g.first_seen = Some((*peer, *fpr));
197        }
198    }
199    fn on_rotation(&self, old: &[u8; 32], new: &[u8; 32]) {
200        if let Ok(mut g) = self.inner.lock() {
201            g.rotation = Some((*old, *new));
202        }
203    }
204    fn on_binding_verified(&self, _peer: &PeerId) {
205        if let Ok(mut g) = self.inner.lock() {
206            g.binding_verified = true;
207        }
208    }
209}
210
211/// Configuration policy for trust operations including TOFU, continuity, and channel binding.
212/// Provides a builder pattern for configuring trust behavior.
213#[derive(Clone)]
214pub struct TransportPolicy {
215    allow_tofu: bool,
216    require_continuity: bool,
217    enable_channel_binding: bool,
218    sink: Option<Arc<dyn EventSink>>,
219}
220
221impl Default for TransportPolicy {
222    /// Create a default policy that allows TOFU, requires continuity, enables channel binding, and has no event sink.
223    fn default() -> Self {
224        Self {
225            allow_tofu: true,
226            require_continuity: true,
227            enable_channel_binding: true,
228            sink: None,
229        }
230    }
231}
232
233impl TransportPolicy {
234    /// Configure whether Trust-On-First-Use (TOFU) pinning is allowed.
235    /// When true, unknown peers can be automatically pinned on first connection.
236    pub fn with_allow_tofu(mut self, v: bool) -> Self {
237        self.allow_tofu = v;
238        self
239    }
240    /// Configure whether key rotation continuity validation is required.
241    /// When true, key rotations must provide valid continuity signatures.
242    pub fn with_require_continuity(mut self, v: bool) -> Self {
243        self.require_continuity = v;
244        self
245    }
246    /// Configure whether channel binding verification is enabled.
247    /// When true, connections will perform channel binding checks.
248    pub fn with_enable_channel_binding(mut self, v: bool) -> Self {
249        self.enable_channel_binding = v;
250        self
251    }
252    /// Set an event sink to receive notifications about trust operations.
253    /// The sink will be called for pinning, rotation, and binding events.
254    pub fn with_event_sink(mut self, sink: Arc<dyn EventSink>) -> Self {
255        self.sink = Some(sink);
256        self
257    }
258}
259
260// ===================== Global runtime (test/integration hook) =====================
261
262/// Global trust runtime used by integration glue to perform automatic
263/// channel binding and event emission. This is intentionally simple and
264/// primarily for tests and early integration; production deployments
265/// should provide explicit wiring.
266#[derive(Clone)]
267pub struct GlobalTrustRuntime {
268    /// The pin store for managing peer fingerprints and key rotation
269    pub store: Arc<dyn PinStore>,
270    /// The trust policy configuration for TOFU, continuity, and channel binding
271    pub policy: TransportPolicy,
272    /// The local ML-DSA-65 public key for trust operations
273    pub local_public_key: Arc<MlDsaPublicKey>,
274    /// The local ML-DSA-65 secret key for trust operations
275    pub local_secret_key: Arc<MlDsaSecretKey>,
276    /// The local Subject Public Key Info (SPKI) for trust operations
277    pub local_spki: Arc<Vec<u8>>,
278}
279
280/// Install a global trust runtime used by automatic binding integration.
281///
282/// This is safe to call multiple times across tests in a single process.
283/// Each call will replace the previous runtime, allowing tests to reset state.
284#[allow(clippy::unwrap_used)]
285pub fn set_global_runtime(rt: Arc<GlobalTrustRuntime>) {
286    *GLOBAL_TRUST.lock().unwrap() = Some(rt);
287}
288
289/// Get the global trust runtime, if one was installed.
290#[allow(clippy::unwrap_used)]
291pub fn global_runtime() -> Option<Arc<GlobalTrustRuntime>> {
292    GLOBAL_TRUST.lock().unwrap().clone()
293}
294
295/// Reset the global trust runtime to None.
296///
297/// This is primarily used in tests to clean up between test runs.
298/// Production code should not call this function.
299#[cfg(test)]
300pub fn reset_global_runtime() {
301    *GLOBAL_TRUST.lock().unwrap() = None;
302}
303
304// ===================== Registration & Rotation =====================
305
306fn fingerprint_spki(spki: &[u8]) -> [u8; 32] {
307    let mut h = Sha256::new();
308    h.update(spki);
309    let r = h.finalize();
310    let mut out = [0u8; 32];
311    out.copy_from_slice(&r);
312    out
313}
314
315fn peer_id_from_spki(spki: &[u8]) -> PeerId {
316    PeerId(fingerprint_spki(spki))
317}
318
319/// Register a peer for the first time, performing TOFU pinning if allowed by policy.
320/// Computes the peer ID from the SPKI fingerprint and either loads existing pin or creates new one.
321/// Returns the peer ID regardless of whether pinning occurred.
322pub fn register_first_seen(
323    store: &dyn PinStore,
324    policy: &TransportPolicy,
325    spki: &[u8],
326) -> Result<PeerId, TrustError> {
327    let peer = peer_id_from_spki(spki);
328    let fpr = fingerprint_spki(spki);
329    match store.load(&peer)? {
330        Some(_) => Ok(peer),
331        None => {
332            if !policy.allow_tofu {
333                return Err(TrustError::ChannelBinding("TOFU disallowed"));
334            }
335            store.save_first(&peer, fpr)?;
336            if let Some(sink) = &policy.sink {
337                sink.on_first_seen(&peer, &fpr);
338            }
339            Ok(peer)
340        }
341    }
342}
343
344/// Sign a new fingerprint with the old private key to prove continuity during key rotation.
345/// Returns the ML-DSA-65 signature as bytes, which can be verified with the old public key.
346pub fn sign_continuity(old_sk: &MlDsaSecretKey, new_fpr: &[u8; 32]) -> Vec<u8> {
347    match sign_with_ml_dsa(old_sk, new_fpr) {
348        Ok(sig) => sig.as_bytes().to_vec(),
349        Err(_) => Vec::new(),
350    }
351}
352
353/// Register a key rotation for a peer, validating continuity if required by policy.
354/// Updates the pin record with the new fingerprint and triggers rotation events.
355/// Validates the old fingerprint matches the current pin and checks continuity signature if required.
356pub fn register_rotation(
357    store: &dyn PinStore,
358    policy: &TransportPolicy,
359    peer: &PeerId,
360    old_fpr: &[u8; 32],
361    new_spki: &[u8],
362    continuity_sig: &[u8],
363) -> Result<(), TrustError> {
364    let new_fpr = fingerprint_spki(new_spki);
365    if policy.require_continuity {
366        // Continuity: signature of new_fpr by old key. We cannot recover the old key here; this
367        // is validated at a higher layer with the old SPKI. For now, enforce signature presence
368        // and length (ML-DSA-65) as a minimal check.
369        if continuity_sig.len() != ML_DSA_65_SIGNATURE_SIZE {
370            return Err(TrustError::ContinuityRequired);
371        }
372    }
373    store.rotate(peer, *old_fpr, new_fpr)?;
374    if let Some(sink) = &policy.sink {
375        sink.on_rotation(old_fpr, &new_fpr);
376    }
377    Ok(())
378}
379
380// ===================== Channel binding =====================
381
382/// Derive a fixed-size exporter key from the TLS session for binding.
383///
384/// Both peers derive the same 32-byte value when using identical
385/// label/context. This value is then signed and verified for binding.
386pub fn derive_exporter(conn: &Connection) -> Result<[u8; 32], TrustError> {
387    let mut out = [0u8; 32];
388    let label = b"ant-quic/pq-binding/v1";
389    let context = b"binding";
390    conn.export_keying_material(&mut out, label, context)
391        .map_err(|_| TrustError::ChannelBinding("exporter"))?;
392    Ok(out)
393}
394
395/// Sign the exporter with an ML-DSA-65 private key.
396pub fn sign_exporter(
397    sk: &MlDsaSecretKey,
398    exporter: &[u8; 32],
399) -> Result<MlDsaSignature, TrustError> {
400    sign_with_ml_dsa(sk, exporter).map_err(|_| TrustError::ChannelBinding("ML-DSA sign failed"))
401}
402
403/// Verify a binding signature against a pinned SubjectPublicKeyInfo (SPKI).
404///
405/// - Validates the SPKI matches the current pin for the derived peer ID.
406/// - Verifies the ML-DSA-65 signature over the exporter using the SPKI's key.
407/// - Emits `OnBindingVerified` on success and returns the `PeerId`.
408pub fn verify_binding(
409    store: &dyn PinStore,
410    policy: &TransportPolicy,
411    spki: &[u8],
412    exporter: &[u8; 32],
413    signature: &[u8],
414) -> Result<PeerId, TrustError> {
415    // Compute IDs/fingerprints
416    let peer = peer_id_from_spki(spki);
417    let fpr = fingerprint_spki(spki);
418
419    // Check pin
420    let Some(rec) = store.load(&peer)? else {
421        return Err(TrustError::NotPinned);
422    };
423    if rec.current_fingerprint != fpr {
424        return Err(TrustError::ChannelBinding("fingerprint mismatch"));
425    }
426
427    // Extract public key from SPKI and verify signature
428    let pk = extract_public_key_from_spki(spki)
429        .map_err(|_| TrustError::ChannelBinding("spki invalid"))?;
430    let sig = MlDsaSignature::from_bytes(signature)
431        .map_err(|_| TrustError::ChannelBinding("invalid signature format"))?;
432    verify_with_ml_dsa(&pk, exporter, &sig)
433        .map_err(|_| TrustError::ChannelBinding("sig verify"))?;
434
435    if let Some(sink) = &policy.sink {
436        sink.on_binding_verified(&peer);
437    }
438    Ok(peer)
439}
440
441/// Perform a simple exporter-based channel binding. Minimal stub that derives exporter
442/// and marks success via event sink. Future work will add signature exchange and pin check.
443pub async fn perform_channel_binding(
444    conn: &Connection,
445    store: &dyn PinStore,
446    policy: &TransportPolicy,
447) -> Result<(), TrustError> {
448    if !policy.enable_channel_binding {
449        return Ok(());
450    }
451
452    // Derive exporter bytes deterministically; size and label are fixed.
453    let mut out = [0u8; 32];
454    let label = b"ant-quic exporter v1";
455    let context = b"binding";
456    conn.export_keying_material(&mut out, label, context)
457        .map_err(|_| TrustError::ChannelBinding("exporter"))?;
458
459    // In a complete implementation, we would:
460    // - extract peer SPKI from the session
461    // - compute PeerId and check PinStore
462    // - exchange signatures over the exporter using ML-DSA/Ed25519
463    // - verify signature against pinned SPKI
464    // For now, we simply signal success if exporter is derivable.
465    if let Some(sink) = &policy.sink {
466        // Best-effort: derive a pseudo PeerId from exporter for event association in tests
467        let peer = PeerId(out);
468        sink.on_binding_verified(&peer);
469    }
470    let _ = store; // placeholder; real check will consult pins
471    Ok(())
472}
473
474/// Test-only helper: perform channel binding from provided exporter bytes.
475pub fn perform_channel_binding_from_exporter(
476    exporter: &[u8; 32],
477    policy: &TransportPolicy,
478) -> Result<(), TrustError> {
479    if let Some(sink) = &policy.sink {
480        sink.on_binding_verified(&PeerId(*exporter));
481    }
482    Ok(())
483}
484
485/// Send a binding message over a unidirectional stream using ML-DSA-65.
486///
487/// Format: `u16 spki_len | u16 sig_len | exporter[32] | sig bytes | spki bytes`.
488pub async fn send_binding(
489    conn: &Connection,
490    exporter: &[u8; 32],
491    signer: &MlDsaSecretKey,
492    spki: &[u8],
493) -> Result<(), TrustError> {
494    let mut stream = conn
495        .open_uni()
496        .await
497        .map_err(|_| TrustError::ChannelBinding("open_uni"))?;
498    let sig = sign_exporter(signer, exporter)?;
499    let sig_bytes = sig.as_bytes();
500    let spki_len: u16 = spki
501        .len()
502        .try_into()
503        .map_err(|_| TrustError::ChannelBinding("spki too large"))?;
504    let sig_len: u16 = sig_bytes
505        .len()
506        .try_into()
507        .map_err(|_| TrustError::ChannelBinding("sig too large"))?;
508
509    // Header: spki_len (2) + sig_len (2) + exporter (32)
510    let mut header = [0u8; 2 + 2 + 32];
511    header[0..2].copy_from_slice(&spki_len.to_be_bytes());
512    header[2..4].copy_from_slice(&sig_len.to_be_bytes());
513    header[4..36].copy_from_slice(exporter);
514    stream
515        .write_all(&header)
516        .await
517        .map_err(|_| TrustError::ChannelBinding("write header"))?;
518    stream
519        .write_all(sig_bytes)
520        .await
521        .map_err(|_| TrustError::ChannelBinding("write sig"))?;
522    stream
523        .write_all(spki)
524        .await
525        .map_err(|_| TrustError::ChannelBinding("write spki"))?;
526    stream
527        .shutdown()
528        .await
529        .map_err(|_| TrustError::ChannelBinding("finish"))?;
530    Ok(())
531}
532
533/// Receive and verify a binding message over a unidirectional stream using ML-DSA-65.
534pub async fn recv_verify_binding(
535    conn: &Connection,
536    store: &dyn PinStore,
537    policy: &TransportPolicy,
538) -> Result<PeerId, TrustError> {
539    use tokio::io::AsyncReadExt;
540    let mut stream = conn
541        .accept_uni()
542        .await
543        .map_err(|_| TrustError::ChannelBinding("accept_uni"))?;
544
545    // Read header: spki_len (2) + sig_len (2) + exporter (32)
546    let mut header = [0u8; 2 + 2 + 32];
547    stream
548        .read_exact(&mut header)
549        .await
550        .map_err(|_| TrustError::ChannelBinding("read header"))?;
551    let spki_len = u16::from_be_bytes([header[0], header[1]]) as usize;
552    let sig_len = u16::from_be_bytes([header[2], header[3]]) as usize;
553    let mut exporter = [0u8; 32];
554    exporter.copy_from_slice(&header[4..36]);
555
556    // Read signature
557    let mut sig = vec![0u8; sig_len];
558    stream
559        .read_exact(&mut sig)
560        .await
561        .map_err(|_| TrustError::ChannelBinding("read sig"))?;
562
563    // Read SPKI
564    let mut spki = vec![0u8; spki_len];
565    stream
566        .read_exact(&mut spki)
567        .await
568        .map_err(|_| TrustError::ChannelBinding("read spki"))?;
569
570    verify_binding(store, policy, &spki, &exporter, &sig)
571}