pros_simulator/host/
task.rs

1use std::{
2    collections::{HashMap, HashSet},
3    future::Future,
4    pin::Pin,
5    sync::Arc,
6    task::Poll,
7};
8
9use anyhow::{bail, Context};
10use pros_simulator_interface::SimulatorEvent;
11use tokio::sync::{Mutex, MutexGuard};
12use wasmtime::{
13    AsContextMut, Caller, Engine, Func, Instance, Linker, Module, SharedMemory, Store, Table,
14    TypedFunc, WasmParams,
15};
16
17use super::{memory::SharedMemoryExt, thread_local::TaskStorage, Host, HostCtx, WasmAllocator};
18use crate::{api::configure_api, interface::SimulatorInterface};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum TaskState {
22    /// Active and currently executing. This is the current task.
23    Running,
24    /// Idle and ready to resume
25    Ready,
26    /// Finished executing and will be removed from the task pool
27    Finished,
28    Blocked,
29    // Suspended,
30    Deleted,
31}
32
33pub const TASK_PRIORITIES: u32 = 16;
34
35pub struct TaskOptions {
36    priority: u32,
37    store: Store<Host>,
38    entrypoint: TypedFunc<(), ()>,
39    name: Option<String>,
40}
41
42impl TaskOptions {
43    /// Create options for a task who's entrypoint is a function from robot code.
44    ///
45    /// # Arguments
46    ///
47    /// * `pool` - The task pool to create the task in.
48    /// * `host` - The host to use for the task.
49    /// * `task_start` - The index of the task entrypoint in the task table.
50    ///   Function pointers are transformed into indices in the `__indirect_function_table`
51    ///   by the linker.
52    /// * `args` - The arguments to pass to the task entrypoint.
53    pub fn new_extern<P: WasmParams + 'static>(
54        pool: &mut TaskPool,
55        host: &Host,
56        task_start: u32,
57        args: P,
58    ) -> anyhow::Result<Self> {
59        let args = Arc::new(Mutex::new(Some(args)));
60        Self::new_closure(pool, host, move |mut caller| {
61            let args = args.clone();
62            Box::new(async move {
63                let entrypoint = {
64                    let task_handle = caller.current_task().await;
65                    let current_task = task_handle.lock().await;
66                    current_task
67                        .indirect_call_table
68                        .get(&mut caller, task_start)
69                        .context("Task entrypoint is out of bounds")?
70                };
71
72                let entrypoint = entrypoint
73                    .funcref()
74                    .context("Task entrypoint is not a function")?
75                    .context("Task entrypoint is NULL")?
76                    .typed::<P, ()>(&mut caller)
77                    .context("Task entrypoint has invalid signature")?;
78
79                entrypoint
80                    .call_async(&mut caller, args.lock().await.take().unwrap())
81                    .await?;
82                Ok(())
83            })
84        })
85    }
86
87    /// Create options for a task who's entrypoint is a custom closure created by the host.
88    /// These are treated the same as "real" tasks that have entrypoints in robot code.
89    pub fn new_closure(
90        pool: &mut TaskPool,
91        host: &Host,
92        task_closure: impl for<'a> FnOnce(
93                Caller<'a, Host>,
94            ) -> Box<dyn Future<Output = anyhow::Result<()>> + Send + 'a>
95            + Send
96            + 'static,
97    ) -> anyhow::Result<Self> {
98        let mut store = pool.create_store(host)?;
99        let task_closure = Arc::new(Mutex::new(Some(task_closure)));
100        let entrypoint = Func::wrap0_async(&mut store, move |caller: Caller<'_, Host>| {
101            let task_closure = task_closure.clone();
102            Box::new(async move {
103                let task_closure = task_closure
104                    .lock()
105                    .await
106                    .take()
107                    .expect("Expected task to only be started once");
108                Pin::from(task_closure(caller)).await
109            })
110        })
111        .typed::<(), ()>(&mut store)?;
112
113        Ok(Self {
114            priority: 7,
115            entrypoint,
116            store,
117            name: None,
118        })
119    }
120
121    /// Create options for a task who's entrypoint is a global function from robot code.
122    pub fn new_global(
123        pool: &mut TaskPool,
124        host: &Host,
125        func_name: &'static str,
126    ) -> anyhow::Result<Self> {
127        Self::new_closure(pool, host, move |mut caller| {
128            Box::new(async move {
129                let instance = {
130                    let task_handle = caller.current_task().await;
131                    let this_task = task_handle.lock().await;
132                    this_task.instance
133                };
134
135                let func = instance.get_func(&mut caller, func_name).with_context(|| {
136                    format!("entrypoint missing: expected {func_name} to be defined")
137                })?;
138                let func = func
139                    .typed(&mut caller)
140                    .with_context(|| format!("invalid {func_name} signature: expected () -> ()"))?;
141
142                func.call_async(&mut caller, ()).await
143            })
144        })
145    }
146
147    pub fn name(mut self, name: impl Into<String>) -> Self {
148        self.name = Some(name.into());
149        self
150    }
151
152    pub fn priority(mut self, priority: u32) -> Self {
153        assert!(priority < TASK_PRIORITIES);
154        self.priority = priority;
155        self
156    }
157}
158
159pub struct Task {
160    id: u32,
161    name: String,
162    local_storage: Option<TaskStorage>,
163    task_impl: TypedFunc<(), ()>,
164    priority: u32,
165    errno: Option<Errno>,
166    pub instance: Instance,
167    allocator: WasmAllocator,
168    pub indirect_call_table: Table,
169    store: Arc<Mutex<Store<Host>>>,
170    state: TaskState,
171    marked_for_delete: bool,
172}
173
174impl Task {
175    fn new(
176        id: u32,
177        name: String,
178        mut store: Store<Host>,
179        instance: Instance,
180        task_impl: TypedFunc<(), ()>,
181    ) -> Self {
182        Self {
183            id,
184            name,
185            local_storage: None,
186            task_impl,
187            priority: 0,
188            errno: None,
189            allocator: WasmAllocator::new(&mut store, &instance),
190            indirect_call_table: instance
191                .get_table(&mut store, "__indirect_function_table")
192                .unwrap(),
193            instance,
194            store: Arc::new(Mutex::new(store)),
195            state: TaskState::Ready,
196            marked_for_delete: false,
197        }
198    }
199
200    pub async fn local_storage(
201        &mut self,
202        store: impl AsContextMut<Data = impl Send>,
203    ) -> TaskStorage {
204        if let Some(storage) = self.local_storage {
205            return storage;
206        }
207        let storage = TaskStorage::new(store, &self.allocator).await;
208        self.local_storage = Some(storage);
209        storage
210    }
211
212    pub async fn errno(&mut self, store: impl AsContextMut<Data = impl Send>) -> Errno {
213        if let Some(errno) = self.errno {
214            return errno;
215        }
216        let errno = Errno::new(store, &self.allocator).await;
217        self.errno = Some(errno);
218        errno
219    }
220
221    pub fn id(&self) -> u32 {
222        self.id
223    }
224
225    pub fn start(&mut self) -> impl Future<Output = anyhow::Result<()>> {
226        let store = self.store.clone();
227        let task_impl = self.task_impl;
228        async move {
229            let mut store = store.lock().await;
230            task_impl.call_async(&mut *store, ()).await
231        }
232    }
233
234    pub fn state(&self) -> TaskState {
235        self.state
236    }
237
238    pub fn name(&self) -> &str {
239        &self.name
240    }
241
242    pub fn allocator(&self) -> WasmAllocator {
243        self.allocator.clone()
244    }
245}
246impl PartialEq for Task {
247    fn eq(&self, other: &Self) -> bool {
248        self.id == other.id
249    }
250}
251impl Eq for Task {}
252
253pub type TaskHandle = Arc<Mutex<Task>>;
254
255pub struct TaskPool {
256    pool: HashMap<u32, TaskHandle>,
257    deleted_tasks: HashSet<u32>,
258    newest_task_id: u32,
259    current_task: Option<TaskHandle>,
260    engine: Engine,
261    shared_memory: SharedMemory,
262    scheduler_suspended: u32,
263    yield_pending: bool,
264    shutdown_pending: bool,
265    interface: SimulatorInterface,
266}
267
268impl TaskPool {
269    pub fn new(
270        engine: Engine,
271        shared_memory: SharedMemory,
272        interface: SimulatorInterface,
273    ) -> anyhow::Result<Self> {
274        Ok(Self {
275            pool: HashMap::new(),
276            deleted_tasks: HashSet::new(),
277            newest_task_id: 0,
278            current_task: None,
279            engine,
280            shared_memory,
281            scheduler_suspended: 0,
282            yield_pending: false,
283            shutdown_pending: false,
284            interface,
285        })
286    }
287
288    pub fn create_store(&mut self, host: &Host) -> anyhow::Result<Store<Host>> {
289        let store = Store::new(&self.engine, host.clone());
290        Ok(store)
291    }
292
293    pub async fn instantiate(
294        &mut self,
295        store: &mut Store<Host>,
296        module: &Module,
297        interface: &SimulatorInterface,
298    ) -> anyhow::Result<Instance> {
299        let mut linker = Linker::<Host>::new(&self.engine);
300
301        configure_api(&mut linker, store, self.shared_memory.clone())?;
302
303        for import in module.imports() {
304            if linker
305                .get(&mut *store, import.module(), import.name())
306                .is_none()
307            {
308                interface.send(SimulatorEvent::Warning(format!(
309                    "Unimplemented API `{}` (Robot code will crash if this is used)",
310                    import.name()
311                )));
312            }
313        }
314
315        linker.define_unknown_imports_as_traps(module)?;
316        let instance = linker.instantiate_async(store, module).await?;
317
318        Ok(instance)
319    }
320
321    pub async fn spawn(
322        &mut self,
323        opts: TaskOptions,
324        module: &Module,
325        interface: &SimulatorInterface,
326    ) -> anyhow::Result<TaskHandle> {
327        let TaskOptions {
328            priority,
329            entrypoint,
330            mut store,
331            name,
332            ..
333        } = opts;
334
335        let instance = self.instantiate(&mut store, module, interface).await?;
336
337        self.newest_task_id += 1;
338        let id = self.newest_task_id;
339
340        let mut task = Task::new(
341            id,
342            name.unwrap_or_else(|| format!("task {id}")),
343            store,
344            instance,
345            entrypoint,
346        );
347        task.priority = priority;
348        let task = Arc::new(Mutex::new(task));
349        self.pool.insert(id, task.clone());
350        Ok(task)
351    }
352
353    pub fn by_id(&self, task_id: u32) -> Option<TaskHandle> {
354        if task_id == 0 {
355            return Some(self.current());
356        }
357        self.pool.get(&task_id).cloned()
358    }
359
360    pub fn current(&self) -> TaskHandle {
361        self.current_task
362            .clone()
363            .expect("using the current task may only happen while a task is being executed")
364    }
365
366    pub async fn current_lock(&self) -> MutexGuard<'_, Task> {
367        self.current_task
368            .as_ref()
369            .expect("using the current task may only happen while a task is being executed")
370            .lock()
371            .await
372    }
373
374    #[inline]
375    pub async fn yield_now() {
376        futures_util::pending!();
377    }
378
379    /// Prevent context switches from happening until `resume_all` is called.
380    pub fn suspend_all(&mut self) {
381        self.scheduler_suspended += 1;
382    }
383
384    /// Resumes the scheduler, causing a yield if one is pending
385    ///
386    /// Returns whether resuming the scheduler caused a yield.
387    pub async fn resume_all(&mut self) -> anyhow::Result<bool> {
388        if self.scheduler_suspended == 0 {
389            bail!("rtos_resume_all called without a matching rtos_suspend_all");
390        }
391
392        self.scheduler_suspended -= 1;
393
394        if self.yield_pending && self.scheduler_suspended == 0 {
395            Self::yield_now().await;
396            Ok(true)
397        } else {
398            Ok(false)
399        }
400    }
401
402    async fn highest_priority_task_ids(&self) -> Vec<u32> {
403        let mut highest_priority = 0;
404        let mut highest_priority_tasks = vec![];
405        for task in self.pool.values() {
406            let task = task.lock().await;
407            if task.priority > highest_priority {
408                highest_priority = task.priority;
409                highest_priority_tasks.clear();
410            }
411            if task.priority == highest_priority {
412                highest_priority_tasks.push(task.id);
413            }
414        }
415        highest_priority_tasks.sort();
416        highest_priority_tasks
417    }
418
419    /// Switches to the next task in the task pool, if any. Returns whether there are running
420    /// tasks remaining.
421    ///
422    /// This function will loop through the tasks in a round-robin fashion, giving each task a
423    /// chance to run before looping back around to the beginning. Only tasks with the highest
424    /// priority will be considered.
425    pub async fn cycle_tasks(&mut self) -> bool {
426        if self.scheduler_suspended != 0 {
427            if self.current_task.is_some() {
428                self.yield_pending = true;
429                return true;
430            } else {
431                panic!("Scheduler is suspended and current task is missing");
432            }
433        }
434        self.yield_pending = false;
435
436        let task_candidates = self.highest_priority_task_ids().await;
437        let current_task_id = if let Some(task) = &self.current_task {
438            task.lock().await.id
439        } else {
440            0
441        };
442        let next_task = task_candidates
443            .iter()
444            .find(|id| **id > current_task_id)
445            .or_else(|| task_candidates.first())
446            .and_then(|id| self.by_id(*id));
447        self.current_task = next_task;
448        self.current_task.is_some()
449    }
450
451    pub async fn run_to_completion(host: &Host) -> anyhow::Result<()> {
452        let mut futures =
453            HashMap::<u32, Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send>>>::new();
454        loop {
455            let mut tasks = host.tasks_lock().await;
456            let running = tasks.cycle_tasks().await;
457            if !running {
458                break Ok(());
459            }
460
461            let mut task = tasks.current_lock().await;
462            let id = task.id();
463            let future = futures.entry(id).or_insert_with(|| Box::pin(task.start()));
464            drop(task);
465            drop(tasks);
466
467            let result = futures::poll!(future);
468
469            let tasks = host.tasks();
470            let mut tasks = tasks
471                .try_lock()
472                .expect("attempt to yield while task mutex is locked");
473            let task = tasks.current();
474            let mut task = task
475                .try_lock()
476                .expect("attempt to yield while current task is locked");
477
478            if tasks.shutdown_pending {
479                break Ok(());
480            }
481
482            if let Poll::Ready(result) = result {
483                task.marked_for_delete = true;
484                task.state = TaskState::Finished;
485                result?;
486            } else if task.marked_for_delete {
487                task.state = TaskState::Deleted;
488            }
489
490            if task.marked_for_delete {
491                if tasks.scheduler_suspended != 0 {
492                    // task called rtos_suspend_all and ended before calling rtos_resume_all
493                    tasks.interface.send(SimulatorEvent::Warning(format!(
494                        "Task `{}` (#{}) exited with scheduler in suspended state",
495                        &task.name, task.id,
496                    )));
497                }
498                drop(task);
499
500                tasks.scheduler_suspended = 0;
501                futures.remove(&id);
502                tasks.pool.remove(&id);
503            }
504        }
505    }
506
507    pub async fn task_state(&self, task_id: u32) -> Option<TaskState> {
508        if self.deleted_tasks.contains(&task_id) {
509            return Some(TaskState::Deleted);
510        }
511        if let Some(task) = self.pool.get(&task_id) {
512            let task = task.lock().await;
513            Some(task.state)
514        } else {
515            None
516        }
517    }
518
519    pub async fn delete_task(&mut self, task_id: u32) {
520        let task = self.pool.get(&task_id);
521        if let Some(task) = task {
522            let mut task = task.lock().await;
523            if task.state == TaskState::Running {
524                task.marked_for_delete = true;
525                Self::yield_now().await;
526                unreachable!("Deleted task may not continue execution");
527            }
528
529            task.state = TaskState::Deleted;
530            drop(task);
531            self.pool.remove(&task_id).unwrap();
532            self.deleted_tasks.insert(task_id);
533        }
534    }
535
536    pub fn start_shutdown(&mut self) {
537        self.shutdown_pending = true;
538    }
539}
540
541#[derive(Debug, Clone, Copy)]
542pub struct Errno {
543    address: u32,
544}
545
546impl Errno {
547    pub async fn new(
548        store: impl AsContextMut<Data = impl Send>,
549        allocator: &WasmAllocator,
550    ) -> Self {
551        let address = allocator
552            .memalign(store, std::alloc::Layout::new::<i32>())
553            .await;
554        Self { address }
555    }
556    pub fn address(&self) -> u32 {
557        self.address
558    }
559    pub fn set(&self, memory: &SharedMemory, new_errno: i32) {
560        let buffer = new_errno.to_le_bytes();
561        memory
562            .write_relaxed(self.address as usize, &buffer)
563            .unwrap();
564    }
565}