1use std::{
2 collections::{HashMap, HashSet},
3 future::Future,
4 pin::Pin,
5 sync::Arc,
6 task::Poll,
7};
8
9use anyhow::{bail, Context};
10use pros_simulator_interface::SimulatorEvent;
11use tokio::sync::{Mutex, MutexGuard};
12use wasmtime::{
13 AsContextMut, Caller, Engine, Func, Instance, Linker, Module, SharedMemory, Store, Table,
14 TypedFunc, WasmParams,
15};
16
17use super::{memory::SharedMemoryExt, thread_local::TaskStorage, Host, HostCtx, WasmAllocator};
18use crate::{api::configure_api, interface::SimulatorInterface};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum TaskState {
22 Running,
24 Ready,
26 Finished,
28 Blocked,
29 Deleted,
31}
32
33pub const TASK_PRIORITIES: u32 = 16;
34
35pub struct TaskOptions {
36 priority: u32,
37 store: Store<Host>,
38 entrypoint: TypedFunc<(), ()>,
39 name: Option<String>,
40}
41
42impl TaskOptions {
43 pub fn new_extern<P: WasmParams + 'static>(
54 pool: &mut TaskPool,
55 host: &Host,
56 task_start: u32,
57 args: P,
58 ) -> anyhow::Result<Self> {
59 let args = Arc::new(Mutex::new(Some(args)));
60 Self::new_closure(pool, host, move |mut caller| {
61 let args = args.clone();
62 Box::new(async move {
63 let entrypoint = {
64 let task_handle = caller.current_task().await;
65 let current_task = task_handle.lock().await;
66 current_task
67 .indirect_call_table
68 .get(&mut caller, task_start)
69 .context("Task entrypoint is out of bounds")?
70 };
71
72 let entrypoint = entrypoint
73 .funcref()
74 .context("Task entrypoint is not a function")?
75 .context("Task entrypoint is NULL")?
76 .typed::<P, ()>(&mut caller)
77 .context("Task entrypoint has invalid signature")?;
78
79 entrypoint
80 .call_async(&mut caller, args.lock().await.take().unwrap())
81 .await?;
82 Ok(())
83 })
84 })
85 }
86
87 pub fn new_closure(
90 pool: &mut TaskPool,
91 host: &Host,
92 task_closure: impl for<'a> FnOnce(
93 Caller<'a, Host>,
94 ) -> Box<dyn Future<Output = anyhow::Result<()>> + Send + 'a>
95 + Send
96 + 'static,
97 ) -> anyhow::Result<Self> {
98 let mut store = pool.create_store(host)?;
99 let task_closure = Arc::new(Mutex::new(Some(task_closure)));
100 let entrypoint = Func::wrap0_async(&mut store, move |caller: Caller<'_, Host>| {
101 let task_closure = task_closure.clone();
102 Box::new(async move {
103 let task_closure = task_closure
104 .lock()
105 .await
106 .take()
107 .expect("Expected task to only be started once");
108 Pin::from(task_closure(caller)).await
109 })
110 })
111 .typed::<(), ()>(&mut store)?;
112
113 Ok(Self {
114 priority: 7,
115 entrypoint,
116 store,
117 name: None,
118 })
119 }
120
121 pub fn new_global(
123 pool: &mut TaskPool,
124 host: &Host,
125 func_name: &'static str,
126 ) -> anyhow::Result<Self> {
127 Self::new_closure(pool, host, move |mut caller| {
128 Box::new(async move {
129 let instance = {
130 let task_handle = caller.current_task().await;
131 let this_task = task_handle.lock().await;
132 this_task.instance
133 };
134
135 let func = instance.get_func(&mut caller, func_name).with_context(|| {
136 format!("entrypoint missing: expected {func_name} to be defined")
137 })?;
138 let func = func
139 .typed(&mut caller)
140 .with_context(|| format!("invalid {func_name} signature: expected () -> ()"))?;
141
142 func.call_async(&mut caller, ()).await
143 })
144 })
145 }
146
147 pub fn name(mut self, name: impl Into<String>) -> Self {
148 self.name = Some(name.into());
149 self
150 }
151
152 pub fn priority(mut self, priority: u32) -> Self {
153 assert!(priority < TASK_PRIORITIES);
154 self.priority = priority;
155 self
156 }
157}
158
159pub struct Task {
160 id: u32,
161 name: String,
162 local_storage: Option<TaskStorage>,
163 task_impl: TypedFunc<(), ()>,
164 priority: u32,
165 errno: Option<Errno>,
166 pub instance: Instance,
167 allocator: WasmAllocator,
168 pub indirect_call_table: Table,
169 store: Arc<Mutex<Store<Host>>>,
170 state: TaskState,
171 marked_for_delete: bool,
172}
173
174impl Task {
175 fn new(
176 id: u32,
177 name: String,
178 mut store: Store<Host>,
179 instance: Instance,
180 task_impl: TypedFunc<(), ()>,
181 ) -> Self {
182 Self {
183 id,
184 name,
185 local_storage: None,
186 task_impl,
187 priority: 0,
188 errno: None,
189 allocator: WasmAllocator::new(&mut store, &instance),
190 indirect_call_table: instance
191 .get_table(&mut store, "__indirect_function_table")
192 .unwrap(),
193 instance,
194 store: Arc::new(Mutex::new(store)),
195 state: TaskState::Ready,
196 marked_for_delete: false,
197 }
198 }
199
200 pub async fn local_storage(
201 &mut self,
202 store: impl AsContextMut<Data = impl Send>,
203 ) -> TaskStorage {
204 if let Some(storage) = self.local_storage {
205 return storage;
206 }
207 let storage = TaskStorage::new(store, &self.allocator).await;
208 self.local_storage = Some(storage);
209 storage
210 }
211
212 pub async fn errno(&mut self, store: impl AsContextMut<Data = impl Send>) -> Errno {
213 if let Some(errno) = self.errno {
214 return errno;
215 }
216 let errno = Errno::new(store, &self.allocator).await;
217 self.errno = Some(errno);
218 errno
219 }
220
221 pub fn id(&self) -> u32 {
222 self.id
223 }
224
225 pub fn start(&mut self) -> impl Future<Output = anyhow::Result<()>> {
226 let store = self.store.clone();
227 let task_impl = self.task_impl;
228 async move {
229 let mut store = store.lock().await;
230 task_impl.call_async(&mut *store, ()).await
231 }
232 }
233
234 pub fn state(&self) -> TaskState {
235 self.state
236 }
237
238 pub fn name(&self) -> &str {
239 &self.name
240 }
241
242 pub fn allocator(&self) -> WasmAllocator {
243 self.allocator.clone()
244 }
245}
246impl PartialEq for Task {
247 fn eq(&self, other: &Self) -> bool {
248 self.id == other.id
249 }
250}
251impl Eq for Task {}
252
253pub type TaskHandle = Arc<Mutex<Task>>;
254
255pub struct TaskPool {
256 pool: HashMap<u32, TaskHandle>,
257 deleted_tasks: HashSet<u32>,
258 newest_task_id: u32,
259 current_task: Option<TaskHandle>,
260 engine: Engine,
261 shared_memory: SharedMemory,
262 scheduler_suspended: u32,
263 yield_pending: bool,
264 shutdown_pending: bool,
265 interface: SimulatorInterface,
266}
267
268impl TaskPool {
269 pub fn new(
270 engine: Engine,
271 shared_memory: SharedMemory,
272 interface: SimulatorInterface,
273 ) -> anyhow::Result<Self> {
274 Ok(Self {
275 pool: HashMap::new(),
276 deleted_tasks: HashSet::new(),
277 newest_task_id: 0,
278 current_task: None,
279 engine,
280 shared_memory,
281 scheduler_suspended: 0,
282 yield_pending: false,
283 shutdown_pending: false,
284 interface,
285 })
286 }
287
288 pub fn create_store(&mut self, host: &Host) -> anyhow::Result<Store<Host>> {
289 let store = Store::new(&self.engine, host.clone());
290 Ok(store)
291 }
292
293 pub async fn instantiate(
294 &mut self,
295 store: &mut Store<Host>,
296 module: &Module,
297 interface: &SimulatorInterface,
298 ) -> anyhow::Result<Instance> {
299 let mut linker = Linker::<Host>::new(&self.engine);
300
301 configure_api(&mut linker, store, self.shared_memory.clone())?;
302
303 for import in module.imports() {
304 if linker
305 .get(&mut *store, import.module(), import.name())
306 .is_none()
307 {
308 interface.send(SimulatorEvent::Warning(format!(
309 "Unimplemented API `{}` (Robot code will crash if this is used)",
310 import.name()
311 )));
312 }
313 }
314
315 linker.define_unknown_imports_as_traps(module)?;
316 let instance = linker.instantiate_async(store, module).await?;
317
318 Ok(instance)
319 }
320
321 pub async fn spawn(
322 &mut self,
323 opts: TaskOptions,
324 module: &Module,
325 interface: &SimulatorInterface,
326 ) -> anyhow::Result<TaskHandle> {
327 let TaskOptions {
328 priority,
329 entrypoint,
330 mut store,
331 name,
332 ..
333 } = opts;
334
335 let instance = self.instantiate(&mut store, module, interface).await?;
336
337 self.newest_task_id += 1;
338 let id = self.newest_task_id;
339
340 let mut task = Task::new(
341 id,
342 name.unwrap_or_else(|| format!("task {id}")),
343 store,
344 instance,
345 entrypoint,
346 );
347 task.priority = priority;
348 let task = Arc::new(Mutex::new(task));
349 self.pool.insert(id, task.clone());
350 Ok(task)
351 }
352
353 pub fn by_id(&self, task_id: u32) -> Option<TaskHandle> {
354 if task_id == 0 {
355 return Some(self.current());
356 }
357 self.pool.get(&task_id).cloned()
358 }
359
360 pub fn current(&self) -> TaskHandle {
361 self.current_task
362 .clone()
363 .expect("using the current task may only happen while a task is being executed")
364 }
365
366 pub async fn current_lock(&self) -> MutexGuard<'_, Task> {
367 self.current_task
368 .as_ref()
369 .expect("using the current task may only happen while a task is being executed")
370 .lock()
371 .await
372 }
373
374 #[inline]
375 pub async fn yield_now() {
376 futures_util::pending!();
377 }
378
379 pub fn suspend_all(&mut self) {
381 self.scheduler_suspended += 1;
382 }
383
384 pub async fn resume_all(&mut self) -> anyhow::Result<bool> {
388 if self.scheduler_suspended == 0 {
389 bail!("rtos_resume_all called without a matching rtos_suspend_all");
390 }
391
392 self.scheduler_suspended -= 1;
393
394 if self.yield_pending && self.scheduler_suspended == 0 {
395 Self::yield_now().await;
396 Ok(true)
397 } else {
398 Ok(false)
399 }
400 }
401
402 async fn highest_priority_task_ids(&self) -> Vec<u32> {
403 let mut highest_priority = 0;
404 let mut highest_priority_tasks = vec![];
405 for task in self.pool.values() {
406 let task = task.lock().await;
407 if task.priority > highest_priority {
408 highest_priority = task.priority;
409 highest_priority_tasks.clear();
410 }
411 if task.priority == highest_priority {
412 highest_priority_tasks.push(task.id);
413 }
414 }
415 highest_priority_tasks.sort();
416 highest_priority_tasks
417 }
418
419 pub async fn cycle_tasks(&mut self) -> bool {
426 if self.scheduler_suspended != 0 {
427 if self.current_task.is_some() {
428 self.yield_pending = true;
429 return true;
430 } else {
431 panic!("Scheduler is suspended and current task is missing");
432 }
433 }
434 self.yield_pending = false;
435
436 let task_candidates = self.highest_priority_task_ids().await;
437 let current_task_id = if let Some(task) = &self.current_task {
438 task.lock().await.id
439 } else {
440 0
441 };
442 let next_task = task_candidates
443 .iter()
444 .find(|id| **id > current_task_id)
445 .or_else(|| task_candidates.first())
446 .and_then(|id| self.by_id(*id));
447 self.current_task = next_task;
448 self.current_task.is_some()
449 }
450
451 pub async fn run_to_completion(host: &Host) -> anyhow::Result<()> {
452 let mut futures =
453 HashMap::<u32, Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send>>>::new();
454 loop {
455 let mut tasks = host.tasks_lock().await;
456 let running = tasks.cycle_tasks().await;
457 if !running {
458 break Ok(());
459 }
460
461 let mut task = tasks.current_lock().await;
462 let id = task.id();
463 let future = futures.entry(id).or_insert_with(|| Box::pin(task.start()));
464 drop(task);
465 drop(tasks);
466
467 let result = futures::poll!(future);
468
469 let tasks = host.tasks();
470 let mut tasks = tasks
471 .try_lock()
472 .expect("attempt to yield while task mutex is locked");
473 let task = tasks.current();
474 let mut task = task
475 .try_lock()
476 .expect("attempt to yield while current task is locked");
477
478 if tasks.shutdown_pending {
479 break Ok(());
480 }
481
482 if let Poll::Ready(result) = result {
483 task.marked_for_delete = true;
484 task.state = TaskState::Finished;
485 result?;
486 } else if task.marked_for_delete {
487 task.state = TaskState::Deleted;
488 }
489
490 if task.marked_for_delete {
491 if tasks.scheduler_suspended != 0 {
492 tasks.interface.send(SimulatorEvent::Warning(format!(
494 "Task `{}` (#{}) exited with scheduler in suspended state",
495 &task.name, task.id,
496 )));
497 }
498 drop(task);
499
500 tasks.scheduler_suspended = 0;
501 futures.remove(&id);
502 tasks.pool.remove(&id);
503 }
504 }
505 }
506
507 pub async fn task_state(&self, task_id: u32) -> Option<TaskState> {
508 if self.deleted_tasks.contains(&task_id) {
509 return Some(TaskState::Deleted);
510 }
511 if let Some(task) = self.pool.get(&task_id) {
512 let task = task.lock().await;
513 Some(task.state)
514 } else {
515 None
516 }
517 }
518
519 pub async fn delete_task(&mut self, task_id: u32) {
520 let task = self.pool.get(&task_id);
521 if let Some(task) = task {
522 let mut task = task.lock().await;
523 if task.state == TaskState::Running {
524 task.marked_for_delete = true;
525 Self::yield_now().await;
526 unreachable!("Deleted task may not continue execution");
527 }
528
529 task.state = TaskState::Deleted;
530 drop(task);
531 self.pool.remove(&task_id).unwrap();
532 self.deleted_tasks.insert(task_id);
533 }
534 }
535
536 pub fn start_shutdown(&mut self) {
537 self.shutdown_pending = true;
538 }
539}
540
541#[derive(Debug, Clone, Copy)]
542pub struct Errno {
543 address: u32,
544}
545
546impl Errno {
547 pub async fn new(
548 store: impl AsContextMut<Data = impl Send>,
549 allocator: &WasmAllocator,
550 ) -> Self {
551 let address = allocator
552 .memalign(store, std::alloc::Layout::new::<i32>())
553 .await;
554 Self { address }
555 }
556 pub fn address(&self) -> u32 {
557 self.address
558 }
559 pub fn set(&self, memory: &SharedMemory, new_errno: i32) {
560 let buffer = new_errno.to_le_bytes();
561 memory
562 .write_relaxed(self.address as usize, &buffer)
563 .unwrap();
564 }
565}