use futures::{FutureExt, future::pending};
use maybenot::MachineId;
use std::{
collections::VecDeque,
sync::{
Arc,
atomic::{self, AtomicU32},
},
};
use tokio::{
sync::{
Notify, RwLock,
mpsc::{self, error::TrySendError},
},
time::Instant,
};
use zerocopy::{Immutable, IntoBytes, KnownLayout, TryFromBytes, Unaligned, big_endian};
use crate::packet::{Packet, WgData};
pub(crate) enum ErrorAction {
Close,
Ignore(IgnoreReason),
}
pub(crate) enum IgnoreReason {
NoEndpoint,
NoSession,
}
impl std::fmt::Display for IgnoreReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
IgnoreReason::NoEndpoint => write!(f, "No endpoint"),
IgnoreReason::NoSession => write!(f, "No session"),
}
}
}
pub(crate) type Result<T> = std::result::Result<T, ErrorAction>;
#[derive(TryFromBytes, KnownLayout, Unaligned, Immutable, PartialEq, Eq)]
#[repr(C)]
pub(crate) struct DecoyPacket {
pub(crate) header: DecoyHeader,
pub(crate) payload: [u8],
}
#[derive(
TryFromBytes, IntoBytes, KnownLayout, Unaligned, Immutable, PartialEq, Eq, Clone, Copy,
)]
#[repr(C, packed)]
pub(crate) struct DecoyHeader {
pub marker: DecoyMarker,
_reserved: u8,
pub length: big_endian::U16,
}
#[derive(
TryFromBytes, IntoBytes, KnownLayout, Unaligned, Immutable, PartialEq, Eq, Clone, Copy,
)]
#[repr(u8)]
pub(crate) enum DecoyMarker {
Decoy = 0xff,
}
impl DecoyHeader {
pub(crate) const fn new(length: big_endian::U16) -> Self {
Self {
marker: DecoyMarker::Decoy,
_reserved: 0,
length,
}
}
}
#[derive(Default)]
pub(crate) struct PacketCount {
outbound_normal: AtomicU32,
}
impl PacketCount {
pub fn dec(&self, amount: u32) {
self.outbound_normal
.fetch_sub(amount, atomic::Ordering::SeqCst);
}
pub fn inc(&self, amount: u32) {
self.outbound_normal
.fetch_add(amount, atomic::Ordering::SeqCst);
}
pub fn outbound(&self) -> u32 {
self.outbound_normal.load(atomic::Ordering::SeqCst)
}
}
#[derive(Clone, Copy)]
pub(crate) enum DelayState {
Inactive,
Active { bypass: bool, expires_at: Instant },
}
impl DelayState {
#[must_use]
pub(crate) fn is_active(&self) -> bool {
matches!(self, Self::Active { .. })
}
}
#[derive(Clone)]
pub struct DelayWatcher {
pub(super) delay_queue_tx: mpsc::Sender<Packet<WgData>>,
pub(super) delay_state: Arc<RwLock<DelayState>>,
delay_abort: Arc<Notify>,
min_delay_capacity: usize,
}
impl DelayWatcher {
pub fn new(delay_queue_tx: mpsc::Sender<Packet<WgData>>, min_delay_capacity: usize) -> Self {
let delay_state = Arc::new(RwLock::new(DelayState::Inactive));
let delay_abort = Arc::new(Notify::const_new());
Self {
delay_queue_tx,
delay_state,
delay_abort,
min_delay_capacity,
}
}
pub async fn wait_delay_ended(&self) {
if let DelayState::Active { expires_at, .. } = &*self.delay_state.read().await {
futures::select! {
() = tokio::time::sleep_until(*expires_at).fuse() => {},
() = self.delay_abort.notified().fuse() => {
log::trace!("Delay aborted with remaining capacity {}", self.delay_queue_tx.capacity());
},
}
} else {
pending().await
}
}
pub fn maybe_delay_packet(&self, packet: Packet<WgData>) -> Option<Packet<WgData>> {
if let Ok(delay) = self.delay_state.try_read()
&& delay.is_active()
{
if self.delay_queue_tx.capacity() < self.min_delay_capacity {
self.delay_abort.notify_one();
}
if let Err(TrySendError::Closed(packet) | TrySendError::Full(packet)) =
self.delay_queue_tx.try_send(packet)
{
log::trace!("Packet sent as it couldn't be delayed");
Some(packet)
} else {
None
}
} else {
Some(packet)
}
}
}
pub(crate) struct MachineTimers(VecDeque<(Instant, MachineId, MachineTimer)>);
#[derive(Clone, Copy, Debug)]
pub(crate) enum Action {
Decoy {
replace: bool,
bypass: bool,
},
Delay {
replace: bool,
bypass: bool,
duration: std::time::Duration,
},
}
#[derive(Clone, Copy, Debug)]
pub(crate) enum MachineTimer {
Internal,
Action(Action),
}
impl MachineTimers {
pub(crate) fn new(cap: usize) -> Self {
Self(VecDeque::with_capacity(cap))
}
pub(crate) fn remove_action(&mut self, machine: &MachineId) {
self.0
.retain(|&(_, m, t)| !(m == *machine && matches!(t, MachineTimer::Action(_))));
}
pub(crate) fn remove_internal(&mut self, machine: &MachineId) {
self.0
.retain(|&(_, m, t)| !(m == *machine && matches!(t, MachineTimer::Internal)));
}
pub(crate) fn remove_all(&mut self, machine: &MachineId) {
self.0.retain(|&(_, m, _)| m != *machine);
}
pub(crate) fn schedule_decoy(
&mut self,
machine: MachineId,
timeout: std::time::Duration,
replace: bool,
bypass: bool,
) {
self.remove_action(&machine);
let expiration_time = Instant::now() + timeout;
let insert_at = self
.0
.binary_search_by_key(&expiration_time, |&(time, _, _)| time)
.unwrap_or_else(|e| e);
self.0.insert(
insert_at,
(
expiration_time,
machine,
MachineTimer::Action(Action::Decoy { replace, bypass }),
),
);
debug_assert!(self.0.iter().is_sorted_by_key(|(time, _, _)| *time));
}
pub(crate) fn schedule_delay(
&mut self,
machine: MachineId,
timeout: std::time::Duration,
duration: std::time::Duration,
replace: bool,
bypass: bool,
) {
self.remove_action(&machine);
let expiration_time = Instant::now() + timeout;
let insert_at = self
.0
.binary_search_by_key(&expiration_time, |&(time, _, _)| time)
.unwrap_or_else(|e| e);
self.0.insert(
insert_at,
(
expiration_time,
machine,
MachineTimer::Action(Action::Delay {
replace,
bypass,
duration,
}),
),
);
debug_assert!(self.0.iter().is_sorted_by_key(|(time, _, _)| *time));
}
pub(crate) fn schedule_internal_timer(
&mut self,
machine: MachineId,
duration: std::time::Duration,
replace: bool,
) -> bool {
let expiry = Instant::now() + duration;
let idx = self
.0
.iter()
.position(|&(_, m, t)| m == machine && matches!(t, MachineTimer::Internal));
let should_update = match idx {
Some(i) => {
let (cur_expiry, _, _) = self.0[i];
if replace || expiry > cur_expiry {
self.0.remove(i);
true
} else {
false
}
}
None => true,
};
if should_update {
let insert_at = self
.0
.binary_search_by_key(&expiry, |&(time, _, _)| time)
.unwrap_or_else(|e| e);
self.0
.insert(insert_at, (expiry, machine, MachineTimer::Internal));
debug_assert!(self.0.iter().is_sorted_by_key(|(time, _, _)| *time));
}
should_update
}
pub(crate) async fn wait_next_timer(&mut self) -> (MachineId, MachineTimer) {
if let Some((time, _, _)) = self.0.front() {
tokio::time::sleep_until(*time).await;
self.0
.pop_front()
.map(|(_, m, t)| (m, t))
.expect("Front exists because we peeked it")
} else {
futures::future::pending().await
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::device::daita::types;
#[test]
fn test_machine_timers_schedule_and_remove() {
let mut timers = types::MachineTimers::new(4);
let machine = MachineId::from_raw(1);
timers.schedule_decoy(machine, std::time::Duration::from_secs(1), false, false);
assert_eq!(timers.0.len(), 1);
timers.schedule_internal_timer(machine, std::time::Duration::from_secs(1), false);
assert_eq!(timers.0.len(), 2);
timers.remove_action(&machine);
assert_eq!(timers.0.len(), 1);
timers.remove_internal(&machine);
assert_eq!(timers.0.len(), 0);
}
#[test]
fn test_internal_machine_timer_replace() {
let mut timers = types::MachineTimers::new(4);
let machine = MachineId::from_raw(1);
timers.schedule_internal_timer(machine, std::time::Duration::from_secs(1), false);
timers.schedule_internal_timer(machine, std::time::Duration::from_secs(2), false);
assert_eq!(timers.0.len(), 1);
let i = timers.0.front().unwrap().0;
assert!(
i.duration_since(Instant::now()) > std::time::Duration::from_secs(1),
"The longer timer should be kept"
);
timers.schedule_internal_timer(machine, std::time::Duration::from_secs(2), false);
timers.schedule_internal_timer(machine, std::time::Duration::from_secs(1), false);
assert_eq!(timers.0.len(), 1);
let i = timers.0.front().unwrap().0;
assert!(
i.duration_since(Instant::now()) > std::time::Duration::from_secs(1),
"The longer timer should be kept"
);
timers.schedule_internal_timer(machine, std::time::Duration::from_secs(2), true);
timers.schedule_internal_timer(machine, std::time::Duration::from_secs(1), true);
assert_eq!(timers.0.len(), 1);
let i = timers.0.front().unwrap().0;
assert!(
i.duration_since(Instant::now()) < std::time::Duration::from_secs(2),
"The last timer should be kept"
);
}
}