stuck/
runtime.rs

1//! Constructions to bootstrap stuck runtime.
2
3use std::cell::Cell;
4use std::collections::{HashMap, VecDeque};
5use std::mem::MaybeUninit;
6use std::num::NonZeroUsize;
7use std::sync::{Arc, Condvar, Mutex, MutexGuard};
8use std::time::Duration;
9use std::{ptr, thread};
10
11use ignore_result::Ignore;
12
13use crate::channel::parallel::{self, Sender};
14use crate::channel::prelude::*;
15use crate::task::{self, SchedFlow, Task};
16use crate::{io, net, time};
17
18thread_local! {
19    static SCHEDULER: Cell<Option<ptr::NonNull<Scheduler>>> = const {  Cell::new(None) };
20}
21
22const STOP_MSG: &str = "runtime stopped";
23
24struct TaskPointer(ptr::NonNull<Task>);
25
26unsafe impl Send for TaskPointer {}
27
28impl TaskPointer {
29    fn from(task: &Task) -> TaskPointer {
30        TaskPointer(ptr::NonNull::from(task))
31    }
32}
33
34struct Scope {}
35
36impl Scope {
37    fn enter(scheduler: &Scheduler) -> Self {
38        SCHEDULER.with(|cell| {
39            assert!(cell.get().is_none(), "runtime scheduler existed");
40            cell.set(Some(ptr::NonNull::from(scheduler)));
41        });
42        Scope {}
43    }
44}
45
46impl Drop for Scope {
47    fn drop(&mut self) {
48        SCHEDULER.with(|cell| {
49            assert!(cell.get().is_some(), "runtime scheduler does not exist");
50            cell.set(None);
51        });
52    }
53}
54
55/// Builder for [Runtime].
56#[derive(Default)]
57pub struct Builder {
58    parallelism: Option<usize>,
59}
60
61impl Builder {
62    /// Specifies the number of parallel threads for scheduling. Defaults to [std::thread::available_parallelism].
63    pub fn parallelism(&mut self, n: usize) -> &mut Self {
64        assert!(n > 0, "parallelism must not be zero");
65        self.parallelism = Some(n);
66        self
67    }
68
69    /// Constructs an [Runtime] to spawn and schedule tasks.
70    pub fn build(&mut self) -> Runtime {
71        let parallelism =
72            self.parallelism.unwrap_or_else(|| thread::available_parallelism().map_or(4, NonZeroUsize::get));
73        let (time_sender, time_receiver) = parallel::unbounded(512);
74        let poller = net::Poller::new().unwrap();
75        let (io_poller, io_requester) = io::Poller::new();
76        let scheduler = Scheduler::new(parallelism, time_sender.clone(), poller.registry(), io_requester);
77        let net_stopper = poller.start().unwrap();
78        let io_stopper = io_poller.start(&scheduler.registry).unwrap();
79        let timer = task::Builder::with_scheduler(&scheduler).spawn(move || {
80            time::timer(time_receiver);
81        });
82        let ticker = thread::Builder::new()
83            .name("stuck::time::ticker".to_string())
84            .spawn(move || {
85                time::tickr(time_sender);
86            })
87            .expect("failed to spawn stuck::time::ticker thread");
88        let scheduling_threads = Scheduler::start(&scheduler);
89        Runtime {
90            scheduler,
91            timer: MaybeUninit::new(timer),
92            ticker: MaybeUninit::new(ticker),
93            io_stopper,
94            net_stopper: MaybeUninit::new(net_stopper),
95            scheduling_threads,
96        }
97    }
98}
99
100/// Runtime encapsulates io selecter, timer and task scheduler to serve spawned tasks.
101///
102/// [Runtime::drop] will stop and join all serving threads.
103pub struct Runtime {
104    scheduler: Arc<Scheduler>,
105    timer: MaybeUninit<task::JoinHandle<()>>,
106    ticker: MaybeUninit<thread::JoinHandle<()>>,
107    io_stopper: io::Stopper,
108    net_stopper: MaybeUninit<net::Stopper>,
109    scheduling_threads: Vec<thread::JoinHandle<()>>,
110}
111
112impl Runtime {
113    /// Constructs an runtime to serve spawned tasks.
114    pub fn new() -> Runtime {
115        Builder::default().build()
116    }
117
118    /// Constructs a task builder to spawn task.
119    pub fn builder(&self) -> task::Builder<'_> {
120        task::Builder::with_scheduler(&self.scheduler)
121    }
122
123    /// Spawns a concurrent task and returns a [task::JoinHandle] for it.
124    ///
125    /// See [task::spawn] for more details
126    pub fn spawn<F, T>(&mut self, f: F) -> task::JoinHandle<T>
127    where
128        F: FnOnce() -> T,
129        F: Send + 'static,
130        T: Send + 'static,
131    {
132        task::Builder::with_scheduler(&self.scheduler).spawn(f)
133    }
134}
135
136impl Default for Runtime {
137    fn default() -> Self {
138        Runtime::new()
139    }
140}
141
142impl Drop for Runtime {
143    fn drop(&mut self) {
144        self.scheduler.stop();
145        let timer = unsafe { ptr::read(self.timer.as_ptr()) };
146        let ticker = unsafe { ptr::read(self.ticker.as_ptr()) };
147        let mut net_stopper = unsafe { ptr::read(self.net_stopper.as_ptr()) };
148        timer.join().ignore();
149        ticker.join().ignore();
150        self.scheduler.stop();
151        // uring completion is notified through eventfd which monitoried through net::Poller.
152        self.io_stopper.stop();
153        net_stopper.stop();
154        for handle in self.scheduling_threads.drain(..) {
155            handle.join().ignore();
156        }
157    }
158}
159
160struct SchedulerState {
161    runq: VecDeque<TaskPointer>,
162    registry: HashMap<u64, Arc<Task>>,
163
164    // -1: running
165    //  0: start stopping
166    // +n: n stopped threads
167    stopped: isize,
168}
169
170impl SchedulerState {
171    fn new() -> Self {
172        SchedulerState { runq: VecDeque::with_capacity(256), registry: HashMap::with_capacity(256), stopped: -1 }
173    }
174}
175
176pub(crate) struct Scheduler {
177    parallelism: usize,
178    timer: Sender<time::Message>,
179    state: Mutex<SchedulerState>,
180    waker: Condvar,
181    registry: Arc<net::Registry>,
182    requester: io::Requester,
183}
184
185unsafe impl Send for Scheduler {}
186unsafe impl Sync for Scheduler {}
187
188impl Scheduler {
189    fn new(
190        parallelism: usize,
191        timer: Sender<time::Message>,
192        registry: Arc<net::Registry>,
193        requester: io::Requester,
194    ) -> Arc<Scheduler> {
195        Arc::new(Scheduler {
196            parallelism,
197            timer,
198            state: Mutex::new(SchedulerState::new()),
199            waker: Condvar::new(),
200            registry,
201            requester,
202        })
203    }
204
205    /// Starts threads to serve spawned tasks.
206    fn start(self: &Arc<Scheduler>) -> Vec<thread::JoinHandle<()>> {
207        let parallelism = self.parallelism;
208        (0..parallelism)
209            .map(move |i| {
210                let scheduler = self.clone();
211                let name = format!("stuck::scheduler({}/{})", i + 1, parallelism);
212                thread::Builder::new()
213                    .name(name)
214                    .spawn(move || scheduler.serve())
215                    .expect("failed to spawn stuck::scheduler thread")
216            })
217            .collect()
218    }
219
220    /// This method is designed to be called twice. One for stop signal and one after all attendant
221    /// threads stopped.
222    fn stop(&self) {
223        let mut state = self.state.lock().unwrap();
224        state.stopped += 1;
225        self.waker.notify_all();
226    }
227
228    pub unsafe fn registry<'a>() -> &'a net::Registry {
229        &Self::current().registry
230    }
231
232    pub(crate) unsafe fn current<'a>() -> &'a Scheduler {
233        SCHEDULER.with(|s| s.get().unwrap_unchecked().as_ref())
234    }
235
236    pub(crate) fn try_current<'a>() -> Option<&'a Scheduler> {
237        SCHEDULER.with(|s| s.get().map(|s| unsafe { s.as_ref() }))
238    }
239
240    pub(crate) fn try_time_sender() -> Option<Sender<time::Message>> {
241        Self::try_current().map(|s| s.timer.clone())
242    }
243
244    pub fn sched(&self, t: Arc<Task>) {
245        let mut state = self.state.lock().unwrap();
246        let id = t.id();
247        let pointer = TaskPointer::from(&t);
248        state.registry.insert(id, t);
249        state.runq.push_back(pointer);
250        self.waker.notify_one();
251    }
252
253    pub(crate) fn resume(&self, t: &Task) {
254        let mut state = self.state.lock().unwrap();
255        state.runq.push_back(TaskPointer::from(t));
256        self.waker.notify_one();
257    }
258
259    fn run<'a>(&'a self, mut state: MutexGuard<'a, SchedulerState>) -> MutexGuard<'a, SchedulerState> {
260        if let Some(mut task) = state.runq.pop_front() {
261            drop(state);
262            let task = unsafe { task.0.as_mut() };
263            let flow = task.sched();
264            let id = task.id();
265            state = self.state.lock().unwrap();
266            match flow {
267                SchedFlow::Yield => state.runq.push_back(TaskPointer::from(task)),
268                SchedFlow::Block => {},
269                SchedFlow::Cease => {
270                    state.registry.remove(&id);
271                },
272            }
273            state
274        } else {
275            self.waker.wait(state).unwrap()
276        }
277    }
278
279    fn serve(&self) {
280        let _scope = Scope::enter(self);
281        let _io_scope = io::Scope::enter(self.requester.clone());
282        let mut state = self.state.lock().unwrap();
283        while state.stopped < 0 {
284            state = self.run(state)
285        }
286        let stopped = state.stopped + 1;
287        state.stopped = stopped;
288        if stopped as usize != self.parallelism {
289            return;
290        }
291        // This is the last scheduling thread.
292        drop(state);
293        self.timer.clone().send(time::Message::Stop).ignore();
294        state = self.state.lock().unwrap();
295        while state.stopped == self.parallelism as isize {
296            state = self.run(state)
297        }
298        // No timer and io poller now, this is the sole execution thread.
299        while !state.registry.is_empty() {
300            // SAFETY: Avoid compilation warning in read to `registry` and write to `runq`.
301            let registry: &HashMap<u64, Arc<Task>> = unsafe { std::mem::transmute::<_, _>(&state.registry) };
302            registry.values().filter(|t| t.grab()).map(|t| TaskPointer::from(t)).for_each(|t| state.runq.push_back(t));
303            while let Some(mut task) = state.runq.pop_front() {
304                drop(state);
305                let task = unsafe { task.0.as_mut() };
306                let id = task.id();
307                task.abort(STOP_MSG);
308                state = self.state.lock().unwrap();
309                state.registry.remove(&id);
310            }
311            drop(state);
312            // Sleep to let waker resume task after winning Task::grab(eg. `running` state).
313            std::thread::sleep(Duration::from_millis(500));
314            state = self.state.lock().unwrap();
315        }
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use pretty_assertions::assert_eq;
322
323    use super::*;
324
325    thread_local! {
326        static LOCAL_SECRET: Cell<usize> = Cell::new(0);
327    }
328
329    #[test]
330    #[should_panic]
331    fn runtime_builder_parallelism_zero() {
332        Builder::default().parallelism(0).build();
333    }
334
335    #[test]
336    fn runtime_builder_parallelism_one() {
337        let mut runtime = Builder::default().parallelism(1).build();
338        let secret = 333;
339        let set_secret = runtime.spawn(move || {
340            thread::sleep(Duration::from_secs(10));
341            LOCAL_SECRET.with(|cell| cell.set(secret));
342        });
343        let get_secret = runtime.spawn(move || LOCAL_SECRET.with(|cell| cell.get()));
344        set_secret.join().unwrap();
345        assert_eq!(secret, get_secret.join().unwrap());
346    }
347
348    #[test]
349    fn runtime_builder_parallelism_multiple() {
350        let mut runtime = Builder::default().parallelism(2).build();
351        let secret = 111;
352        let set_secret = runtime.spawn(move || {
353            thread::sleep(Duration::from_secs(10));
354            LOCAL_SECRET.with(|cell| cell.set(secret));
355        });
356        let get_secret = runtime.spawn(move || LOCAL_SECRET.with(|cell| cell.get()));
357        set_secret.join().unwrap();
358        assert_ne!(secret, get_secret.join().unwrap());
359    }
360}