1use core::{
2 future::Future,
3 marker::PhantomData,
4 mem,
5 pin::Pin,
6 task::{
7 Context,
8 Poll,
9 Waker,
10 },
11};
12
13use alloc::{
14 collections::VecDeque,
15 sync::Arc,
16};
17
18use futures::{
19 future::{
20 FutureObj,
21 LocalFutureObj,
22 UnsafeFutureObj,
23 },
24 task::{
25 LocalSpawn,
26 Spawn,
27 SpawnError,
28 },
29};
30
31use lock_api::{
32 Mutex,
33 RawMutex,
34};
35
36use generational_arena::{
37 Arena,
38 Index,
39};
40
41use crate::{
42 future_box,
43 sleep::*,
44 wake::{
45 Wake,
46 WakeExt,
47 },
48};
49
50const REG_CAP: usize = 16;
52
53const QUEUE_CAP: usize = REG_CAP / 2;
55
56pub struct AllocExecutor<'a, R, S>
70where
71 R: RawMutex,
72{
73 registry: Arena<Task<'a>>,
74 queue: QueueHandle<'a, R>,
75 sleep_waker: S,
76}
77
78enum SpawnLoc {
80 Front,
81 Back,
82}
83
84impl<'a, R, S> Default for AllocExecutor<'a, R, S>
85where
86 R: RawMutex,
87 S: Sleep + Wake + Clone + Default,
88{
89 fn default() -> Self {
90 Self::new()
91 }
92}
93
94impl<'a, R, S> AllocExecutor<'a, R, S>
95where
96 R: RawMutex,
97 S: Sleep + Wake + Clone + Default,
98{
99 pub fn new() -> Self {
103 Self::with_capacity(REG_CAP, QUEUE_CAP)
104 }
105
106 pub fn with_capacity(registry: usize, queue: usize) -> Self {
110 AllocExecutor {
111 registry: Arena::with_capacity(registry),
112 queue: new_queue(queue),
113 sleep_waker: S::default(),
114 }
115 }
116
117 pub fn spawner(&self) -> Spawner<'a, R> {
120 Spawner::new(self.queue.clone())
121 }
122
123 pub fn local_spawner(&self) -> LocalSpawner<'a, R> {
126 LocalSpawner::new(Spawner::new(self.queue.clone()))
127 }
128
129 fn spawn_local(&mut self, future: LocalFutureObj<'a, ()>, loc: SpawnLoc) {
138 let id = self.registry.insert(Task::new(future));
139
140 let queue_waker = Arc::new(QueueWaker::new(
141 self.queue.clone(),
142 id,
143 self.sleep_waker.clone(),
144 ));
145
146 let waker = queue_waker.into_waker();
147 self.registry.get_mut(id).unwrap().set_waker(waker);
148
149 let item = QueueItem::Poll(id);
150 let mut lock = self.queue.lock();
151
152 match loc {
153 SpawnLoc::Front => lock.push_front(item),
154 SpawnLoc::Back => lock.push_back(item),
155 }
156 }
157
158 pub fn spawn_raw<F>(&mut self, future: F)
160 where
161 F: UnsafeFutureObj<'a, ()>,
162 {
163 self.spawn_local(LocalFutureObj::new(future), SpawnLoc::Back)
164 }
165
166 pub fn spawn<F>(&mut self, future: F)
170 where
171 F: Future<Output = ()> + 'a,
172 {
173 self.spawn_raw(future_box::make_local(future));
174 }
175
176 fn poll_task(&mut self, id: Index) {
181 if let Some(Task { future, waker }) = self.registry.get_mut(id) {
185 let future = Pin::new(future);
186
187 let waker = waker
188 .as_ref()
189 .expect("waker not set, task spawned incorrectly");
190
191 match future.poll(&mut Context::from_waker(waker)) {
192 Poll::Ready(_) => {
193 self.registry.remove(id);
194 }
195 Poll::Pending => {}
196 }
197 }
198 }
199
200 fn dequeue(&self) -> Option<QueueItem<'a>> {
202 self.queue.lock().pop_front()
203 }
204
205 pub fn run(&mut self) {
214 'outer: loop {
215 while let Some(item) = self.dequeue() {
216 match item {
217 QueueItem::Poll(id) => {
218 self.poll_task(id);
219 }
220 QueueItem::Spawn(task) => {
221 self.spawn_local(task.into(), SpawnLoc::Front);
222 }
223 }
224 if self.registry.is_empty() {
225 break 'outer;
226 }
227 self.sleep_waker.sleep();
228 }
229 }
230 }
231}
232
233struct Task<'a> {
234 future: LocalFutureObj<'a, ()>,
235 waker: Option<Waker>,
237}
238
239impl<'a> Task<'a> {
240 fn new(future: LocalFutureObj<'a, ()>) -> Task<'a> {
241 Task {
242 future,
243 waker: None,
244 }
245 }
246 fn set_waker(&mut self, waker: Waker) {
247 self.waker = Some(waker);
248 }
249}
250
251type Queue<'a> = VecDeque<QueueItem<'a>>;
252
253type QueueHandle<'a, R> = Arc<Mutex<R, Queue<'a>>>;
254
255fn new_queue<'a, R>(capacity: usize) -> QueueHandle<'a, R>
256where
257 R: RawMutex,
258{
259 Arc::new(Mutex::new(Queue::with_capacity(capacity)))
260}
261
262enum QueueItem<'a> {
263 Poll(Index),
264 Spawn(FutureObj<'a, ()>),
265}
266
267struct QueueWaker<R, W>
270where
271 R: RawMutex,
272{
273 queue: QueueHandle<'static, R>,
274 id: Index,
275 waker: W,
276}
277
278impl<R, W> QueueWaker<R, W>
279where
280 R: RawMutex,
281 W: Wake,
282{
283 fn new(queue: QueueHandle<'_, R>, id: Index, waker: W) -> Self {
284 QueueWaker {
285 queue: unsafe { mem::transmute(queue) },
289 id,
290 waker,
291 }
292 }
293}
294
295impl<R, W> Wake for QueueWaker<R, W>
296where
297 R: RawMutex,
298 W: Wake,
299{
300 fn wake(&self) {
301 self.queue.lock().push_back(QueueItem::Poll(self.id));
302 self.waker.wake();
303 }
304}
305
306#[derive(Clone)]
312pub struct LocalSpawner<'a, R>(Spawner<'a, R>, PhantomData<LocalFutureObj<'a, ()>>)
313where
314 R: RawMutex;
315
316impl<'a, R> LocalSpawner<'a, R>
317where
318 R: RawMutex,
319{
320 fn new(spawner: Spawner<'a, R>) -> Self {
321 LocalSpawner(spawner, PhantomData)
322 }
323}
324
325impl<'a, R> LocalSpawner<'a, R>
326where
327 R: RawMutex,
328{
329 fn spawn_local(&self, future: LocalFutureObj<'a, ()>) -> Result<(), SpawnError> {
330 Ok(self
333 .0
334 .spawn_obj(unsafe { mem::transmute(future.into_future_obj()) }))
335 }
336
337 pub fn spawn_raw<F>(&mut self, future: F) -> Result<(), SpawnError>
339 where
340 F: UnsafeFutureObj<'a, ()>,
341 {
342 self.spawn_local(LocalFutureObj::new(future))
343 }
344
345 pub fn spawn<F>(&mut self, future: F) -> Result<(), SpawnError>
351 where
352 F: Future<Output = ()> + 'a,
353 {
354 self.spawn_raw(future_box::make_local(future))
355 }
356}
357
358impl<'a, R> LocalSpawn for LocalSpawner<'a, R>
359where
360 R: RawMutex,
361{
362 fn spawn_local_obj(&self, future: LocalFutureObj<'a, ()>) -> Result<(), SpawnError> {
363 self.spawn_local(future)
364 }
365}
366
367pub struct Spawner<'a, R>(QueueHandle<'a, R>)
372where
373 R: RawMutex;
374
375impl<'a, R> Spawner<'a, R>
376where
377 R: RawMutex,
378{
379 fn new(handle: QueueHandle<'a, R>) -> Self {
380 Spawner(handle)
381 }
382
383 fn spawn_obj(&self, future: FutureObj<'a, ()>) {
384 self.0.lock().push_back(QueueItem::Spawn(future));
385 }
386
387 pub fn spawn_raw<F>(&self, future: F)
389 where
390 F: UnsafeFutureObj<'a, ()> + Send + 'a,
391 {
392 Spawner::spawn_obj(self, FutureObj::new(future));
393 }
394
395 pub fn spawn<F>(&mut self, future: F)
401 where
402 F: Future<Output = ()> + Send + 'a,
403 {
404 self.spawn_raw(future_box::make_obj(future));
405 }
406}
407
408impl<'a, R> Clone for Spawner<'a, R>
409where
410 R: RawMutex,
411{
412 fn clone(&self) -> Self {
413 Spawner(self.0.clone())
414 }
415}
416
417impl<'a, R> Spawn for Spawner<'a, R>
418where
419 R: RawMutex,
420{
421 fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError> {
422 Ok(Spawner::spawn_obj(self, future))
423 }
424}
425
426impl<'a, R> From<LocalSpawner<'a, R>> for Spawner<'a, R>
427where
428 R: RawMutex,
429{
430 fn from(other: LocalSpawner<'a, R>) -> Self {
431 other.0
432 }
433}
434
435#[cfg(test)]
436mod test {
437 use super::*;
438 use crate::sleep::Sleep;
439 use core::sync::atomic::{
440 AtomicBool,
441 Ordering,
442 };
443 use futures::{
444 future::{
445 self,
446 FutureObj,
447 },
448 task::Spawn,
449 };
450 use lock_api::GuardSend;
451
452 pub struct RawSpinlock(AtomicBool);
455
456 unsafe impl RawMutex for RawSpinlock {
458 const INIT: RawSpinlock = RawSpinlock(AtomicBool::new(false));
459
460 type GuardMarker = GuardSend;
462
463 fn lock(&self) {
464 while !self.try_lock() {}
467 }
468
469 fn try_lock(&self) -> bool {
470 self.0.swap(true, Ordering::Acquire)
471 }
472
473 unsafe fn unlock(&self) {
474 self.0.store(false, Ordering::Release);
475 }
476 }
477 #[derive(Copy, Clone, Default)]
478 struct NopSleep;
479
480 impl Sleep for NopSleep {
481 fn sleep(&self) {}
482 }
483
484 impl Wake for NopSleep {
485 fn wake(&self) {}
486 }
487
488 async fn foo() -> i32 {
489 5
490 }
491
492 async fn bar() -> i32 {
493 let a = foo().await;
494 println!("{}", a);
495 let b = a + 1;
496 b
497 }
498
499 async fn baz<S: Spawn>(spawner: S) {
500 let c = bar().await;
501 for i in c..25 {
502 let spam = async move {
503 println!("{}", i);
504 };
505 println!("spawning!");
506 spawner
507 .spawn_obj(FutureObj::new(future_box::make_obj(spam)))
508 .unwrap();
509 }
510 }
511 #[test]
512 fn executor() {
513 let mut executor = AllocExecutor::<RawSpinlock, NopSleep>::new();
514 let spawner = executor.spawner();
515 let entry = future::lazy(move |_| {
516 for i in 0..10 {
517 spawner.spawn_raw(future_box::make_obj(future::lazy(move |_| {
518 println!("{}", i);
519 })));
520 }
521 });
522 executor.spawn(entry);
523 executor.spawn(baz(executor.spawner()));
524 executor.run();
525 }
526}