use std::{collections::HashMap, convert::Infallible, ops::DerefMut, sync::atomic::AtomicU32};
use tokio::sync::{Mutex, mpsc};
use ts_packet::PacketMut;
use ts_transport::{OverlayTransportId, PeerId, UnderlayTransportId};
use ts_tunnel::NodeKeyPair;
use crate::{EventResult, InboundResult, OutboundResult};
pub type DataplaneToOverlay = mpsc::UnboundedSender<Vec<PacketMut>>;
pub type DataplaneFromOverlay = mpsc::UnboundedReceiver<Vec<PacketMut>>;
pub type DataplaneToUnderlay = mpsc::UnboundedSender<(PeerId, Vec<PacketMut>)>;
pub type DataplaneFromUnderlay = mpsc::UnboundedReceiver<(PeerId, Vec<PacketMut>)>;
pub struct DataPlane {
core_state: Mutex<CoreState>,
poll_state: Mutex<PollState>,
transports_changed: tokio::sync::Notify,
underlay_down: DataplaneToUnderlay,
overlay_up: DataplaneToOverlay,
next_underlay_transport: AtomicU32,
next_overlay_transport: AtomicU32,
}
struct CoreState {
sync: crate::DataPlane,
overlay_transports: HashMap<OverlayTransportId, DataplaneToOverlay>,
underlay_transports: HashMap<UnderlayTransportId, DataplaneToUnderlay>,
}
struct PollState {
from_overlay: DataplaneFromOverlay,
from_underlay: DataplaneFromUnderlay,
}
impl DataPlane {
pub fn new(my_key: NodeKeyPair) -> Self {
let (overlay_up, overlay_down) = mpsc::unbounded_channel();
let (underlay_down, underlay_up) = mpsc::unbounded_channel();
let sync = crate::DataPlane::new(my_key);
Self {
underlay_down,
overlay_up,
next_overlay_transport: Default::default(),
next_underlay_transport: Default::default(),
transports_changed: tokio::sync::Notify::new(),
core_state: Mutex::new(CoreState {
sync,
overlay_transports: Default::default(),
underlay_transports: Default::default(),
}),
poll_state: Mutex::new(PollState {
from_overlay: overlay_down,
from_underlay: underlay_up,
}),
}
}
pub async fn new_underlay_transport(
&self,
) -> (
UnderlayTransportId,
DataplaneFromUnderlay,
DataplaneToUnderlay,
) {
let id = self
.next_underlay_transport
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
.into();
let (tx, rx) = mpsc::unbounded_channel();
{
let mut rest = self.core_state.lock().await;
rest.underlay_transports.insert(id, tx);
}
self.transports_changed.notify_waiters();
(id, rx, self.underlay_down.clone())
}
pub async fn new_overlay_transport(
&self,
) -> (OverlayTransportId, DataplaneToOverlay, DataplaneFromOverlay) {
let id = self
.next_overlay_transport
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
.into();
let (tx, rx) = mpsc::unbounded_channel();
{
let mut rest = self.core_state.lock().await;
rest.overlay_transports.insert(id, tx);
}
self.transports_changed.notify_waiters();
(id, self.overlay_up.clone(), rx)
}
pub async fn run(&self) -> Infallible {
loop {
self.step().await;
}
}
#[tracing::instrument(skip_all)]
pub async fn step(&self) {
enum SelectResult {
OverlayDown(Vec<PacketMut>),
UnderlayUp(PeerId, Vec<PacketMut>),
TransportsChanged,
Event,
}
let select_result = {
let next_event = {
let state = self.core_state.lock().await;
state.sync.next_event()
};
let mut poll_state = self.poll_state.lock().await;
let PollState {
from_overlay: overlay_down,
from_underlay: underlay_up,
..
} = &mut *poll_state;
tokio::select! {
overlay_pkts = overlay_down.recv() => {
let overlay_pkts = overlay_pkts.unwrap();
tracing::trace!(n_overlay_pkts = overlay_pkts.len());
SelectResult::OverlayDown(overlay_pkts)
}
underlay_pkts = underlay_up.recv() => {
let (peer_id, underlay_pkts) = underlay_pkts.unwrap();
tracing::trace!(%peer_id, n_underlay_pkts = underlay_pkts.len());
SelectResult::UnderlayUp(peer_id, underlay_pkts)
}
_ = self.transports_changed.notified() => {
tracing::trace!("transports changed");
SelectResult::TransportsChanged
}
_ = sleep_until_event(next_event.map(Into::into)) => {
tracing::trace!("event");
SelectResult::Event
}
}
};
let mut core = self.core_state.lock().await;
let (to_peers, to_local) = match select_result {
SelectResult::OverlayDown(overlay_down) => {
let OutboundResult { to_peers, loopback } =
core.sync.process_outbound(overlay_down);
(Some(to_peers), Some(loopback))
}
SelectResult::UnderlayUp(_peer_id, underlay_up) => {
let InboundResult { to_local, to_peers } = core.sync.process_inbound(underlay_up);
(Some(to_peers), Some(to_local))
}
SelectResult::Event => {
let EventResult { to_peers } = core.sync.process_events();
(Some(to_peers), None)
}
SelectResult::TransportsChanged => (None, None),
};
if let Some(to_peers) = to_peers {
write_to_underlay(&core, to_peers).await;
}
if let Some(to_local) = to_local {
write_to_overlay(&core, to_local).await;
}
}
pub async fn inner(&self) -> impl DerefMut<Target = crate::DataPlane> {
let core = self.core_state.lock().await;
tokio::sync::MutexGuard::map(core, |x| &mut x.sync)
}
}
async fn write_to_overlay(slf: &CoreState, packets: HashMap<OverlayTransportId, Vec<PacketMut>>) {
for (id, packets) in packets {
if let Some(queue) = slf.overlay_transports.get(&id) {
tracing::trace!(overlay_id = ?id, n_packets = packets.len());
queue.send(packets).unwrap();
}
}
}
async fn write_to_underlay(
slf: &CoreState,
packets: impl IntoIterator<Item = ((UnderlayTransportId, PeerId), Vec<PacketMut>)>,
) {
for ((tid, peer_id), packets) in packets {
tracing::trace!(underlay_id = ?tid, %peer_id, n_packets = packets.len());
if let Some(queue) = slf.underlay_transports.get(&tid) {
queue.send((peer_id, packets)).unwrap();
}
}
}
const MAX_IDLE_SLEEP: core::time::Duration = core::time::Duration::from_secs(5);
async fn sleep_until_event(deadline: Option<tokio::time::Instant>) {
let until = next_wakeup(deadline, tokio::time::Instant::now(), MAX_IDLE_SLEEP);
tokio::time::sleep_until(until).await;
}
fn next_wakeup<I: core::ops::Add<core::time::Duration, Output = I> + Copy>(
deadline: Option<I>,
now: I,
max_idle_sleep: core::time::Duration,
) -> I {
match deadline {
Some(deadline) => deadline,
None => now + max_idle_sleep,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn no_scheduled_event_still_wakes_within_floor() {
let now = std::time::Instant::now();
let woke = next_wakeup(None, now, MAX_IDLE_SLEEP);
assert_eq!(
woke,
now + MAX_IDLE_SLEEP,
"a None deadline must collapse to the bounded floor, never block forever"
);
}
#[test]
fn near_event_is_honored_exactly() {
let now = std::time::Instant::now();
let soon = now + core::time::Duration::from_millis(50);
assert_eq!(
next_wakeup(Some(soon), now, MAX_IDLE_SLEEP),
soon,
"an event sooner than the floor must wake exactly on its deadline"
);
}
#[test]
fn far_event_is_honored_not_clamped() {
let now = std::time::Instant::now();
let far = now + core::time::Duration::from_secs(3600);
assert_eq!(
next_wakeup(Some(far), now, MAX_IDLE_SLEEP),
far,
"a far-off scheduled event must be honored exactly, not clamped to the idle floor"
);
}
#[test]
fn keepalive_in_25s_sleeps_to_the_deadline_not_the_floor() {
let now = std::time::Instant::now();
let keepalive_due = now + core::time::Duration::from_secs(25);
let woke = next_wakeup(Some(keepalive_due), now, MAX_IDLE_SLEEP);
assert_eq!(
woke, keepalive_due,
"a 25s keepalive deadline must be slept to directly (one wakeup), not capped at the idle floor"
);
assert!(
woke > now + MAX_IDLE_SLEEP,
"the wakeup must be well past the idle floor: the floor must not shorten a real deadline"
);
}
#[test]
fn idle_wakeup_is_coarse_and_never_busy_spins() {
assert!(
MAX_IDLE_SLEEP > core::time::Duration::ZERO,
"the idle floor must be a positive cadence, else step() would busy-spin"
);
let base = std::time::Instant::now();
for offset_ms in [0u64, 1, 250, 5_000, 60_000] {
let now = base + core::time::Duration::from_millis(offset_ms);
let woke = next_wakeup(None, now, MAX_IDLE_SLEEP);
assert!(
woke > now,
"idle wakeup must be strictly after now (no busy-spin); got {woke:?} <= {now:?}"
);
assert_eq!(
woke,
now + MAX_IDLE_SLEEP,
"idle wakeup must land on the bounded coarse floor, never sooner and never never"
);
}
}
}