rfs_runner/
worker_handle.rs

1use colored::Colorize;
2use tokio::select;
3use tokio::sync::mpsc::Receiver;
4
5use crate::error::Error;
6use crate::templates::WorkerTemplate;
7use crate::{limit_string, wsupdate, wsupdate_async, MainProgress, Result, Uid, WorkerState, WorkerStatus};
8
9pub struct WorkerHandle<D, F, Fut, S>
10where
11  Fut: Future<Output = anyhow::Result<()>> + Send + 'static,
12  Fut::Output: Send + 'static,
13  F: Fn(D, MainProgress<S>) -> Fut,
14{
15  fn_ptr: F,
16  receiver: Receiver<D>,
17  main_ui: MainProgress<S>,
18  worker_state: WorkerState<S>,
19}
20
21impl<D, F, Fut, S: WorkerTemplate> WorkerHandle<D, F, Fut, S>
22where
23  D: Send,
24  Fut: Future<Output = anyhow::Result<()>> + Send + 'static,
25  Fut::Output: Send + 'static,
26  F: Fn(D, MainProgress<S>) -> Fut,
27{
28  fn initialize(mut self) -> Self {
29    self.main_ui.add_worker(self.worker_state.id, &self.worker_state);
30
31    wsupdate! {
32      self.worker_state,
33      self.receiver,
34      "New worker spawned",
35      WorkerStatus::Spawned,
36    }
37
38    self.worker_state.set_status(WorkerStatus::Waiting);
39    self
40  }
41
42  pub async fn run(mut self) {
43    use WorkerStatus::*;
44
45    while let Some(data) = self.receiver.recv().await {
46      self.main_ui.inc(1);
47
48      wsupdate! {
49        self.worker_state,
50        self.receiver,
51        "Processing...",
52        Running
53      }
54
55      let call = &mut self.fn_ptr;
56      select! {
57        _ = wsupdate_async!(self.worker_state, self.receiver,) => (),
58        result = call(data, self.main_ui.clone()) => self.update_state(result)
59      }
60    }
61
62    self.worker_state.set_jobs(self.receiver.len());
63    self.worker_state.set_status(Stopped);
64    self.worker_state.set_task("All finished!");
65  }
66
67  fn update_state(&mut self, output: anyhow::Result<()>) {
68    if let Err(error) = output {
69      let (task, status) = if self.receiver.is_empty() {
70        ("Waiting", WorkerStatus::Waiting)
71      } else {
72        ("Processing", WorkerStatus::Running)
73      };
74
75      wsupdate! {
76        &mut self.worker_state,
77        self.receiver,
78        "Waiting Jobs",
79        status
80      }
81
82      self.worker_state.set_jobs(self.receiver.len());
83      self.worker_state.set_task(task);
84
85      let error = limit_string(116, error, None).bright_red();
86      self.main_ui.set_message(error.to_string());
87    }
88  }
89}
90
91pub struct WorkerHandleBuilder<D, F, Fut, S: WorkerTemplate>
92where
93  D: Send,
94  Fut: Future<Output = anyhow::Result<()>> + Send + 'static,
95  Fut::Output: Send + 'static,
96  F: Fn(D, MainProgress<S>) -> Fut,
97{
98  fn_ptr: Option<F>,
99  main_ui: Option<MainProgress<S>>,
100  receiver: Option<Receiver<D>>,
101}
102
103impl<D, F, Fut, S: WorkerTemplate> Default for WorkerHandleBuilder<D, F, Fut, S>
104where
105  D: Send,
106  Fut: Future<Output = anyhow::Result<()>> + Send + 'static,
107  Fut::Output: Send + 'static,
108  F: Fn(D, MainProgress<S>) -> Fut,
109{
110  fn default() -> Self {
111    Self {
112      fn_ptr: None,
113      main_ui: None,
114      receiver: None,
115    }
116  }
117}
118
119impl<D, F, Fut, S: WorkerTemplate> WorkerHandleBuilder<D, F, Fut, S>
120where
121  D: Send,
122  Fut: Future<Output = anyhow::Result<()>> + Send + 'static,
123  Fut::Output: Send + 'static,
124  F: Fn(D, MainProgress<S>) -> Fut,
125{
126  pub fn fn_ptr(mut self, f: F) -> Self {
127    self.fn_ptr = Some(f);
128    self
129  }
130
131  pub fn main_ui(mut self, ui: MainProgress<S>) -> Self {
132    self.main_ui = Some(ui);
133    self
134  }
135
136  pub fn receiver(mut self, receiver: Receiver<D>) -> Self {
137    self.receiver = Some(receiver);
138    self
139  }
140
141  pub fn build(self, uid: Uid) -> Result<WorkerHandle<D, F, Fut, S>> {
142    let fn_ptr = self.fn_ptr.ok_or(Error::Missing("[`WorkerHandleBuilder`] missing [`FnMut`]"))?;
143    let receiver = self.receiver.ok_or(Error::Missing("[`WorkerHandleBuilder`] missing [`Receiver<T>`]"))?;
144    let main_ui = self.main_ui.ok_or(Error::Missing("[`WorkerHandleBuilder`] missing [`MainProgress`]"))?;
145    let worker_state = WorkerState::new(uid, main_ui.clone());
146
147    let handle = WorkerHandle {
148      fn_ptr,
149      receiver,
150      main_ui,
151      worker_state,
152    };
153
154    Ok(handle.initialize())
155  }
156}
157
158#[cfg(test)]
159mod test {
160  use std::num::NonZero;
161  use std::sync::atomic::AtomicBool;
162  use std::sync::Arc;
163
164  use indicatif::MultiProgress;
165  use tokio::spawn;
166
167  use super::*;
168  use crate::DefaultTemplate;
169
170  #[tokio::test]
171  async fn test_build() {
172    let style = DefaultTemplate::new("Test");
173    let (_, rx) = tokio::sync::mpsc::channel::<String>(1);
174    let main_ui = MainProgress::new(10, MultiProgress::new(), style);
175
176    let handle = WorkerHandleBuilder::default()
177      .fn_ptr(async |_: String, _: MainProgress<DefaultTemplate>| Ok(()))
178      .main_ui(main_ui)
179      .receiver(rx)
180      .build(NonZero::new(1).unwrap());
181
182    assert!(handle.is_ok());
183
184    let handle = handle.unwrap();
185    let join_handle = spawn(handle.run());
186    join_handle.abort();
187  }
188
189  #[tokio::test]
190  async fn test_capture_owned_var() {
191    async fn the_solver_fn(_: String, _: MainProgress<DefaultTemplate>, _: Arc<AtomicBool>) -> anyhow::Result<()> {
192      Ok(())
193    }
194
195    let (_, rx) = tokio::sync::mpsc::channel::<String>(1);
196    let main_ui = MainProgress::new(10, MultiProgress::new(), DefaultTemplate::new("Test"));
197    let moveable = Arc::new(AtomicBool::default());
198
199    let ref_4_fn_ptr_owned = moveable.clone();
200
201    let handle = WorkerHandleBuilder::default()
202      .fn_ptr(move |msg: String, ui: MainProgress<DefaultTemplate>| the_solver_fn(msg, ui, ref_4_fn_ptr_owned.clone()))
203      .main_ui(main_ui)
204      .receiver(rx)
205      .build(NonZero::new(1).unwrap());
206
207    assert!(handle.is_ok());
208
209    let handle = handle.unwrap();
210    let join_handle = spawn(handle.run());
211    join_handle.abort();
212  }
213}