lunatic_timer_api/
lib.rs

1use std::{
2    cmp::Ordering,
3    collections::BinaryHeap,
4    future::Future,
5    time::{Duration, Instant},
6};
7
8use anyhow::Result;
9use hash_map_id::HashMapId;
10use lunatic_common_api::IntoTrap;
11use lunatic_process::{state::ProcessState, Signal};
12use lunatic_process_api::ProcessCtx;
13use tokio::task::JoinHandle;
14use wasmtime::{Caller, Linker};
15
16#[derive(Debug)]
17struct HeapValue {
18    instant: Instant,
19    key: u64,
20}
21
22impl PartialOrd for HeapValue {
23    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
24        Some(self.instant.cmp(&other.instant).reverse())
25    }
26}
27
28impl Ord for HeapValue {
29    fn cmp(&self, other: &Self) -> Ordering {
30        self.instant.cmp(&other.instant).reverse()
31    }
32}
33
34impl PartialEq for HeapValue {
35    fn eq(&self, other: &Self) -> bool {
36        self.instant.eq(&other.instant)
37    }
38}
39
40impl Eq for HeapValue {}
41
42#[derive(Debug, Default)]
43pub struct TimerResources {
44    hash_map: HashMapId<JoinHandle<()>>,
45    heap: BinaryHeap<HeapValue>,
46}
47
48impl TimerResources {
49    pub fn add(&mut self, handle: JoinHandle<()>, target_time: Instant) -> u64 {
50        self.cleanup_expired_timers();
51
52        let id = self.hash_map.add(handle);
53        self.heap.push(HeapValue {
54            instant: target_time,
55            key: id,
56        });
57        id
58    }
59
60    fn cleanup_expired_timers(&mut self) {
61        let deadline = Instant::now();
62        while let Some(HeapValue { instant, .. }) = self.heap.peek() {
63            if *instant > deadline {
64                // instant is after the deadline so stop
65                return;
66            }
67
68            let key = self
69                .heap
70                .pop()
71                .expect("not empty because we matched on peek")
72                .key;
73            self.hash_map.remove(key);
74        }
75    }
76
77    pub fn remove(&mut self, id: u64) -> Option<JoinHandle<()>> {
78        self.hash_map.remove(id)
79    }
80}
81
82pub trait TimerCtx {
83    fn timer_resources(&self) -> &TimerResources;
84    fn timer_resources_mut(&mut self) -> &mut TimerResources;
85}
86
87pub fn register<T: ProcessState + ProcessCtx<T> + TimerCtx + Send + 'static>(
88    linker: &mut Linker<T>,
89) -> Result<()> {
90    linker.func_wrap("lunatic::timer", "send_after", send_after)?;
91    linker.func_wrap1_async("lunatic::timer", "cancel_timer", cancel_timer)?;
92
93    #[cfg(feature = "metrics")]
94    metrics::describe_counter!(
95        "lunatic.timers.started",
96        metrics::Unit::Count,
97        "number of timers set since startup, will usually be completed + canceled + active"
98    );
99    #[cfg(feature = "metrics")]
100    metrics::describe_counter!(
101        "lunatic.timers.completed",
102        metrics::Unit::Count,
103        "number of timers completed since startup"
104    );
105    #[cfg(feature = "metrics")]
106    metrics::describe_counter!(
107        "lunatic.timers.canceled",
108        metrics::Unit::Count,
109        "number of timers canceled since startup"
110    );
111    #[cfg(feature = "metrics")]
112    metrics::describe_gauge!(
113        "lunatic.timers.active",
114        metrics::Unit::Count,
115        "number of timers currently active"
116    );
117
118    Ok(())
119}
120
121// Sends the message to a process after a delay.
122//
123// There are no guarantees that the message will be received.
124//
125// Traps:
126// * If the process ID doesn't exist.
127// * If it's called before creating the next message.
128fn send_after<T: ProcessState + ProcessCtx<T> + TimerCtx>(
129    mut caller: Caller<T>,
130    process_id: u64,
131    delay: u64,
132) -> Result<u64> {
133    let message = caller
134        .data_mut()
135        .message_scratch_area()
136        .take()
137        .or_trap("lunatic::message::send_after")?;
138
139    let process = caller.data_mut().environment().get_process(process_id);
140
141    let target_time = Instant::now() + Duration::from_millis(delay);
142    let timer_handle = tokio::task::spawn(async move {
143        #[cfg(feature = "metrics")]
144        metrics::increment_counter!("lunatic.timers.started");
145        #[cfg(feature = "metrics")]
146        metrics::increment_gauge!("lunatic.timers.active", 1.0);
147        let duration_remaining = target_time - Instant::now();
148        if duration_remaining != Duration::ZERO {
149            tokio::time::sleep(duration_remaining).await;
150        }
151        if let Some(process) = process {
152            #[cfg(feature = "metrics")]
153            metrics::increment_counter!("lunatic.timers.completed");
154            #[cfg(feature = "metrics")]
155            metrics::decrement_gauge!("lunatic.timers.active", 1.0);
156            process.send(Signal::Message(message));
157        }
158    });
159
160    let id = caller
161        .data_mut()
162        .timer_resources_mut()
163        .add(timer_handle, target_time);
164    Ok(id)
165}
166
167// Cancels the specified timer.
168//
169// Returns:
170// * 1 if a timer with the timer_id was found
171// * 0 if no timer was found, this can be either because:
172//     - timer had expired
173//     - timer already had been canceled
174//     - timer_id never corresponded to a timer
175fn cancel_timer<T: ProcessState + TimerCtx + Send>(
176    mut caller: Caller<T>,
177    timer_id: u64,
178) -> Box<dyn Future<Output = Result<u32>> + Send + '_> {
179    Box::new(async move {
180        let timer_handle = caller.data_mut().timer_resources_mut().remove(timer_id);
181        match timer_handle {
182            Some(timer_handle) => {
183                timer_handle.abort();
184                #[cfg(feature = "metrics")]
185                metrics::increment_counter!("lunatic.timers.canceled");
186                #[cfg(feature = "metrics")]
187                metrics::decrement_gauge!("lunatic.timers.active", 1.0);
188                Ok(1)
189            }
190            None => Ok(0),
191        }
192    })
193}