use std::{future::Future, pin::Pin, time::Duration};
use std::{
cmp::Ordering,
collections::BinaryHeap,
sync::{
Arc, Mutex,
atomic::{AtomicU64, Ordering as AtomicOrdering},
},
task::{Context, Poll, Waker},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct WakeupId(u64);
impl WakeupId {
fn new(id: u64) -> Self {
Self(id)
}
}
impl WakeupId {
pub fn as_u64(&self) -> u64 {
self.0
}
}
#[derive(Debug)]
pub struct Wakeup {
pub deadline: u64,
pub id: WakeupId,
#[allow(dead_code)]
waker: Option<Waker>,
}
impl Wakeup {
fn new(deadline: u64, id: WakeupId) -> Self {
Self {
deadline,
id,
waker: None,
}
}
}
impl PartialEq for Wakeup {
fn eq(&self, other: &Self) -> bool {
self.deadline == other.deadline && self.id == other.id
}
}
impl Eq for Wakeup {}
impl PartialOrd for Wakeup {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Wakeup {
fn cmp(&self, other: &Self) -> Ordering {
match other.deadline.cmp(&self.deadline) {
Ordering::Equal => other.id.0.cmp(&self.id.0),
ord @ Ordering::Less | ord @ Ordering::Greater => ord,
}
}
}
pub trait TimeSource: Send + Sync + Clone + 'static {
fn now_nanos(&self) -> u64;
fn now(&self) -> Duration {
Duration::from_nanos(self.now_nanos())
}
fn sleep(&self, duration: Duration) -> Pin<Box<dyn Future<Output = ()> + Send>>;
fn sleep_until(&self, deadline_nanos: u64) -> Pin<Box<dyn Future<Output = ()> + Send>>;
fn timeout<F, T>(
&self,
duration: Duration,
future: F,
) -> Pin<Box<dyn Future<Output = Option<T>> + Send>>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static;
fn connection_idle_timeout(&self) -> Duration {
Duration::from_secs(120) }
fn supports_keepalive(&self) -> bool {
true }
}
#[derive(Clone)]
pub struct RealTime {
epoch: tokio::time::Instant,
}
impl Default for RealTime {
fn default() -> Self {
Self::new()
}
}
impl RealTime {
pub fn new() -> Self {
Self {
epoch: tokio::time::Instant::now(),
}
}
}
impl TimeSource for RealTime {
fn now_nanos(&self) -> u64 {
self.epoch.elapsed().as_nanos() as u64
}
fn sleep(&self, duration: Duration) -> Pin<Box<dyn Future<Output = ()> + Send>> {
Box::pin(tokio::time::sleep(duration))
}
fn sleep_until(&self, deadline_nanos: u64) -> Pin<Box<dyn Future<Output = ()> + Send>> {
let now = self.now_nanos();
if deadline_nanos <= now {
Box::pin(std::future::ready(()))
} else {
let duration = Duration::from_nanos(deadline_nanos - now);
Box::pin(tokio::time::sleep(duration))
}
}
fn timeout<F, T>(
&self,
duration: Duration,
future: F,
) -> Pin<Box<dyn Future<Output = Option<T>> + Send>>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
Box::pin(async move { tokio::time::timeout(duration, future).await.ok() })
}
fn supports_keepalive(&self) -> bool {
!crate::config::SimulationTransportOpt::is_enabled()
}
fn connection_idle_timeout(&self) -> Duration {
if crate::config::SimulationIdleTimeout::is_enabled() {
Duration::from_secs(86400) } else {
Duration::from_secs(120) }
}
}
#[derive(Debug)]
struct VirtualTimeState {
current_nanos: AtomicU64,
next_wakeup_id: AtomicU64,
pending_wakeups: Mutex<BinaryHeap<Wakeup>>,
pending_wakers: Mutex<Vec<(WakeupId, Waker)>>,
}
#[derive(Clone)]
pub struct VirtualTime {
state: Arc<VirtualTimeState>,
}
impl Default for VirtualTime {
fn default() -> Self {
Self::new()
}
}
impl VirtualTime {
pub fn new() -> Self {
Self::with_initial_time(0)
}
pub fn with_initial_time(initial_nanos: u64) -> Self {
Self {
state: Arc::new(VirtualTimeState {
current_nanos: AtomicU64::new(initial_nanos),
next_wakeup_id: AtomicU64::new(0),
pending_wakeups: Mutex::new(BinaryHeap::new()),
pending_wakers: Mutex::new(Vec::new()),
}),
}
}
pub fn pending_wakeup_count(&self) -> usize {
self.state.pending_wakeups.lock().unwrap().len()
}
pub fn advance(&self, duration: Duration) -> Vec<WakeupId> {
let new_time = self
.state
.current_nanos
.load(AtomicOrdering::SeqCst)
.saturating_add(duration.as_nanos() as u64);
self.advance_to(new_time)
}
pub fn advance_to(&self, target_nanos: u64) -> Vec<WakeupId> {
let current = self.state.current_nanos.load(AtomicOrdering::SeqCst);
if target_nanos <= current {
return Vec::new();
}
self.state
.current_nanos
.store(target_nanos, AtomicOrdering::SeqCst);
let mut triggered = Vec::new();
let mut wakers_to_wake = Vec::new();
{
let mut pending = self.state.pending_wakeups.lock().unwrap();
let mut pending_wakers = self.state.pending_wakers.lock().unwrap();
while let Some(wakeup) = pending.peek() {
if wakeup.deadline <= target_nanos {
let wakeup = pending.pop().unwrap();
triggered.push(wakeup.id);
if let Some(pos) = pending_wakers.iter().position(|(id, _)| *id == wakeup.id) {
let (_, waker) = pending_wakers.swap_remove(pos);
wakers_to_wake.push(waker);
}
} else {
break;
}
}
}
for waker in wakers_to_wake {
waker.wake();
}
triggered
}
pub fn advance_to_next_wakeup(&self) -> Option<(WakeupId, u64)> {
let next_deadline = {
let pending = self.state.pending_wakeups.lock().unwrap();
pending.peek().map(|w| w.deadline)
};
if let Some(deadline) = next_deadline {
let triggered = self.advance_to(deadline);
if let Some(id) = triggered.first() {
return Some((*id, deadline));
}
}
None
}
pub fn next_wakeup_deadline(&self) -> Option<u64> {
self.state
.pending_wakeups
.lock()
.unwrap()
.peek()
.map(|w| w.deadline)
}
fn register_wakeup(&self, deadline: u64) -> WakeupId {
let id = WakeupId::new(
self.state
.next_wakeup_id
.fetch_add(1, AtomicOrdering::SeqCst),
);
let wakeup = Wakeup::new(deadline, id);
self.state.pending_wakeups.lock().unwrap().push(wakeup);
id
}
#[allow(dead_code)]
fn register_waker(&self, id: WakeupId, waker: Waker) {
let mut pending_wakers = self.state.pending_wakers.lock().unwrap();
if let Some(pos) = pending_wakers.iter().position(|(wid, _)| *wid == id) {
pending_wakers[pos].1 = waker;
} else {
pending_wakers.push((id, waker));
}
}
#[allow(dead_code)]
fn is_wakeup_triggered(&self, id: WakeupId) -> bool {
let pending = self.state.pending_wakeups.lock().unwrap();
!pending.iter().any(|w| w.id == id)
}
pub fn trigger_expired(&self) -> Vec<WakeupId> {
let current = self.now_nanos();
let mut triggered = Vec::new();
let mut wakers_to_wake = Vec::new();
{
let mut pending = self.state.pending_wakeups.lock().unwrap();
let mut pending_wakers = self.state.pending_wakers.lock().unwrap();
while let Some(wakeup) = pending.peek() {
if wakeup.deadline <= current {
let wakeup = pending.pop().unwrap();
triggered.push(wakeup.id);
if let Some(pos) = pending_wakers.iter().position(|(id, _)| *id == wakeup.id) {
let (_, waker) = pending_wakers.swap_remove(pos);
wakers_to_wake.push(waker);
}
} else {
break;
}
}
}
for waker in wakers_to_wake {
waker.wake();
}
triggered
}
pub fn try_auto_advance(&self) -> Option<u64> {
self.try_auto_advance_bounded(Duration::from_secs(1))
}
pub fn try_auto_advance_bounded(&self, max_step: Duration) -> Option<u64> {
self.trigger_expired();
if let Some(deadline) = self.next_wakeup_deadline() {
let current = self.now_nanos();
if deadline > current {
let max_advance = current.saturating_add(max_step.as_nanos() as u64);
let target = deadline.min(max_advance);
self.advance_to(target);
return Some(target);
}
None
} else {
None
}
}
}
impl TimeSource for VirtualTime {
fn now_nanos(&self) -> u64 {
self.state.current_nanos.load(AtomicOrdering::SeqCst)
}
fn sleep(&self, duration: Duration) -> Pin<Box<dyn Future<Output = ()> + Send>> {
let deadline = self.now_nanos().saturating_add(duration.as_nanos() as u64);
self.sleep_until(deadline)
}
fn sleep_until(&self, deadline_nanos: u64) -> Pin<Box<dyn Future<Output = ()> + Send>> {
let current = self.now_nanos();
if deadline_nanos <= current {
return Box::pin(std::future::ready(()));
}
let id = self.register_wakeup(deadline_nanos);
let state = self.state.clone();
Box::pin(VirtualSleep { id, state })
}
fn timeout<F, T>(
&self,
duration: Duration,
future: F,
) -> Pin<Box<dyn Future<Output = Option<T>> + Send>>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let deadline = self.now_nanos().saturating_add(duration.as_nanos() as u64);
let sleep = self.sleep_until(deadline);
Box::pin(async move {
tokio::select! {
biased;
result = future => Some(result),
_ = sleep => None,
}
})
}
fn connection_idle_timeout(&self) -> Duration {
Duration::from_secs(86400) }
fn supports_keepalive(&self) -> bool {
false
}
}
struct VirtualSleep {
id: WakeupId,
state: Arc<VirtualTimeState>,
}
impl Future for VirtualSleep {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let pending = self.state.pending_wakeups.lock().unwrap();
let mut pending_wakers = self.state.pending_wakers.lock().unwrap();
let still_pending = pending.iter().any(|w| w.id == self.id);
if !still_pending {
Poll::Ready(())
} else {
let wakeup_id = self.id;
if let Some(pos) = pending_wakers.iter().position(|(id, _)| *id == wakeup_id) {
pending_wakers[pos].1 = cx.waker().clone();
} else {
pending_wakers.push((wakeup_id, cx.waker().clone()));
}
Poll::Pending
}
}
}
unsafe impl Send for VirtualSleep {}
pub struct TimeSourceInterval<T: TimeSource> {
time_source: T,
period_nanos: u64,
next_tick_nanos: u64,
}
impl<T: TimeSource> TimeSourceInterval<T> {
pub fn new(time_source: T, period: Duration) -> Self {
let now = time_source.now_nanos();
let period_nanos = period.as_nanos() as u64;
Self {
time_source,
period_nanos,
next_tick_nanos: now,
}
}
pub fn new_at(time_source: T, start_nanos: u64, period: Duration) -> Self {
let period_nanos = period.as_nanos() as u64;
Self {
time_source,
period_nanos,
next_tick_nanos: start_nanos,
}
}
pub async fn tick(&mut self) {
let now = self.time_source.now_nanos();
if now >= self.next_tick_nanos {
let elapsed = now - self.next_tick_nanos;
let periods_elapsed = elapsed / self.period_nanos + 1;
self.next_tick_nanos += periods_elapsed * self.period_nanos;
return;
}
self.time_source.sleep_until(self.next_tick_nanos).await;
self.next_tick_nanos += self.period_nanos;
}
pub fn period(&self) -> Duration {
Duration::from_nanos(self.period_nanos)
}
pub fn reset(&mut self) {
self.next_tick_nanos = self.time_source.now_nanos() + self.period_nanos;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_virtual_time_starts_at_zero() {
let vt = VirtualTime::new();
assert_eq!(vt.now_nanos(), 0);
}
#[test]
fn test_virtual_time_advance() {
let vt = VirtualTime::new();
vt.advance(Duration::from_secs(10));
assert_eq!(vt.now_nanos(), 10_000_000_000);
}
#[test]
fn test_virtual_time_wakeup_ordering() {
let vt = VirtualTime::new();
let _id3 = vt.register_wakeup(300);
let _id1 = vt.register_wakeup(100);
let _id2 = vt.register_wakeup(200);
let triggered = vt.advance_to(150);
assert_eq!(triggered.len(), 1);
let triggered = vt.advance_to(250);
assert_eq!(triggered.len(), 1);
let triggered = vt.advance_to(400);
assert_eq!(triggered.len(), 1);
}
#[test]
fn test_virtual_time_same_deadline_fifo() {
let vt = VirtualTime::new();
let id1 = vt.register_wakeup(100);
let id2 = vt.register_wakeup(100);
let id3 = vt.register_wakeup(100);
let triggered = vt.advance_to(100);
assert_eq!(triggered.len(), 3);
assert_eq!(triggered[0], id1);
assert_eq!(triggered[1], id2);
assert_eq!(triggered[2], id3);
}
#[test]
fn test_virtual_time_advance_to_next() {
let vt = VirtualTime::new();
vt.register_wakeup(50);
vt.register_wakeup(100);
let result = vt.advance_to_next_wakeup();
assert!(result.is_some());
let (_, deadline) = result.unwrap();
assert_eq!(deadline, 50);
assert_eq!(vt.now_nanos(), 50);
let result = vt.advance_to_next_wakeup();
assert!(result.is_some());
let (_, deadline) = result.unwrap();
assert_eq!(deadline, 100);
assert_eq!(vt.now_nanos(), 100);
let result = vt.advance_to_next_wakeup();
assert!(result.is_none());
}
#[test]
fn test_real_time_basic() {
let rt = RealTime::new();
let t1 = rt.now_nanos();
std::thread::sleep(Duration::from_millis(10));
let t2 = rt.now_nanos();
assert!(t2 > t1);
}
#[tokio::test]
async fn test_virtual_time_sleep_immediate() {
let vt = VirtualTime::with_initial_time(1000);
let sleep = vt.sleep_until(500);
tokio::time::timeout(Duration::from_millis(10), sleep)
.await
.expect("sleep should complete immediately");
}
#[test]
fn test_virtual_time_wakeup_order_reverse_registration() {
let vt = VirtualTime::new();
for i in (0..5).rev() {
drop(vt.sleep_until((i + 1) * 100));
}
let mut fired = Vec::new();
while let Some((id, deadline)) = vt.advance_to_next_wakeup() {
fired.push((id.as_u64(), deadline));
}
for i in 1..fired.len() {
assert!(
fired[i].1 >= fired[i - 1].1,
"Wakeups should be ordered by deadline, not registration order"
);
}
assert_eq!(fired.len(), 5);
}
}