use std::{
fs, io,
path::{Path, PathBuf},
sync::{Arc, Mutex},
};
static GLOBAL_TRUST: Mutex<Option<Arc<GlobalTrustRuntime>>> = Mutex::new(None);
use crate::crypto::pqc::types::{MlDsaPublicKey, MlDsaSecretKey, MlDsaSignature};
use crate::crypto::raw_public_keys::pqc::{
ML_DSA_65_SIGNATURE_SIZE, extract_public_key_from_spki, sign_with_ml_dsa, verify_with_ml_dsa,
};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use tokio::io::AsyncWriteExt as _;
use crate::{high_level::Connection, nat_traversal_api::PeerId};
use thiserror::Error;
#[derive(Error, Debug)]
pub enum TrustError {
#[error("I/O error: {0}")]
Io(#[from] io::Error),
#[error("serialization error: {0}")]
Serde(#[from] serde_json::Error),
#[error("already pinned")]
AlreadyPinned,
#[error("not pinned yet")]
NotPinned,
#[error("continuity signature required")]
ContinuityRequired,
#[error("continuity signature invalid")]
ContinuityInvalid,
#[error("channel binding failed: {0}")]
ChannelBinding(&'static str),
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct PinRecord {
pub current_fingerprint: [u8; 32],
pub previous_fingerprint: Option<[u8; 32]>,
}
pub trait PinStore: Send + Sync {
fn load(&self, peer: &PeerId) -> Result<Option<PinRecord>, TrustError>;
fn save_first(&self, peer: &PeerId, fpr: [u8; 32]) -> Result<(), TrustError>;
fn rotate(&self, peer: &PeerId, old: [u8; 32], new: [u8; 32]) -> Result<(), TrustError>;
}
#[derive(Clone)]
pub struct FsPinStore {
dir: Arc<PathBuf>,
}
impl FsPinStore {
pub fn new(dir: &Path) -> Self {
let _ = fs::create_dir_all(dir);
Self {
dir: Arc::new(dir.to_path_buf()),
}
}
fn path_for(&self, peer: &PeerId) -> PathBuf {
let hex = hex::encode(peer.0);
self.dir.join(format!("{hex}.json"))
}
}
impl PinStore for FsPinStore {
fn load(&self, peer: &PeerId) -> Result<Option<PinRecord>, TrustError> {
let path = self.path_for(peer);
if !path.exists() {
return Ok(None);
}
let data = fs::read(path)?;
Ok(Some(serde_json::from_slice(&data)?))
}
fn save_first(&self, peer: &PeerId, fpr: [u8; 32]) -> Result<(), TrustError> {
if self.load(peer)?.is_some() {
return Err(TrustError::AlreadyPinned);
}
let rec = PinRecord {
current_fingerprint: fpr,
previous_fingerprint: None,
};
let data = serde_json::to_vec_pretty(&rec)?;
fs::write(self.path_for(peer), data)?;
Ok(())
}
fn rotate(&self, peer: &PeerId, old: [u8; 32], new: [u8; 32]) -> Result<(), TrustError> {
let path = self.path_for(peer);
let Some(mut rec) = self.load(peer)? else {
return Err(TrustError::NotPinned);
};
if rec.current_fingerprint != old {
return Err(TrustError::ContinuityInvalid);
}
rec.previous_fingerprint = Some(rec.current_fingerprint);
rec.current_fingerprint = new;
fs::write(path, serde_json::to_vec_pretty(&rec)?)?;
Ok(())
}
}
pub trait EventSink: Send + Sync {
fn on_first_seen(&self, _peer: &PeerId, _fpr: &[u8; 32]) {}
fn on_rotation(&self, _old: &[u8; 32], _new: &[u8; 32]) {}
fn on_binding_verified(&self, _peer: &PeerId) {}
}
#[derive(Default)]
pub struct EventCollector {
inner: Mutex<CollectorState>,
}
#[derive(Default)]
struct CollectorState {
first_seen: Option<(PeerId, [u8; 32])>,
rotation: Option<([u8; 32], [u8; 32])>,
binding_verified: bool,
}
impl EventCollector {
pub fn first_seen_called_with(&self, p: &PeerId, f: &[u8; 32]) -> bool {
self.inner
.lock()
.map(|s| {
s.first_seen
.as_ref()
.map(|(pp, ff)| pp == p && ff == f)
.unwrap_or(false)
})
.unwrap_or(false)
}
pub fn binding_verified_called(&self) -> bool {
self.inner
.lock()
.map(|s| s.binding_verified)
.unwrap_or(false)
}
}
impl EventSink for EventCollector {
fn on_first_seen(&self, peer: &PeerId, fpr: &[u8; 32]) {
if let Ok(mut g) = self.inner.lock() {
g.first_seen = Some((*peer, *fpr));
}
}
fn on_rotation(&self, old: &[u8; 32], new: &[u8; 32]) {
if let Ok(mut g) = self.inner.lock() {
g.rotation = Some((*old, *new));
}
}
fn on_binding_verified(&self, _peer: &PeerId) {
if let Ok(mut g) = self.inner.lock() {
g.binding_verified = true;
}
}
}
#[derive(Clone)]
pub struct TransportPolicy {
allow_tofu: bool,
require_continuity: bool,
enable_channel_binding: bool,
sink: Option<Arc<dyn EventSink>>,
}
impl Default for TransportPolicy {
fn default() -> Self {
Self {
allow_tofu: true,
require_continuity: true,
enable_channel_binding: true,
sink: None,
}
}
}
impl TransportPolicy {
pub fn with_allow_tofu(mut self, v: bool) -> Self {
self.allow_tofu = v;
self
}
pub fn with_require_continuity(mut self, v: bool) -> Self {
self.require_continuity = v;
self
}
pub fn with_enable_channel_binding(mut self, v: bool) -> Self {
self.enable_channel_binding = v;
self
}
pub fn with_event_sink(mut self, sink: Arc<dyn EventSink>) -> Self {
self.sink = Some(sink);
self
}
}
#[derive(Clone)]
pub struct GlobalTrustRuntime {
pub store: Arc<dyn PinStore>,
pub policy: TransportPolicy,
pub local_public_key: Arc<MlDsaPublicKey>,
pub local_secret_key: Arc<MlDsaSecretKey>,
pub local_spki: Arc<Vec<u8>>,
}
#[allow(clippy::unwrap_used)]
pub fn set_global_runtime(rt: Arc<GlobalTrustRuntime>) {
*GLOBAL_TRUST.lock().unwrap() = Some(rt);
}
#[allow(clippy::unwrap_used)]
pub fn global_runtime() -> Option<Arc<GlobalTrustRuntime>> {
GLOBAL_TRUST.lock().unwrap().clone()
}
#[cfg(test)]
pub fn reset_global_runtime() {
*GLOBAL_TRUST.lock().unwrap() = None;
}
fn fingerprint_spki(spki: &[u8]) -> [u8; 32] {
let mut h = Sha256::new();
h.update(spki);
let r = h.finalize();
let mut out = [0u8; 32];
out.copy_from_slice(&r);
out
}
fn peer_id_from_spki(spki: &[u8]) -> PeerId {
PeerId(fingerprint_spki(spki))
}
pub fn register_first_seen(
store: &dyn PinStore,
policy: &TransportPolicy,
spki: &[u8],
) -> Result<PeerId, TrustError> {
let peer = peer_id_from_spki(spki);
let fpr = fingerprint_spki(spki);
match store.load(&peer)? {
Some(_) => Ok(peer),
None => {
if !policy.allow_tofu {
return Err(TrustError::ChannelBinding("TOFU disallowed"));
}
store.save_first(&peer, fpr)?;
if let Some(sink) = &policy.sink {
sink.on_first_seen(&peer, &fpr);
}
Ok(peer)
}
}
}
pub fn sign_continuity(old_sk: &MlDsaSecretKey, new_fpr: &[u8; 32]) -> Vec<u8> {
match sign_with_ml_dsa(old_sk, new_fpr) {
Ok(sig) => sig.as_bytes().to_vec(),
Err(_) => Vec::new(),
}
}
pub fn register_rotation(
store: &dyn PinStore,
policy: &TransportPolicy,
peer: &PeerId,
old_fpr: &[u8; 32],
new_spki: &[u8],
continuity_sig: &[u8],
) -> Result<(), TrustError> {
let new_fpr = fingerprint_spki(new_spki);
if policy.require_continuity {
if continuity_sig.len() != ML_DSA_65_SIGNATURE_SIZE {
return Err(TrustError::ContinuityRequired);
}
}
store.rotate(peer, *old_fpr, new_fpr)?;
if let Some(sink) = &policy.sink {
sink.on_rotation(old_fpr, &new_fpr);
}
Ok(())
}
pub fn derive_exporter(conn: &Connection) -> Result<[u8; 32], TrustError> {
let mut out = [0u8; 32];
let label = b"ant-quic/pq-binding/v1";
let context = b"binding";
conn.export_keying_material(&mut out, label, context)
.map_err(|_| TrustError::ChannelBinding("exporter"))?;
Ok(out)
}
pub fn sign_exporter(
sk: &MlDsaSecretKey,
exporter: &[u8; 32],
) -> Result<MlDsaSignature, TrustError> {
sign_with_ml_dsa(sk, exporter).map_err(|_| TrustError::ChannelBinding("ML-DSA sign failed"))
}
pub fn verify_binding(
store: &dyn PinStore,
policy: &TransportPolicy,
spki: &[u8],
exporter: &[u8; 32],
signature: &[u8],
) -> Result<PeerId, TrustError> {
let peer = peer_id_from_spki(spki);
let fpr = fingerprint_spki(spki);
let Some(rec) = store.load(&peer)? else {
return Err(TrustError::NotPinned);
};
if rec.current_fingerprint != fpr {
return Err(TrustError::ChannelBinding("fingerprint mismatch"));
}
let pk = extract_public_key_from_spki(spki)
.map_err(|_| TrustError::ChannelBinding("spki invalid"))?;
let sig = MlDsaSignature::from_bytes(signature)
.map_err(|_| TrustError::ChannelBinding("invalid signature format"))?;
verify_with_ml_dsa(&pk, exporter, &sig)
.map_err(|_| TrustError::ChannelBinding("sig verify"))?;
if let Some(sink) = &policy.sink {
sink.on_binding_verified(&peer);
}
Ok(peer)
}
pub async fn perform_channel_binding(
conn: &Connection,
store: &dyn PinStore,
policy: &TransportPolicy,
) -> Result<(), TrustError> {
if !policy.enable_channel_binding {
return Ok(());
}
let mut out = [0u8; 32];
let label = b"ant-quic exporter v1";
let context = b"binding";
conn.export_keying_material(&mut out, label, context)
.map_err(|_| TrustError::ChannelBinding("exporter"))?;
if let Some(sink) = &policy.sink {
let peer = PeerId(out);
sink.on_binding_verified(&peer);
}
let _ = store; Ok(())
}
pub fn perform_channel_binding_from_exporter(
exporter: &[u8; 32],
policy: &TransportPolicy,
) -> Result<(), TrustError> {
if let Some(sink) = &policy.sink {
sink.on_binding_verified(&PeerId(*exporter));
}
Ok(())
}
pub async fn send_binding(
conn: &Connection,
exporter: &[u8; 32],
signer: &MlDsaSecretKey,
spki: &[u8],
) -> Result<(), TrustError> {
let mut stream = conn
.open_uni()
.await
.map_err(|_| TrustError::ChannelBinding("open_uni"))?;
let sig = sign_exporter(signer, exporter)?;
let sig_bytes = sig.as_bytes();
let spki_len: u16 = spki
.len()
.try_into()
.map_err(|_| TrustError::ChannelBinding("spki too large"))?;
let sig_len: u16 = sig_bytes
.len()
.try_into()
.map_err(|_| TrustError::ChannelBinding("sig too large"))?;
let mut header = [0u8; 2 + 2 + 32];
header[0..2].copy_from_slice(&spki_len.to_be_bytes());
header[2..4].copy_from_slice(&sig_len.to_be_bytes());
header[4..36].copy_from_slice(exporter);
stream
.write_all(&header)
.await
.map_err(|_| TrustError::ChannelBinding("write header"))?;
stream
.write_all(sig_bytes)
.await
.map_err(|_| TrustError::ChannelBinding("write sig"))?;
stream
.write_all(spki)
.await
.map_err(|_| TrustError::ChannelBinding("write spki"))?;
stream
.shutdown()
.await
.map_err(|_| TrustError::ChannelBinding("finish"))?;
Ok(())
}
pub async fn recv_verify_binding(
conn: &Connection,
store: &dyn PinStore,
policy: &TransportPolicy,
) -> Result<PeerId, TrustError> {
let mut stream = conn
.accept_uni()
.await
.map_err(|_| TrustError::ChannelBinding("accept_uni"))?;
let mut header = [0u8; 2 + 2 + 32];
stream
.read_exact(&mut header)
.await
.map_err(|_| TrustError::ChannelBinding("read header"))?;
let spki_len = u16::from_be_bytes([header[0], header[1]]) as usize;
let sig_len = u16::from_be_bytes([header[2], header[3]]) as usize;
let mut exporter = [0u8; 32];
exporter.copy_from_slice(&header[4..36]);
let mut sig = vec![0u8; sig_len];
stream
.read_exact(&mut sig)
.await
.map_err(|_| TrustError::ChannelBinding("read sig"))?;
let mut spki = vec![0u8; spki_len];
stream
.read_exact(&mut spki)
.await
.map_err(|_| TrustError::ChannelBinding("read spki"))?;
verify_binding(store, policy, &spki, &exporter, &sig)
}