1use colored::Colorize;
2use tokio::select;
3use tokio::sync::mpsc::Receiver;
4
5use crate::error::Error;
6use crate::templates::WorkerTemplate;
7use crate::{limit_string, wsupdate, wsupdate_async, MainProgress, Result, Uid, WorkerState, WorkerStatus};
8
9pub struct WorkerHandle<D, F, Fut, S>
10where
11 Fut: Future<Output = anyhow::Result<()>> + Send + 'static,
12 Fut::Output: Send + 'static,
13 F: Fn(D, MainProgress<S>) -> Fut,
14{
15 fn_ptr: F,
16 receiver: Receiver<D>,
17 main_ui: MainProgress<S>,
18 worker_state: WorkerState<S>,
19}
20
21impl<D, F, Fut, S: WorkerTemplate> WorkerHandle<D, F, Fut, S>
22where
23 D: Send,
24 Fut: Future<Output = anyhow::Result<()>> + Send + 'static,
25 Fut::Output: Send + 'static,
26 F: Fn(D, MainProgress<S>) -> Fut,
27{
28 fn initialize(mut self) -> Self {
29 self.main_ui.add_worker(self.worker_state.id, &self.worker_state);
30
31 wsupdate! {
32 self.worker_state,
33 self.receiver,
34 "New worker spawned",
35 WorkerStatus::Spawned,
36 }
37
38 self.worker_state.set_status(WorkerStatus::Waiting);
39 self
40 }
41
42 pub async fn run(mut self) {
43 use WorkerStatus::*;
44
45 while let Some(data) = self.receiver.recv().await {
46 self.main_ui.inc(1);
47
48 wsupdate! {
49 self.worker_state,
50 self.receiver,
51 "Processing...",
52 Running
53 }
54
55 let call = &mut self.fn_ptr;
56 select! {
57 _ = wsupdate_async!(self.worker_state, self.receiver,) => (),
58 result = call(data, self.main_ui.clone()) => self.update_state(result)
59 }
60 }
61
62 self.worker_state.set_jobs(self.receiver.len());
63 self.worker_state.set_status(Stopped);
64 self.worker_state.set_task("All finished!");
65 }
66
67 fn update_state(&mut self, output: anyhow::Result<()>) {
68 if let Err(error) = output {
69 let (task, status) = if self.receiver.is_empty() {
70 ("Waiting", WorkerStatus::Waiting)
71 } else {
72 ("Processing", WorkerStatus::Running)
73 };
74
75 wsupdate! {
76 &mut self.worker_state,
77 self.receiver,
78 "Waiting Jobs",
79 status
80 }
81
82 self.worker_state.set_jobs(self.receiver.len());
83 self.worker_state.set_task(task);
84
85 let error = limit_string(116, error, None).bright_red();
86 self.main_ui.set_message(error.to_string());
87 }
88 }
89}
90
91pub struct WorkerHandleBuilder<D, F, Fut, S: WorkerTemplate>
92where
93 D: Send,
94 Fut: Future<Output = anyhow::Result<()>> + Send + 'static,
95 Fut::Output: Send + 'static,
96 F: Fn(D, MainProgress<S>) -> Fut,
97{
98 fn_ptr: Option<F>,
99 main_ui: Option<MainProgress<S>>,
100 receiver: Option<Receiver<D>>,
101}
102
103impl<D, F, Fut, S: WorkerTemplate> Default for WorkerHandleBuilder<D, F, Fut, S>
104where
105 D: Send,
106 Fut: Future<Output = anyhow::Result<()>> + Send + 'static,
107 Fut::Output: Send + 'static,
108 F: Fn(D, MainProgress<S>) -> Fut,
109{
110 fn default() -> Self {
111 Self {
112 fn_ptr: None,
113 main_ui: None,
114 receiver: None,
115 }
116 }
117}
118
119impl<D, F, Fut, S: WorkerTemplate> WorkerHandleBuilder<D, F, Fut, S>
120where
121 D: Send,
122 Fut: Future<Output = anyhow::Result<()>> + Send + 'static,
123 Fut::Output: Send + 'static,
124 F: Fn(D, MainProgress<S>) -> Fut,
125{
126 pub fn fn_ptr(mut self, f: F) -> Self {
127 self.fn_ptr = Some(f);
128 self
129 }
130
131 pub fn main_ui(mut self, ui: MainProgress<S>) -> Self {
132 self.main_ui = Some(ui);
133 self
134 }
135
136 pub fn receiver(mut self, receiver: Receiver<D>) -> Self {
137 self.receiver = Some(receiver);
138 self
139 }
140
141 pub fn build(self, uid: Uid) -> Result<WorkerHandle<D, F, Fut, S>> {
142 let fn_ptr = self.fn_ptr.ok_or(Error::Missing("[`WorkerHandleBuilder`] missing [`FnMut`]"))?;
143 let receiver = self.receiver.ok_or(Error::Missing("[`WorkerHandleBuilder`] missing [`Receiver<T>`]"))?;
144 let main_ui = self.main_ui.ok_or(Error::Missing("[`WorkerHandleBuilder`] missing [`MainProgress`]"))?;
145 let worker_state = WorkerState::new(uid, main_ui.clone());
146
147 let handle = WorkerHandle {
148 fn_ptr,
149 receiver,
150 main_ui,
151 worker_state,
152 };
153
154 Ok(handle.initialize())
155 }
156}
157
158#[cfg(test)]
159mod test {
160 use std::num::NonZero;
161 use std::sync::atomic::AtomicBool;
162 use std::sync::Arc;
163
164 use indicatif::MultiProgress;
165 use tokio::spawn;
166
167 use super::*;
168 use crate::DefaultTemplate;
169
170 #[tokio::test]
171 async fn test_build() {
172 let style = DefaultTemplate::new("Test");
173 let (_, rx) = tokio::sync::mpsc::channel::<String>(1);
174 let main_ui = MainProgress::new(10, MultiProgress::new(), style);
175
176 let handle = WorkerHandleBuilder::default()
177 .fn_ptr(async |_: String, _: MainProgress<DefaultTemplate>| Ok(()))
178 .main_ui(main_ui)
179 .receiver(rx)
180 .build(NonZero::new(1).unwrap());
181
182 assert!(handle.is_ok());
183
184 let handle = handle.unwrap();
185 let join_handle = spawn(handle.run());
186 join_handle.abort();
187 }
188
189 #[tokio::test]
190 async fn test_capture_owned_var() {
191 async fn the_solver_fn(_: String, _: MainProgress<DefaultTemplate>, _: Arc<AtomicBool>) -> anyhow::Result<()> {
192 Ok(())
193 }
194
195 let (_, rx) = tokio::sync::mpsc::channel::<String>(1);
196 let main_ui = MainProgress::new(10, MultiProgress::new(), DefaultTemplate::new("Test"));
197 let moveable = Arc::new(AtomicBool::default());
198
199 let ref_4_fn_ptr_owned = moveable.clone();
200
201 let handle = WorkerHandleBuilder::default()
202 .fn_ptr(move |msg: String, ui: MainProgress<DefaultTemplate>| the_solver_fn(msg, ui, ref_4_fn_ptr_owned.clone()))
203 .main_ui(main_ui)
204 .receiver(rx)
205 .build(NonZero::new(1).unwrap());
206
207 assert!(handle.is_ok());
208
209 let handle = handle.unwrap();
210 let join_handle = spawn(handle.run());
211 join_handle.abort();
212 }
213}