1use std::{
7 fs, io,
8 path::{Path, PathBuf},
9 sync::{Arc, Mutex, OnceLock},
10};
11
12static GLOBAL_TRUST: Mutex<Option<Arc<GlobalTrustRuntime>>> = Mutex::new(None);
14
15use crate::crypto::pqc::types::{MlDsaPublicKey, MlDsaSecretKey, MlDsaSignature};
16use crate::crypto::raw_public_keys::pqc::{
17 ML_DSA_65_SIGNATURE_SIZE, extract_public_key_from_spki, sign_with_ml_dsa, verify_with_ml_dsa,
18};
19use serde::{Deserialize, Serialize};
20use sha2::{Digest, Sha256};
21use tokio::io::AsyncWriteExt as _;
22
23use crate::{high_level::Connection, nat_traversal_api::PeerId};
24use thiserror::Error;
25
26#[derive(Error, Debug)]
28pub enum TrustError {
29 #[error("I/O error: {0}")]
31 Io(#[from] io::Error),
32 #[error("serialization error: {0}")]
34 Serde(#[from] serde_json::Error),
35 #[error("already pinned")]
37 AlreadyPinned,
38 #[error("not pinned yet")]
40 NotPinned,
41 #[error("continuity signature required")]
43 ContinuityRequired,
44 #[error("continuity signature invalid")]
46 ContinuityInvalid,
47 #[error("channel binding failed: {0}")]
49 ChannelBinding(&'static str),
50}
51
52#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
57pub struct PinRecord {
58 pub current_fingerprint: [u8; 32],
60 pub previous_fingerprint: Option<[u8; 32]>,
62}
63
64pub trait PinStore: Send + Sync {
67 fn load(&self, peer: &PeerId) -> Result<Option<PinRecord>, TrustError>;
70 fn save_first(&self, peer: &PeerId, fpr: [u8; 32]) -> Result<(), TrustError>;
73 fn rotate(&self, peer: &PeerId, old: [u8; 32], new: [u8; 32]) -> Result<(), TrustError>;
76}
77
78#[derive(Clone)]
81pub struct FsPinStore {
82 dir: Arc<PathBuf>,
83}
84
85impl FsPinStore {
86 pub fn new(dir: &Path) -> Self {
89 let _ = fs::create_dir_all(dir);
90 Self {
91 dir: Arc::new(dir.to_path_buf()),
92 }
93 }
94
95 fn path_for(&self, peer: &PeerId) -> PathBuf {
96 let hex = hex::encode(peer.0);
97 self.dir.join(format!("{hex}.json"))
98 }
99}
100
101impl PinStore for FsPinStore {
102 fn load(&self, peer: &PeerId) -> Result<Option<PinRecord>, TrustError> {
103 let path = self.path_for(peer);
104 if !path.exists() {
105 return Ok(None);
106 }
107 let data = fs::read(path)?;
108 Ok(Some(serde_json::from_slice(&data)?))
109 }
110
111 fn save_first(&self, peer: &PeerId, fpr: [u8; 32]) -> Result<(), TrustError> {
112 if self.load(peer)?.is_some() {
113 return Err(TrustError::AlreadyPinned);
114 }
115 let rec = PinRecord {
116 current_fingerprint: fpr,
117 previous_fingerprint: None,
118 };
119 let data = serde_json::to_vec_pretty(&rec)?;
120 fs::write(self.path_for(peer), data)?;
121 Ok(())
122 }
123
124 fn rotate(&self, peer: &PeerId, old: [u8; 32], new: [u8; 32]) -> Result<(), TrustError> {
125 let path = self.path_for(peer);
126 let Some(mut rec) = self.load(peer)? else {
127 return Err(TrustError::NotPinned);
128 };
129 if rec.current_fingerprint != old {
130 return Err(TrustError::ContinuityInvalid);
132 }
133 rec.previous_fingerprint = Some(rec.current_fingerprint);
134 rec.current_fingerprint = new;
135 fs::write(path, serde_json::to_vec_pretty(&rec)?)?;
136 Ok(())
137 }
138}
139
140pub trait EventSink: Send + Sync {
146 fn on_first_seen(&self, _peer: &PeerId, _fpr: &[u8; 32]) {}
149 fn on_rotation(&self, _old: &[u8; 32], _new: &[u8; 32]) {}
152 fn on_binding_verified(&self, _peer: &PeerId) {}
155}
156
157#[derive(Default)]
160pub struct EventCollector {
161 inner: Mutex<CollectorState>,
162}
163
164#[derive(Default)]
165struct CollectorState {
166 first_seen: Option<(PeerId, [u8; 32])>,
167 rotation: Option<([u8; 32], [u8; 32])>,
168 binding_verified: bool,
169}
170
171impl EventCollector {
172 pub fn first_seen_called_with(&self, p: &PeerId, f: &[u8; 32]) -> bool {
174 self.inner
175 .lock()
176 .map(|s| {
177 s.first_seen
178 .as_ref()
179 .map(|(pp, ff)| pp == p && ff == f)
180 .unwrap_or(false)
181 })
182 .unwrap_or(false)
183 }
184 pub fn binding_verified_called(&self) -> bool {
186 self.inner
187 .lock()
188 .map(|s| s.binding_verified)
189 .unwrap_or(false)
190 }
191}
192
193impl EventSink for EventCollector {
194 fn on_first_seen(&self, peer: &PeerId, fpr: &[u8; 32]) {
195 if let Ok(mut g) = self.inner.lock() {
196 g.first_seen = Some((*peer, *fpr));
197 }
198 }
199 fn on_rotation(&self, old: &[u8; 32], new: &[u8; 32]) {
200 if let Ok(mut g) = self.inner.lock() {
201 g.rotation = Some((*old, *new));
202 }
203 }
204 fn on_binding_verified(&self, _peer: &PeerId) {
205 if let Ok(mut g) = self.inner.lock() {
206 g.binding_verified = true;
207 }
208 }
209}
210
211#[derive(Clone)]
214pub struct TransportPolicy {
215 allow_tofu: bool,
216 require_continuity: bool,
217 enable_channel_binding: bool,
218 sink: Option<Arc<dyn EventSink>>,
219}
220
221impl Default for TransportPolicy {
222 fn default() -> Self {
224 Self {
225 allow_tofu: true,
226 require_continuity: true,
227 enable_channel_binding: true,
228 sink: None,
229 }
230 }
231}
232
233impl TransportPolicy {
234 pub fn with_allow_tofu(mut self, v: bool) -> Self {
237 self.allow_tofu = v;
238 self
239 }
240 pub fn with_require_continuity(mut self, v: bool) -> Self {
243 self.require_continuity = v;
244 self
245 }
246 pub fn with_enable_channel_binding(mut self, v: bool) -> Self {
249 self.enable_channel_binding = v;
250 self
251 }
252 pub fn with_event_sink(mut self, sink: Arc<dyn EventSink>) -> Self {
255 self.sink = Some(sink);
256 self
257 }
258}
259
260#[derive(Clone)]
267pub struct GlobalTrustRuntime {
268 pub store: Arc<dyn PinStore>,
270 pub policy: TransportPolicy,
272 pub local_public_key: Arc<MlDsaPublicKey>,
274 pub local_secret_key: Arc<MlDsaSecretKey>,
276 pub local_spki: Arc<Vec<u8>>,
278}
279
280#[allow(clippy::unwrap_used)]
285pub fn set_global_runtime(rt: Arc<GlobalTrustRuntime>) {
286 *GLOBAL_TRUST.lock().unwrap() = Some(rt);
287}
288
289#[allow(clippy::unwrap_used)]
291pub fn global_runtime() -> Option<Arc<GlobalTrustRuntime>> {
292 GLOBAL_TRUST.lock().unwrap().clone()
293}
294
295#[cfg(test)]
300pub fn reset_global_runtime() {
301 *GLOBAL_TRUST.lock().unwrap() = None;
302}
303
304fn fingerprint_spki(spki: &[u8]) -> [u8; 32] {
307 let mut h = Sha256::new();
308 h.update(spki);
309 let r = h.finalize();
310 let mut out = [0u8; 32];
311 out.copy_from_slice(&r);
312 out
313}
314
315fn peer_id_from_spki(spki: &[u8]) -> PeerId {
316 PeerId(fingerprint_spki(spki))
317}
318
319pub fn register_first_seen(
323 store: &dyn PinStore,
324 policy: &TransportPolicy,
325 spki: &[u8],
326) -> Result<PeerId, TrustError> {
327 let peer = peer_id_from_spki(spki);
328 let fpr = fingerprint_spki(spki);
329 match store.load(&peer)? {
330 Some(_) => Ok(peer),
331 None => {
332 if !policy.allow_tofu {
333 return Err(TrustError::ChannelBinding("TOFU disallowed"));
334 }
335 store.save_first(&peer, fpr)?;
336 if let Some(sink) = &policy.sink {
337 sink.on_first_seen(&peer, &fpr);
338 }
339 Ok(peer)
340 }
341 }
342}
343
344pub fn sign_continuity(old_sk: &MlDsaSecretKey, new_fpr: &[u8; 32]) -> Vec<u8> {
347 match sign_with_ml_dsa(old_sk, new_fpr) {
348 Ok(sig) => sig.as_bytes().to_vec(),
349 Err(_) => Vec::new(),
350 }
351}
352
353pub fn register_rotation(
357 store: &dyn PinStore,
358 policy: &TransportPolicy,
359 peer: &PeerId,
360 old_fpr: &[u8; 32],
361 new_spki: &[u8],
362 continuity_sig: &[u8],
363) -> Result<(), TrustError> {
364 let new_fpr = fingerprint_spki(new_spki);
365 if policy.require_continuity {
366 if continuity_sig.len() != ML_DSA_65_SIGNATURE_SIZE {
370 return Err(TrustError::ContinuityRequired);
371 }
372 }
373 store.rotate(peer, *old_fpr, new_fpr)?;
374 if let Some(sink) = &policy.sink {
375 sink.on_rotation(old_fpr, &new_fpr);
376 }
377 Ok(())
378}
379
380pub fn derive_exporter(conn: &Connection) -> Result<[u8; 32], TrustError> {
387 let mut out = [0u8; 32];
388 let label = b"ant-quic/pq-binding/v1";
389 let context = b"binding";
390 conn.export_keying_material(&mut out, label, context)
391 .map_err(|_| TrustError::ChannelBinding("exporter"))?;
392 Ok(out)
393}
394
395pub fn sign_exporter(
397 sk: &MlDsaSecretKey,
398 exporter: &[u8; 32],
399) -> Result<MlDsaSignature, TrustError> {
400 sign_with_ml_dsa(sk, exporter).map_err(|_| TrustError::ChannelBinding("ML-DSA sign failed"))
401}
402
403pub fn verify_binding(
409 store: &dyn PinStore,
410 policy: &TransportPolicy,
411 spki: &[u8],
412 exporter: &[u8; 32],
413 signature: &[u8],
414) -> Result<PeerId, TrustError> {
415 let peer = peer_id_from_spki(spki);
417 let fpr = fingerprint_spki(spki);
418
419 let Some(rec) = store.load(&peer)? else {
421 return Err(TrustError::NotPinned);
422 };
423 if rec.current_fingerprint != fpr {
424 return Err(TrustError::ChannelBinding("fingerprint mismatch"));
425 }
426
427 let pk = extract_public_key_from_spki(spki)
429 .map_err(|_| TrustError::ChannelBinding("spki invalid"))?;
430 let sig = MlDsaSignature::from_bytes(signature)
431 .map_err(|_| TrustError::ChannelBinding("invalid signature format"))?;
432 verify_with_ml_dsa(&pk, exporter, &sig)
433 .map_err(|_| TrustError::ChannelBinding("sig verify"))?;
434
435 if let Some(sink) = &policy.sink {
436 sink.on_binding_verified(&peer);
437 }
438 Ok(peer)
439}
440
441pub async fn perform_channel_binding(
444 conn: &Connection,
445 store: &dyn PinStore,
446 policy: &TransportPolicy,
447) -> Result<(), TrustError> {
448 if !policy.enable_channel_binding {
449 return Ok(());
450 }
451
452 let mut out = [0u8; 32];
454 let label = b"ant-quic exporter v1";
455 let context = b"binding";
456 conn.export_keying_material(&mut out, label, context)
457 .map_err(|_| TrustError::ChannelBinding("exporter"))?;
458
459 if let Some(sink) = &policy.sink {
466 let peer = PeerId(out);
468 sink.on_binding_verified(&peer);
469 }
470 let _ = store; Ok(())
472}
473
474pub fn perform_channel_binding_from_exporter(
476 exporter: &[u8; 32],
477 policy: &TransportPolicy,
478) -> Result<(), TrustError> {
479 if let Some(sink) = &policy.sink {
480 sink.on_binding_verified(&PeerId(*exporter));
481 }
482 Ok(())
483}
484
485pub async fn send_binding(
489 conn: &Connection,
490 exporter: &[u8; 32],
491 signer: &MlDsaSecretKey,
492 spki: &[u8],
493) -> Result<(), TrustError> {
494 let mut stream = conn
495 .open_uni()
496 .await
497 .map_err(|_| TrustError::ChannelBinding("open_uni"))?;
498 let sig = sign_exporter(signer, exporter)?;
499 let sig_bytes = sig.as_bytes();
500 let spki_len: u16 = spki
501 .len()
502 .try_into()
503 .map_err(|_| TrustError::ChannelBinding("spki too large"))?;
504 let sig_len: u16 = sig_bytes
505 .len()
506 .try_into()
507 .map_err(|_| TrustError::ChannelBinding("sig too large"))?;
508
509 let mut header = [0u8; 2 + 2 + 32];
511 header[0..2].copy_from_slice(&spki_len.to_be_bytes());
512 header[2..4].copy_from_slice(&sig_len.to_be_bytes());
513 header[4..36].copy_from_slice(exporter);
514 stream
515 .write_all(&header)
516 .await
517 .map_err(|_| TrustError::ChannelBinding("write header"))?;
518 stream
519 .write_all(sig_bytes)
520 .await
521 .map_err(|_| TrustError::ChannelBinding("write sig"))?;
522 stream
523 .write_all(spki)
524 .await
525 .map_err(|_| TrustError::ChannelBinding("write spki"))?;
526 stream
527 .shutdown()
528 .await
529 .map_err(|_| TrustError::ChannelBinding("finish"))?;
530 Ok(())
531}
532
533pub async fn recv_verify_binding(
535 conn: &Connection,
536 store: &dyn PinStore,
537 policy: &TransportPolicy,
538) -> Result<PeerId, TrustError> {
539 use tokio::io::AsyncReadExt;
540 let mut stream = conn
541 .accept_uni()
542 .await
543 .map_err(|_| TrustError::ChannelBinding("accept_uni"))?;
544
545 let mut header = [0u8; 2 + 2 + 32];
547 stream
548 .read_exact(&mut header)
549 .await
550 .map_err(|_| TrustError::ChannelBinding("read header"))?;
551 let spki_len = u16::from_be_bytes([header[0], header[1]]) as usize;
552 let sig_len = u16::from_be_bytes([header[2], header[3]]) as usize;
553 let mut exporter = [0u8; 32];
554 exporter.copy_from_slice(&header[4..36]);
555
556 let mut sig = vec![0u8; sig_len];
558 stream
559 .read_exact(&mut sig)
560 .await
561 .map_err(|_| TrustError::ChannelBinding("read sig"))?;
562
563 let mut spki = vec![0u8; spki_len];
565 stream
566 .read_exact(&mut spki)
567 .await
568 .map_err(|_| TrustError::ChannelBinding("read spki"))?;
569
570 verify_binding(store, policy, &spki, &exporter, &sig)
571}