async_callback_manager/
manager.rs

1use crate::manager::task_list::{SpawnedTask, TaskInformation, TaskList, TaskOutcome, TaskWaiter};
2use crate::task::{AsyncTask, AsyncTaskKind, FutureTask, StreamTask};
3use crate::{Constraint, DEFAULT_STREAM_CHANNEL_SIZE};
4use futures::StreamExt;
5use std::any::TypeId;
6use std::sync::Arc;
7
8pub mod task_list;
9
10#[derive(Copy, Clone, Debug, Eq, PartialEq)]
11pub struct TaskId(pub(crate) u64);
12
13pub(crate) type DynTaskSpawnCallback<Cstrnt> = dyn Fn(TaskInformation<Cstrnt>);
14
15pub struct AsyncCallbackManager<Frntend, Bkend, Md> {
16    next_task_id: u64,
17    tasks_list: TaskList<Frntend, Bkend, Md>,
18    // It could be possible to make this generic instead of dynamic, however this type would then
19    // require more type parameters.
20    on_task_spawn: Box<DynTaskSpawnCallback<Md>>,
21}
22
23/// Temporary struct to store task details before it is added to the task list.
24pub(crate) struct TempSpawnedTask<Frntend, Bkend, Md> {
25    waiter: TaskWaiter<Frntend, Bkend, Md>,
26    type_id: TypeId,
27    type_name: &'static str,
28    type_debug: Arc<String>,
29}
30
31impl<Frntend, Bkend, Md: PartialEq> Default for AsyncCallbackManager<Frntend, Bkend, Md> {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl<Frntend, Bkend, Md: PartialEq> AsyncCallbackManager<Frntend, Bkend, Md> {
38    /// Get a new AsyncCallbackManager.
39    pub fn new() -> Self {
40        Self {
41            next_task_id: Default::default(),
42            tasks_list: TaskList::new(),
43            on_task_spawn: Box::new(|_| {}),
44        }
45    }
46    pub fn with_on_task_spawn_callback(
47        mut self,
48        cb: impl Fn(TaskInformation<Md>) + 'static,
49    ) -> Self {
50        self.on_task_spawn = Box::new(cb);
51        self
52    }
53    /// Await for the next response from one of the spawned tasks, or returns
54    /// None if no tasks were in the list.
55    pub async fn get_next_response(&mut self) -> Option<TaskOutcome<Frntend, Bkend, Md>> {
56        self.tasks_list.get_next_response().await
57    }
58    pub fn spawn_task(&mut self, backend: &Bkend, task: AsyncTask<Frntend, Bkend, Md>)
59    where
60        Frntend: 'static,
61        Bkend: 'static,
62        Md: 'static,
63    {
64        let AsyncTask {
65            task,
66            constraint,
67            metadata,
68        } = task;
69        match task {
70            AsyncTaskKind::Future(future_task) => {
71                let outcome = self.spawn_future_task(backend, future_task, &constraint);
72                self.add_task_to_list(outcome, metadata, constraint);
73            }
74            AsyncTaskKind::Stream(stream_task) => {
75                let outcome = self.spawn_stream_task(backend, stream_task, &constraint);
76                self.add_task_to_list(outcome, metadata, constraint);
77            }
78            AsyncTaskKind::Multi(tasks) => {
79                for task in tasks {
80                    self.spawn_task(backend, task)
81                }
82            }
83            AsyncTaskKind::NoOp => (),
84        }
85    }
86    fn add_task_to_list(
87        &mut self,
88        details: TempSpawnedTask<Frntend, Bkend, Md>,
89        metadata: Vec<Md>,
90        constraint: Option<Constraint<Md>>,
91    ) {
92        let TempSpawnedTask {
93            waiter,
94            type_id,
95            type_name,
96            type_debug,
97        } = details;
98        let sp = SpawnedTask {
99            type_id,
100            task_id: TaskId(self.next_task_id),
101            type_name,
102            type_debug,
103            receiver: waiter,
104            metadata,
105        };
106        // At one task per nanosecond, it would take 584.6 years for a library user to
107        // trigger overflow.
108        //
109        // https://www.wolframalpha.com/input?i=2%5E64+nanoseconds
110        let new_id = self
111            .next_task_id
112            .checked_add(1)
113            .expect("u64 shouldn't overflow!");
114        self.next_task_id = new_id;
115        if let Some(constraint) = constraint {
116            self.tasks_list.handle_constraint(constraint, type_id);
117        }
118        self.tasks_list.push(sp);
119    }
120    fn spawn_future_task(
121        &self,
122        backend: &Bkend,
123        future_task: FutureTask<Frntend, Bkend, Md>,
124        constraint: &Option<Constraint<Md>>,
125    ) -> TempSpawnedTask<Frntend, Bkend, Md>
126    where
127        Frntend: 'static,
128        Bkend: 'static,
129        Md: 'static,
130    {
131        let FutureTask {
132            task,
133            type_id,
134            type_name,
135            ref type_debug,
136        } = future_task;
137        (self.on_task_spawn)(TaskInformation {
138            type_id,
139            type_name,
140            type_debug,
141            constraint,
142        });
143        let future = task.into_dyn_task()(backend);
144        let handle = tokio::spawn(future);
145        TempSpawnedTask {
146            waiter: TaskWaiter::Future(handle),
147            type_id: future_task.type_id,
148            type_name: future_task.type_name,
149            type_debug: Arc::new(future_task.type_debug),
150        }
151    }
152    fn spawn_stream_task(
153        &self,
154        backend: &Bkend,
155        stream_task: StreamTask<Frntend, Bkend, Md>,
156        constraint: &Option<Constraint<Md>>,
157    ) -> TempSpawnedTask<Frntend, Bkend, Md>
158    where
159        Frntend: 'static,
160        Bkend: 'static,
161        Md: 'static,
162    {
163        let StreamTask {
164            task,
165            type_id,
166            type_name,
167            type_debug,
168        } = stream_task;
169        (self.on_task_spawn)(TaskInformation {
170            type_id,
171            type_name,
172            type_debug: &type_debug,
173            constraint,
174        });
175        let mut stream = task.into_dyn_stream()(backend);
176        let (tx, rx) = tokio::sync::mpsc::channel(DEFAULT_STREAM_CHANNEL_SIZE);
177        let join_handle = tokio::spawn(async move {
178            loop {
179                if let Some(mutation) = stream.next().await {
180                    // Error could occur here if receiver is dropped.
181                    // Doesn't seem to be a big deal to ignore this error.
182                    let _ = tx.send(mutation).await;
183                    continue;
184                }
185                return;
186            }
187        });
188        TempSpawnedTask {
189            waiter: TaskWaiter::Stream {
190                receiver: rx,
191                join_handle,
192            },
193            type_id,
194            type_name,
195            type_debug: Arc::new(type_debug),
196        }
197    }
198}