1use std::thread::JoinHandle;
12
13use tokio::sync::broadcast;
14use tracing::{error, info};
15
16use crate::{
17 algorithm::AlgorithmConfig,
18 derived::{events::DerivedDataEvent, SharedDerivedDataRef},
19 feed::{
20 events::{MarketEvent, MarketEventHandler},
21 market_data::SharedMarketDataRef,
22 },
23 graph::EdgeWeightUpdaterWithDerived,
24 types::internal::SolveTask,
25 worker_pool::{
26 registry::{
27 spawn_workers_generic, AlgorithmSpawner, SpawnWorkersParams, UnknownAlgorithmError,
28 DEFAULT_ALGORITHM,
29 },
30 task_queue::{TaskQueue, TaskQueueConfig, TaskQueueHandle},
31 },
32};
33
34#[derive(Debug)]
36pub struct WorkerPoolConfig {
37 name: String,
40 spawner: AlgorithmSpawner,
42 num_workers: usize,
44 algorithm_config: AlgorithmConfig,
46 task_queue_capacity: usize,
48}
49
50impl WorkerPoolConfig {
51 pub fn algorithm_name(&self) -> &str {
53 self.spawner.algorithm_name()
54 }
55}
56
57impl Default for WorkerPoolConfig {
58 fn default() -> Self {
59 Self {
60 name: DEFAULT_ALGORITHM.to_string(),
61 spawner: AlgorithmSpawner::Registry { algorithm: DEFAULT_ALGORITHM.to_string() },
62 num_workers: num_cpus::get(),
63 algorithm_config: AlgorithmConfig::default(),
64 task_queue_capacity: 1000,
65 }
66 }
67}
68
69pub struct WorkerPool {
74 name: String,
76 algorithm: String,
78 workers: Vec<JoinHandle<()>>,
80 shutdown_tx: broadcast::Sender<()>,
82}
83
84impl WorkerPool {
85 pub fn spawn(
100 config: WorkerPoolConfig,
101 task_rx: async_channel::Receiver<SolveTask>,
102 market_data: SharedMarketDataRef,
103 derived_data: SharedDerivedDataRef,
104 event_rx: broadcast::Receiver<MarketEvent>,
105 derived_event_rx: broadcast::Receiver<DerivedDataEvent>,
106 ) -> Result<Self, UnknownAlgorithmError> {
107 let (shutdown_tx, _) = broadcast::channel(1);
108 let name = config.name.clone();
109 let algorithm = config
110 .spawner
111 .algorithm_name()
112 .to_string();
113
114 let params = SpawnWorkersParams {
116 algorithm: algorithm.clone(),
117 num_workers: config.num_workers,
118 algorithm_config: config.algorithm_config,
119 task_rx,
120 market_data,
121 derived_data,
122 event_rx,
123 derived_event_rx,
124 shutdown_tx: shutdown_tx.clone(),
125 };
126 let workers = config.spawner.spawn(params)?;
127
128 info!(
129 name = %name,
130 algorithm = %algorithm,
131 num_workers = workers.len(),
132 "worker pool spawned"
133 );
134
135 Ok(Self { name, algorithm, workers, shutdown_tx })
136 }
137
138 pub fn name(&self) -> &str {
140 &self.name
141 }
142
143 pub fn algorithm(&self) -> &str {
145 &self.algorithm
146 }
147
148 pub fn num_workers(&self) -> usize {
150 self.workers.len()
151 }
152
153 pub fn shutdown(self) {
155 info!(name = %self.name, "shutting down worker pool");
156
157 let _ = self.shutdown_tx.send(());
159
160 for (i, handle) in self.workers.into_iter().enumerate() {
162 if let Err(e) = handle.join() {
163 error!(
164 name = %self.name,
165 worker_id = i,
166 "worker thread panicked: {:?}",
167 e
168 );
169 }
170 }
171
172 info!(name = %self.name, "worker pool shut down");
173 }
174}
175
176#[must_use = "a builder does nothing until .build() is called"]
189pub struct WorkerPoolBuilder {
190 config: WorkerPoolConfig,
191}
192
193impl WorkerPoolBuilder {
194 pub fn new() -> Self {
196 Self { config: WorkerPoolConfig::default() }
197 }
198
199 pub fn name(mut self, name: impl Into<String>) -> Self {
201 self.config.name = name.into();
202 self
203 }
204
205 pub fn algorithm(mut self, algorithm: impl Into<String>) -> Self {
209 self.config.spawner = AlgorithmSpawner::Registry { algorithm: algorithm.into() };
210 self
211 }
212
213 pub fn with_algorithm<A, F>(mut self, name: impl Into<String>, factory: F) -> Self
225 where
226 A: crate::algorithm::Algorithm + 'static,
227 A::GraphManager: MarketEventHandler + EdgeWeightUpdaterWithDerived + 'static,
228 F: Fn(AlgorithmConfig) -> A + Clone + Send + Sync + 'static,
229 {
230 let name = name.into();
231 let spawner =
232 Box::new(move |params: SpawnWorkersParams| spawn_workers_generic(params, &factory));
233 self.config.spawner = AlgorithmSpawner::Custom { algorithm: name, spawner };
234 self
235 }
236
237 pub fn algorithm_config(mut self, config: AlgorithmConfig) -> Self {
239 self.config.algorithm_config = config;
240 self
241 }
242
243 pub fn num_workers(mut self, n: usize) -> Self {
245 self.config.num_workers = n;
246 self
247 }
248
249 pub fn task_queue_capacity(mut self, capacity: usize) -> Self {
251 self.config.task_queue_capacity = capacity;
252 self
253 }
254
255 pub fn build(
264 self,
265 market_data: SharedMarketDataRef,
266 derived_data: SharedDerivedDataRef,
267 event_rx: broadcast::Receiver<MarketEvent>,
268 derived_event_rx: broadcast::Receiver<DerivedDataEvent>,
269 ) -> Result<(WorkerPool, TaskQueueHandle), UnknownAlgorithmError> {
270 let task_queue =
272 TaskQueue::new(TaskQueueConfig { capacity: self.config.task_queue_capacity });
273 let (task_handle, task_rx) = task_queue.split();
274
275 let pool = WorkerPool::spawn(
277 self.config,
278 task_rx,
279 market_data,
280 derived_data,
281 event_rx,
282 derived_event_rx,
283 )?;
284
285 Ok((pool, task_handle))
286 }
287}
288
289impl Default for WorkerPoolBuilder {
290 fn default() -> Self {
291 Self::new()
292 }
293}