use std::sync::atomic::{AtomicU32, AtomicU8, Ordering};
use std::time::Instant;
use dashmap::DashMap;
use parking_lot::{Mutex, RwLock};
use subtle::ConstantTimeEq;
use crate::crypto::rng::{OsRng, RngProvider};
pub const PATH_CHALLENGE_LEN: usize = 32;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PathStateKind {
Unvalidated,
Validating,
Validated,
Failed,
}
pub struct PathState {
pub path_id: u8,
state: AtomicU8, pub rtt_ms: AtomicU32,
pub loss_pct: AtomicU8,
pub last_packet_seen: RwLock<Option<Instant>>,
pending_challenge: Mutex<Option<[u8; PATH_CHALLENGE_LEN]>>,
}
impl PathState {
fn new(path_id: u8) -> Self {
Self {
path_id,
state: AtomicU8::new(PathStateKind::Unvalidated as u8),
rtt_ms: AtomicU32::new(0),
loss_pct: AtomicU8::new(0),
last_packet_seen: RwLock::new(None),
pending_challenge: Mutex::new(None),
}
}
pub fn state(&self) -> PathStateKind {
match self.state.load(Ordering::Acquire) {
0 => PathStateKind::Unvalidated,
1 => PathStateKind::Validating,
2 => PathStateKind::Validated,
3 => PathStateKind::Failed,
_ => PathStateKind::Failed,
}
}
fn set_state(&self, new: PathStateKind) {
self.state.store(new as u8, Ordering::Release);
}
pub fn mark_seen(&self) {
*self.last_packet_seen.write() = Some(Instant::now());
}
}
pub struct PathRegistry {
paths: DashMap<u8, PathState>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RegistrationResult {
Created,
AlreadyKnown,
}
impl Default for PathRegistry {
fn default() -> Self {
Self::new()
}
}
impl PathRegistry {
pub fn new() -> Self {
Self {
paths: DashMap::new(),
}
}
pub fn register(&self, path_id: u8) -> RegistrationResult {
let mut created = false;
self.paths.entry(path_id).or_insert_with(|| {
created = true;
PathState::new(path_id)
});
if created {
RegistrationResult::Created
} else {
RegistrationResult::AlreadyKnown
}
}
pub fn register_validated(&self, path_id: u8) -> RegistrationResult {
let mut created = false;
self.paths.entry(path_id).or_insert_with(|| {
created = true;
let p = PathState::new(path_id);
p.set_state(PathStateKind::Validated);
p
});
if created {
RegistrationResult::Created
} else {
RegistrationResult::AlreadyKnown
}
}
pub fn mark_seen(&self, path_id: u8) {
if let Some(p) = self.paths.get(&path_id) {
p.mark_seen();
}
}
pub fn issue_challenge(&self, path_id: u8) -> Option<[u8; PATH_CHALLENGE_LEN]> {
let path = self.paths.get(&path_id)?;
match path.state() {
PathStateKind::Unvalidated | PathStateKind::Validating => {
}
PathStateKind::Validated | PathStateKind::Failed => return None,
}
let mut pending = path.pending_challenge.lock();
if let Some(existing) = *pending {
return Some(existing);
}
let mut challenge = [0u8; PATH_CHALLENGE_LEN];
OsRng.fill_bytes(&mut challenge);
*pending = Some(challenge);
drop(pending);
path.set_state(PathStateKind::Validating);
Some(challenge)
}
pub fn verify_response(&self, path_id: u8, response: &[u8]) -> bool {
let path = match self.paths.get(&path_id) {
Some(p) => p,
None => return false,
};
if response.len() != PATH_CHALLENGE_LEN {
return false;
}
if path.state() != PathStateKind::Validating {
return false;
}
let mut guard = path.pending_challenge.lock();
let expected = match guard.take() {
Some(e) => e,
None => {
drop(guard);
path.set_state(PathStateKind::Failed);
return false;
}
};
drop(guard);
let matched: bool = expected.ct_eq(response).into();
if matched {
path.set_state(PathStateKind::Validated);
true
} else {
path.set_state(PathStateKind::Failed);
false
}
}
pub fn state(&self, path_id: u8) -> Option<PathStateKind> {
self.paths.get(&path_id).map(|p| p.state())
}
pub fn validated_paths(&self) -> Vec<u8> {
self.paths
.iter()
.filter(|p| p.state() == PathStateKind::Validated)
.map(|p| *p.key())
.collect()
}
pub fn len(&self) -> usize {
self.paths.len()
}
pub fn is_empty(&self) -> bool {
self.paths.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn register_new_path_returns_created() {
let r = PathRegistry::new();
assert_eq!(r.register(7), RegistrationResult::Created);
assert_eq!(r.register(7), RegistrationResult::AlreadyKnown);
}
#[test]
fn freshly_registered_path_is_unvalidated() {
let r = PathRegistry::new();
r.register(1);
assert_eq!(r.state(1), Some(PathStateKind::Unvalidated));
}
#[test]
fn issue_challenge_transitions_to_validating() {
let r = PathRegistry::new();
r.register(1);
let challenge = r.issue_challenge(1).expect("challenge issued");
assert_eq!(challenge.len(), PATH_CHALLENGE_LEN);
assert_eq!(r.state(1), Some(PathStateKind::Validating));
}
#[test]
fn reissue_on_validating_path_returns_same_challenge() {
let r = PathRegistry::new();
r.register(1);
let first = r.issue_challenge(1).expect("first challenge");
let second = r.issue_challenge(1).expect("re-issue returns existing");
assert_eq!(
first, second,
"re-issue must not clobber the in-flight challenge"
);
assert!(r.verify_response(1, &first));
assert_eq!(r.state(1), Some(PathStateKind::Validated));
}
#[test]
fn matching_response_transitions_to_validated() {
let r = PathRegistry::new();
r.register(1);
let challenge = r.issue_challenge(1).expect("challenge");
assert!(r.verify_response(1, &challenge));
assert_eq!(r.state(1), Some(PathStateKind::Validated));
}
#[test]
fn mismatched_response_transitions_to_failed() {
let r = PathRegistry::new();
r.register(1);
let mut challenge = r.issue_challenge(1).expect("challenge");
challenge[0] ^= 0xFF; assert!(!r.verify_response(1, &challenge));
assert_eq!(r.state(1), Some(PathStateKind::Failed));
}
#[test]
fn response_without_challenge_fails() {
let r = PathRegistry::new();
r.register(1);
let zeros = [0u8; PATH_CHALLENGE_LEN];
assert!(!r.verify_response(1, &zeros));
assert_eq!(r.state(1), Some(PathStateKind::Unvalidated));
}
#[test]
fn response_for_wrong_length_fails() {
let r = PathRegistry::new();
r.register(1);
let _ = r.issue_challenge(1);
assert!(!r.verify_response(1, &[0u8; 16])); assert_eq!(r.state(1), Some(PathStateKind::Validating));
}
#[test]
fn issue_challenge_on_unknown_path_returns_none() {
let r = PathRegistry::new();
assert!(r.issue_challenge(99).is_none());
}
#[test]
fn validated_paths_lists_only_validated() {
let r = PathRegistry::new();
for p in 0..5 {
r.register(p);
}
for p in [1u8, 3].iter().copied() {
let c = r.issue_challenge(p).unwrap();
assert!(r.verify_response(p, &c));
}
let mut c = r.issue_challenge(2).unwrap();
c[0] ^= 1;
assert!(!r.verify_response(2, &c));
r.issue_challenge(4);
let mut validated = r.validated_paths();
validated.sort();
assert_eq!(validated, vec![1, 3]);
}
#[test]
fn mark_seen_updates_last_packet_timestamp() {
let r = PathRegistry::new();
r.register(1);
let before = Instant::now();
std::thread::sleep(std::time::Duration::from_millis(2));
r.mark_seen(1);
let path = r.paths.get(&1).unwrap();
let seen = path.last_packet_seen.read().expect("set");
assert!(seen >= before);
}
#[test]
fn re_validating_terminal_path_returns_none() {
let r = PathRegistry::new();
r.register(1);
let c = r.issue_challenge(1).unwrap();
assert!(r.verify_response(1, &c));
assert!(r.issue_challenge(1).is_none());
r.register(2);
let mut c2 = r.issue_challenge(2).unwrap();
c2[0] ^= 1;
assert!(!r.verify_response(2, &c2)); assert!(r.issue_challenge(2).is_none());
}
}