use std::{sync::Arc, task::Waker};
use maybenot::{MachineId, TriggerEvent};
use smallvec::SmallVec;
use web_time_compat::InstantExt;
use super::{Bypass, Duration, Instant, PerHopPaddingEvent, PerHopPaddingEventVec, Replace};
type Rng = ThisThreadRng;
type Framework = maybenot::Framework<Arc<[maybenot::Machine]>, Rng, Instant>;
type TriggerAction = maybenot::TriggerAction<Instant>;
type TriggerEventsOutVec = SmallVec<[TriggerEvent; 1]>;
#[derive(Clone, Debug)]
enum ScheduledAction {
SendPadding {
bypass: bool,
replace: bool,
},
Block {
bypass: bool,
replace: bool,
duration: Duration,
},
}
#[derive(Default, Clone, Debug)]
struct MachineState {
internal_timer_expires: Option<Instant>,
action_timer_expires: Option<(Instant, ScheduledAction)>,
}
impl MachineState {
fn next_expiration(&self) -> Option<Instant> {
match (&self.internal_timer_expires, &self.action_timer_expires) {
(None, None) => None,
(None, Some((t, _))) => Some(*t),
(Some(t), None) => Some(*t),
(Some(t1), Some((t2, _))) => Some(*t1.min(t2)),
}
}
}
struct PadderState<const N: usize> {
state: SmallVec<[MachineState; N]>,
}
impl<const N: usize> PadderState<N> {
fn state_mut(&mut self, id: MachineId) -> &mut MachineState {
&mut self.state[id.into_raw()]
}
fn trigger_action(
&mut self,
action: &TriggerAction,
now: Instant,
events_out: &mut TriggerEventsOutVec,
) -> bool {
use maybenot::Timer as T;
use maybenot::TriggerAction as A;
let mut timer_changed = false;
match action {
A::Cancel { machine, timer } => {
let st = self.state_mut(*machine);
match timer {
T::Action => st.action_timer_expires = None,
T::Internal => st.internal_timer_expires = None,
T::All => {
st.action_timer_expires = None;
st.internal_timer_expires = None;
}
};
timer_changed = true;
}
A::SendPadding {
timeout,
bypass,
replace,
machine,
} => {
let st = self.state_mut(*machine);
st.action_timer_expires = Some((
now + *timeout,
ScheduledAction::SendPadding {
bypass: *bypass,
replace: *replace,
},
));
timer_changed = true;
}
A::BlockOutgoing {
timeout,
duration,
bypass,
replace,
machine,
} => {
let st = self.state_mut(*machine);
st.action_timer_expires = Some((
now + *timeout,
ScheduledAction::Block {
bypass: *bypass,
replace: *replace,
duration: *duration,
},
));
timer_changed = true;
}
A::UpdateTimer {
duration,
replace,
machine,
} => {
let st = self.state_mut(*machine);
let new_expiry = now + *duration;
let update_timer = match (replace, st.internal_timer_expires) {
(_, None) => true,
(true, Some(_)) => true,
(false, Some(cur)) if new_expiry > cur => true,
(false, Some(_)) => false,
};
if update_timer {
st.internal_timer_expires = Some(new_expiry);
timer_changed = true;
}
events_out.push(TriggerEvent::TimerBegin { machine: *machine });
}
}
timer_changed
}
fn next_expiration(&self) -> Option<Instant> {
self.state
.iter()
.filter_map(MachineState::next_expiration)
.min()
}
}
#[derive(Clone, Debug)]
struct Timer {
next_expiration: Option<Instant>,
waker: Waker,
}
impl Timer {
fn new() -> Self {
Self {
next_expiration: None,
waker: Waker::noop().clone(),
}
}
fn get_expiration(&mut self, waker: &Waker) -> Option<Instant> {
self.waker = waker.clone();
self.next_expiration
}
fn set_expiration(
&mut self,
new_expiration: Option<Instant>,
next_scheduled_wakeup: Option<Instant>,
) {
let wake = match (next_scheduled_wakeup, new_expiration) {
(_, None) => false,
(None, Some(_)) => true,
(Some(w_exp), Some(new_exp)) => new_exp < w_exp,
};
self.next_expiration = new_expiration;
if wake {
self.waker.wake_by_ref();
}
}
}
#[derive(Debug)]
struct BlockingState {
expiration: Instant,
}
pub(super) struct MaybenotPadder<const N: usize> {
framework: Framework,
state: PadderState<N>,
timer: Timer,
blocking: Option<BlockingState>,
}
impl<const N: usize> MaybenotPadder<N> {
pub(super) fn from_framework_rules(
rules: &super::PaddingRules,
) -> Result<Self, maybenot::Error> {
let framework = maybenot::Framework::new(
rules.machines.clone(),
rules.max_outbound_padding_frac,
rules.max_outbound_blocking_frac,
Instant::get(),
ThisThreadRng,
)?;
Ok(Self::from_framework(framework))
}
pub(super) fn from_framework(framework: Framework) -> Self {
let n = framework.num_machines();
let state = PadderState {
state: smallvec::smallvec![MachineState::default(); n],
};
Self {
framework,
state,
timer: Timer::new(),
blocking: None,
}
}
pub(super) fn get_expiration(&mut self, waker: &Waker) -> Option<Instant> {
self.timer.get_expiration(waker)
}
pub(super) fn trigger_events_at(
&mut self,
events: &[TriggerEvent],
now: Instant,
next_scheduled_wakeup: Option<Instant>,
) {
let mut timer_changed = false;
let (mut e1, mut e2) = (TriggerEventsOutVec::new(), TriggerEventsOutVec::new());
let (mut processing, mut pending) = (&mut e1, &mut e2);
let mut events = events;
const MAX_LOOPS: usize = 4;
let finished_normally = 'finished: {
for _ in 0..MAX_LOOPS {
pending.clear();
for action in self.framework.trigger_events(events, now) {
timer_changed |= self.state.trigger_action(action, now, pending);
}
if pending.is_empty() {
break 'finished true;
} else {
std::mem::swap(&mut processing, &mut pending);
events = &processing[..];
}
}
break 'finished false;
};
if !finished_normally {
}
if timer_changed {
self.timer
.set_expiration(self.state.next_expiration(), next_scheduled_wakeup);
}
}
fn take_actions_at(
&mut self,
now: Instant,
next_scheduled_wakeup: Option<Instant>,
) -> PerHopPaddingEventVec {
let mut e: SmallVec<[TriggerEvent; N]> = SmallVec::default();
let mut return_events = PerHopPaddingEventVec::default();
let mut timer_changed = false;
if let Some(blocking) = &self.blocking {
if blocking.expiration <= now {
timer_changed = true;
self.blocking = None;
e.push(TriggerEvent::BlockingEnd);
return_events.push(PerHopPaddingEvent::StopBlocking);
}
}
for (idx, st) in self.state.state.iter_mut().enumerate() {
match st.internal_timer_expires {
Some(t) if t <= now => {
st.internal_timer_expires = None;
timer_changed = true;
e.push(TriggerEvent::TimerEnd {
machine: MachineId::from_raw(idx),
});
}
None | Some(_) => {}
}
match &st.action_timer_expires {
Some((t, _)) if *t <= now => {
use ScheduledAction as SA;
let action = st
.action_timer_expires
.take()
.expect("It was Some a minute ago!")
.1;
timer_changed = true;
match action {
SA::SendPadding { bypass, replace } => {
return_events.push(PerHopPaddingEvent::SendPadding {
machine: MachineId::from_raw(idx),
replace: Replace::from_bool(replace),
bypass: Bypass::from_bool(bypass),
});
}
SA::Block {
bypass,
replace,
duration,
} => {
let new_expiry = now + duration;
if self.blocking.is_none() {
return_events.push(PerHopPaddingEvent::StartBlocking {
is_bypassable: bypass,
});
}
let replace = match &self.blocking {
None => true,
Some(b) if replace || b.expiration < new_expiry => true,
Some(_) => false,
};
if replace {
self.blocking = Some(BlockingState {
expiration: new_expiry,
});
}
e.push(TriggerEvent::BlockingBegin {
machine: MachineId::from_raw(idx),
});
}
}
}
None | Some(_) => {}
}
}
if timer_changed {
self.timer
.set_expiration(self.state.next_expiration(), next_scheduled_wakeup);
}
self.trigger_events_at(&e[..], now, next_scheduled_wakeup);
return_events
}
}
#[derive(Clone, Debug)]
pub(super) struct ThisThreadRng;
impl rand::RngCore for ThisThreadRng {
fn next_u32(&mut self) -> u32 {
rand::rng().next_u32()
}
fn next_u64(&mut self) -> u64 {
rand::rng().next_u64()
}
fn fill_bytes(&mut self, dst: &mut [u8]) {
rand::rng().fill_bytes(dst);
}
}
pub(super) trait PaddingBackend: Send + Sync {
fn report_events_at(
&mut self,
events: &[maybenot::TriggerEvent],
now: Instant,
next_scheduled_wakeup: Option<Instant>,
);
fn take_padding_events_at(
&mut self,
now: Instant,
next_scheduled_wakeup: Option<Instant>,
) -> PerHopPaddingEventVec;
fn next_wakeup(&mut self, waker: &Waker) -> Option<Instant>;
}
impl<const N: usize> PaddingBackend for MaybenotPadder<N> {
fn report_events_at(
&mut self,
events: &[maybenot::TriggerEvent],
now: Instant,
next_scheduled_wakeup: Option<Instant>,
) {
self.trigger_events_at(events, now, next_scheduled_wakeup);
}
fn take_padding_events_at(
&mut self,
now: Instant,
next_scheduled_wakeup: Option<Instant>,
) -> PerHopPaddingEventVec {
self.take_actions_at(now, next_scheduled_wakeup)
}
fn next_wakeup(&mut self, waker: &Waker) -> Option<Instant> {
self.get_expiration(waker)
}
}