hydro2_async_scheduler/
worker_pool.rs

1// ---------------- [ File: src/worker_pool.rs ]
2crate::ix!();
3
4/// A pool of OS threads: 
5/// - 1 aggregator thread reading from `main_tasks_rx` (single consumer),
6/// - N worker threads, each with its own private channel,
7/// - A results channel for `TaskResult`.
8#[derive(Builder)]
9#[builder(setter(into), pattern = "owned")]
10pub struct WorkerPool<'threads, T>
11where
12    T: Debug + Send + Sync + 'threads
13{
14    /// For sending tasks into aggregator
15    main_tasks_tx: Sender<TaskItem<'threads, T>>,
16
17    /// Threads for aggregator + workers
18    threads: Vec<ScopedJoinHandle<'threads, ()>>,
19
20    /// For receiving TaskResult from all workers
21    results_rx: AsyncMutex<Receiver<TaskResult>>,
22
23    #[cfg(test)]
24    #[builder(default)]
25    pub(crate) results_tx_for_test: Option<Sender<TaskResult>>
26}
27
28impl<'threads, T> WorkerPool<'threads, T>
29where
30    T: Debug + Send + Sync + 'threads,
31{
32    /// Build the aggregator + N workers within a synchronous scope.
33    /// - The aggregator (thread) reads from `main_tasks_rx` (single consumer).
34    /// - It fans out tasks to each worker’s private channel in round-robin.
35    /// - Each worker runs in a separate OS thread, with its own mini tokio runtime.
36    /// - We also have one results channel so the aggregator can send back TaskResult items,
37    ///   which an external consumer (like `process_immediate`) can poll via `try_recv_result`.
38    ///
39    /// The aggregator closes each worker’s channel at the end, ensuring that idle workers
40    /// (which never receive tasks) also exit cleanly.
41    pub fn new_in_scope(
42        scope: &'threads Scope<'threads, '_>,
43        num_workers: usize,
44        buffer_size: usize,
45    ) -> Self {
46
47        eprintln!(
48            "WorkerPool::new_in_scope => setting up aggregator + {} workers, buffer_size={}",
49            num_workers, buffer_size
50        );
51
52        //=== (A) Main tasks channel => aggregator is the single consumer
53        let (main_tasks_tx, main_tasks_rx) = mpsc::channel::<TaskItem<'threads, T>>(buffer_size);
54        eprintln!("WorkerPool::new_in_scope => created main_tasks channel (aggregator consumer)");
55
56        //=== (B) A results channel for all workers => external code can poll results
57        let (results_tx, results_rx) = mpsc::channel::<TaskResult>(buffer_size);
58        eprintln!("WorkerPool::new_in_scope => created results channel for all workers");
59
60        //=== (C) Worker channels: each worker has its own channel
61        // aggregator will send tasks to these
62        let (worker_senders, worker_receivers) = create_worker_channels(num_workers, buffer_size);
63
64        // aggregator + N workers => total of num_workers + 1 threads
65        let threads = spawn_aggregator_and_workers(
66            scope,
67            main_tasks_rx,
68            worker_senders,
69            worker_receivers,
70            results_tx
71        );
72
73        eprintln!("WorkerPool::new_in_scope => aggregator + {} workers => returning WorkerPool", num_workers);
74
75        WorkerPool {
76            main_tasks_tx,
77            threads,
78            results_rx: AsyncMutex::new(results_rx),
79
80            #[cfg(test)]
81            results_tx_for_test: None
82        }
83    }
84
85    /// Submit a task => aggregator picks it up, fans out to a worker.
86    pub async fn submit(&self, item: TaskItem<'threads, T>) -> Result<(), NetworkError> {
87        eprintln!("WorkerPool::submit => sending to aggregator main_tasks channel => node_idx={}", item.node_idx());
88        // Instead of .send(...).await, do a non-blocking try_send:
89        match self.main_tasks_tx.try_send(item) {
90            Ok(()) => Ok(()),
91            Err(_e) => Err(NetworkError::ResourceExhaustion {
92                resource: "WorkerPool Main Tasks Channel".into(),
93            }),
94        }
95    }
96
97    /// Non-blocking poll of the results channel from workers
98    pub async fn try_recv_result(&self) -> Option<TaskResult> {
99        let mut guard = self.results_rx.lock().await;
100        let res = guard.try_recv().ok();
101        if let Some(ref r) = res {
102            eprintln!("WorkerPool::try_recv_result => got a result => node_idx={}", r.node_idx());
103        }
104        res
105    }
106
107    pub fn is_main_channel_closed(&self) -> bool {
108        self.main_tasks_tx.is_closed()
109    }
110
111    /// Force aggregator to see "None" => aggregator returns => shuts down
112    pub fn close_main_tasks_channel(&self) {
113        eprintln!("WorkerPool::close_main_tasks_channel => about to drop main_tasks_tx");
114        eprintln!("Pointer: {:p}", &self.main_tasks_tx);
115        drop(&self.main_tasks_tx);
116        eprintln!("WorkerPool::close_main_tasks_channel => after drop(main_tasks_tx)");
117    }
118
119    /// Shut down everything: aggregator + workers 
120    /// (if aggregator is still open, close it, then join threads).
121    pub fn shutdown(self) {
122        eprintln!("WorkerPool::shutdown => dropping main_tasks_tx => aggregator sees None => eventually done");
123        drop(self.main_tasks_tx);
124
125        for (i, th) in self.threads.into_iter().enumerate() {
126            eprintln!("WorkerPool::shutdown => joining aggregator/worker thread #{}", i);
127            let _ = th.join();
128        }
129        eprintln!("WorkerPool::shutdown => all aggregator+worker threads joined => done");
130    }
131}