async_callback_manager/
manager.rs

1use crate::task::{
2    AsyncTask, AsyncTaskKind, FutureTask, SpawnedTask, StreamTask, TaskInformation, TaskList,
3    TaskOutcome, TaskWaiter,
4};
5use crate::{Constraint, DEFAULT_STREAM_CHANNEL_SIZE};
6use futures::{Stream, StreamExt};
7use std::any::TypeId;
8use std::future::Future;
9use std::sync::Arc;
10
11#[derive(Copy, Clone, Debug, Eq, PartialEq)]
12pub struct TaskId(pub(crate) u64);
13
14pub(crate) type DynStateMutation<Frntend, Bkend, Md> =
15    Box<dyn FnOnce(&mut Frntend) -> AsyncTask<Frntend, Bkend, Md> + Send>;
16pub(crate) type DynMutationFuture<Frntend, Bkend, Md> =
17    Box<dyn Future<Output = DynStateMutation<Frntend, Bkend, Md>> + Unpin + Send>;
18pub(crate) type DynMutationStream<Frntend, Bkend, Md> =
19    Box<dyn Stream<Item = DynStateMutation<Frntend, Bkend, Md>> + Unpin + Send>;
20pub(crate) type DynFutureTask<Frntend, Bkend, Md> =
21    Box<dyn FnOnce(&Bkend) -> DynMutationFuture<Frntend, Bkend, Md>>;
22pub(crate) type DynStreamTask<Frntend, Bkend, Md> =
23    Box<dyn FnOnce(&Bkend) -> DynMutationStream<Frntend, Bkend, Md>>;
24
25pub(crate) type DynTaskSpawnCallback<Cstrnt> = dyn Fn(TaskInformation<Cstrnt>);
26
27pub struct AsyncCallbackManager<Frntend, Bkend, Md> {
28    next_task_id: u64,
29    tasks_list: TaskList<Frntend, Bkend, Md>,
30    // It could be possible to make this generic instead of dynamic, however this type would then
31    // require more type parameters.
32    on_task_spawn: Box<DynTaskSpawnCallback<Md>>,
33}
34
35/// Temporary struct to store task details before it is added to the task list.
36pub(crate) struct TempSpawnedTask<Frntend, Bkend, Md> {
37    waiter: TaskWaiter<Frntend, Bkend, Md>,
38    type_id: TypeId,
39    type_name: &'static str,
40    type_debug: Arc<String>,
41}
42
43impl<Frntend, Bkend, Md: PartialEq> Default for AsyncCallbackManager<Frntend, Bkend, Md> {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49impl<Frntend, Bkend, Md: PartialEq> AsyncCallbackManager<Frntend, Bkend, Md> {
50    /// Get a new AsyncCallbackManager.
51    pub fn new() -> Self {
52        Self {
53            next_task_id: Default::default(),
54            tasks_list: TaskList::new(),
55            on_task_spawn: Box::new(|_| {}),
56        }
57    }
58    pub fn with_on_task_spawn_callback(
59        mut self,
60        cb: impl Fn(TaskInformation<Md>) + 'static,
61    ) -> Self {
62        self.on_task_spawn = Box::new(cb);
63        self
64    }
65    /// Await for the next response from one of the spawned tasks, or returns
66    /// None if no tasks were in the list.
67    pub async fn get_next_response(&mut self) -> Option<TaskOutcome<Frntend, Bkend, Md>> {
68        self.tasks_list.get_next_response().await
69    }
70    pub fn spawn_task(&mut self, backend: &Bkend, task: AsyncTask<Frntend, Bkend, Md>)
71    where
72        Frntend: 'static,
73        Bkend: 'static,
74        Md: 'static,
75    {
76        let AsyncTask {
77            task,
78            constraint,
79            metadata,
80        } = task;
81        match task {
82            AsyncTaskKind::Future(future_task) => {
83                let outcome = self.spawn_future_task(backend, future_task, &constraint);
84                self.add_task_to_list(outcome, metadata, constraint);
85            }
86            AsyncTaskKind::Stream(stream_task) => {
87                let outcome = self.spawn_stream_task(backend, stream_task, &constraint);
88                self.add_task_to_list(outcome, metadata, constraint);
89            }
90            // Don't call (self.on_task_spawn)() for NoOp.
91            AsyncTaskKind::Multi(tasks) => {
92                for task in tasks {
93                    self.spawn_task(backend, task)
94                }
95            }
96            AsyncTaskKind::NoOp => (),
97        }
98    }
99    fn add_task_to_list(
100        &mut self,
101        details: TempSpawnedTask<Frntend, Bkend, Md>,
102        metadata: Vec<Md>,
103        constraint: Option<Constraint<Md>>,
104    ) {
105        let TempSpawnedTask {
106            waiter,
107            type_id,
108            type_name,
109            type_debug,
110        } = details;
111        let sp = SpawnedTask {
112            type_id,
113            task_id: TaskId(self.next_task_id),
114            type_name,
115            type_debug,
116            receiver: waiter,
117            metadata,
118        };
119        // At one task per nanosecond, it would take 584.6 years for a library user to
120        // trigger overflow.
121        //
122        // https://www.wolframalpha.com/input?i=2%5E64+nanoseconds
123        let new_id = self
124            .next_task_id
125            .checked_add(1)
126            .expect("u64 shouldn't overflow!");
127        self.next_task_id = new_id;
128        if let Some(constraint) = constraint {
129            self.tasks_list.handle_constraint(constraint, type_id);
130        }
131        self.tasks_list.push(sp);
132    }
133    fn spawn_future_task(
134        &self,
135        backend: &Bkend,
136        future_task: FutureTask<Frntend, Bkend, Md>,
137        constraint: &Option<Constraint<Md>>,
138    ) -> TempSpawnedTask<Frntend, Bkend, Md>
139    where
140        Frntend: 'static,
141        Bkend: 'static,
142        Md: 'static,
143    {
144        (self.on_task_spawn)(TaskInformation {
145            type_id: future_task.type_id,
146            type_name: future_task.type_name,
147            type_debug: &future_task.type_debug,
148            constraint,
149        });
150        let future = (future_task.task)(backend);
151        let handle = tokio::spawn(future);
152        TempSpawnedTask {
153            waiter: TaskWaiter::Future(handle),
154            type_id: future_task.type_id,
155            type_name: future_task.type_name,
156            type_debug: Arc::new(future_task.type_debug),
157        }
158    }
159    fn spawn_stream_task(
160        &self,
161        backend: &Bkend,
162        stream_task: StreamTask<Frntend, Bkend, Md>,
163        constraint: &Option<Constraint<Md>>,
164    ) -> TempSpawnedTask<Frntend, Bkend, Md>
165    where
166        Frntend: 'static,
167        Bkend: 'static,
168        Md: 'static,
169    {
170        let StreamTask {
171            task,
172            type_id,
173            type_name,
174            type_debug,
175        } = stream_task;
176        (self.on_task_spawn)(TaskInformation {
177            type_id,
178            type_name,
179            type_debug: &type_debug,
180            constraint,
181        });
182        let mut stream = task(backend);
183        let (tx, rx) = tokio::sync::mpsc::channel(DEFAULT_STREAM_CHANNEL_SIZE);
184        let abort_handle = tokio::spawn(async move {
185            loop {
186                if let Some(mutation) = stream.next().await {
187                    // Error could occur here if receiver is dropped.
188                    // Doesn't seem to be a big deal to ignore this error.
189                    let _ = tx.send(mutation).await;
190                    continue;
191                }
192                return;
193            }
194        })
195        .abort_handle();
196        TempSpawnedTask {
197            waiter: TaskWaiter::Stream {
198                receiver: rx,
199                abort_handle,
200            },
201            type_id,
202            type_name,
203            type_debug: Arc::new(type_debug),
204        }
205    }
206}