1use std::{
2 cmp::Ordering,
3 collections::BinaryHeap,
4 future::Future,
5 time::{Duration, Instant},
6};
7
8use anyhow::Result;
9use hash_map_id::HashMapId;
10use lunatic_common_api::IntoTrap;
11use lunatic_process::{state::ProcessState, Signal};
12use lunatic_process_api::ProcessCtx;
13use tokio::task::JoinHandle;
14use wasmtime::{Caller, Linker};
15
16#[derive(Debug)]
17struct HeapValue {
18 instant: Instant,
19 key: u64,
20}
21
22impl PartialOrd for HeapValue {
23 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
24 Some(self.instant.cmp(&other.instant).reverse())
25 }
26}
27
28impl Ord for HeapValue {
29 fn cmp(&self, other: &Self) -> Ordering {
30 self.instant.cmp(&other.instant).reverse()
31 }
32}
33
34impl PartialEq for HeapValue {
35 fn eq(&self, other: &Self) -> bool {
36 self.instant.eq(&other.instant)
37 }
38}
39
40impl Eq for HeapValue {}
41
42#[derive(Debug, Default)]
43pub struct TimerResources {
44 hash_map: HashMapId<JoinHandle<()>>,
45 heap: BinaryHeap<HeapValue>,
46}
47
48impl TimerResources {
49 pub fn add(&mut self, handle: JoinHandle<()>, target_time: Instant) -> u64 {
50 self.cleanup_expired_timers();
51
52 let id = self.hash_map.add(handle);
53 self.heap.push(HeapValue {
54 instant: target_time,
55 key: id,
56 });
57 id
58 }
59
60 fn cleanup_expired_timers(&mut self) {
61 let deadline = Instant::now();
62 while let Some(HeapValue { instant, .. }) = self.heap.peek() {
63 if *instant > deadline {
64 return;
66 }
67
68 let key = self
69 .heap
70 .pop()
71 .expect("not empty because we matched on peek")
72 .key;
73 self.hash_map.remove(key);
74 }
75 }
76
77 pub fn remove(&mut self, id: u64) -> Option<JoinHandle<()>> {
78 self.hash_map.remove(id)
79 }
80}
81
82pub trait TimerCtx {
83 fn timer_resources(&self) -> &TimerResources;
84 fn timer_resources_mut(&mut self) -> &mut TimerResources;
85}
86
87pub fn register<T: ProcessState + ProcessCtx<T> + TimerCtx + Send + 'static>(
88 linker: &mut Linker<T>,
89) -> Result<()> {
90 linker.func_wrap("lunatic::timer", "send_after", send_after)?;
91 linker.func_wrap1_async("lunatic::timer", "cancel_timer", cancel_timer)?;
92
93 #[cfg(feature = "metrics")]
94 metrics::describe_counter!(
95 "lunatic.timers.started",
96 metrics::Unit::Count,
97 "number of timers set since startup, will usually be completed + canceled + active"
98 );
99 #[cfg(feature = "metrics")]
100 metrics::describe_counter!(
101 "lunatic.timers.completed",
102 metrics::Unit::Count,
103 "number of timers completed since startup"
104 );
105 #[cfg(feature = "metrics")]
106 metrics::describe_counter!(
107 "lunatic.timers.canceled",
108 metrics::Unit::Count,
109 "number of timers canceled since startup"
110 );
111 #[cfg(feature = "metrics")]
112 metrics::describe_gauge!(
113 "lunatic.timers.active",
114 metrics::Unit::Count,
115 "number of timers currently active"
116 );
117
118 Ok(())
119}
120
121fn send_after<T: ProcessState + ProcessCtx<T> + TimerCtx>(
129 mut caller: Caller<T>,
130 process_id: u64,
131 delay: u64,
132) -> Result<u64> {
133 let message = caller
134 .data_mut()
135 .message_scratch_area()
136 .take()
137 .or_trap("lunatic::message::send_after")?;
138
139 let process = caller.data_mut().environment().get_process(process_id);
140
141 let target_time = Instant::now() + Duration::from_millis(delay);
142 let timer_handle = tokio::task::spawn(async move {
143 #[cfg(feature = "metrics")]
144 metrics::increment_counter!("lunatic.timers.started");
145 #[cfg(feature = "metrics")]
146 metrics::increment_gauge!("lunatic.timers.active", 1.0);
147 let duration_remaining = target_time - Instant::now();
148 if duration_remaining != Duration::ZERO {
149 tokio::time::sleep(duration_remaining).await;
150 }
151 if let Some(process) = process {
152 #[cfg(feature = "metrics")]
153 metrics::increment_counter!("lunatic.timers.completed");
154 #[cfg(feature = "metrics")]
155 metrics::decrement_gauge!("lunatic.timers.active", 1.0);
156 process.send(Signal::Message(message));
157 }
158 });
159
160 let id = caller
161 .data_mut()
162 .timer_resources_mut()
163 .add(timer_handle, target_time);
164 Ok(id)
165}
166
167fn cancel_timer<T: ProcessState + TimerCtx + Send>(
176 mut caller: Caller<T>,
177 timer_id: u64,
178) -> Box<dyn Future<Output = Result<u32>> + Send + '_> {
179 Box::new(async move {
180 let timer_handle = caller.data_mut().timer_resources_mut().remove(timer_id);
181 match timer_handle {
182 Some(timer_handle) => {
183 timer_handle.abort();
184 #[cfg(feature = "metrics")]
185 metrics::increment_counter!("lunatic.timers.canceled");
186 #[cfg(feature = "metrics")]
187 metrics::decrement_gauge!("lunatic.timers.active", 1.0);
188 Ok(1)
189 }
190 None => Ok(0),
191 }
192 })
193}