1use std::thread::JoinHandle;
12
13use tokio::sync::broadcast;
14use tracing::{error, info};
15
16use crate::{
17 algorithm::AlgorithmConfig,
18 derived::{events::DerivedDataEvent, SharedDerivedDataRef},
19 feed::{
20 events::{MarketEvent, MarketEventHandler},
21 market_data::SharedMarketDataRef,
22 },
23 graph::EdgeWeightUpdaterWithDerived,
24 types::internal::SolveTask,
25 worker_pool::{
26 registry::{
27 spawn_workers_generic, AlgorithmSpawner, SpawnWorkersParams, UnknownAlgorithmError,
28 DEFAULT_ALGORITHM,
29 },
30 task_queue::{TaskQueue, TaskQueueConfig, TaskQueueHandle},
31 },
32};
33
34#[derive(Debug)]
36pub struct WorkerPoolConfig {
37 name: String,
40 spawner: AlgorithmSpawner,
42 num_workers: usize,
44 algorithm_config: AlgorithmConfig,
46 task_queue_capacity: usize,
48}
49
50impl WorkerPoolConfig {
51 pub fn algorithm_name(&self) -> &str {
53 self.spawner.algorithm_name()
54 }
55}
56
57impl Default for WorkerPoolConfig {
58 fn default() -> Self {
59 Self {
60 name: DEFAULT_ALGORITHM.to_string(),
61 spawner: AlgorithmSpawner::Registry { algorithm: DEFAULT_ALGORITHM.to_string() },
62 num_workers: num_cpus::get(),
63 algorithm_config: AlgorithmConfig::default(),
64 task_queue_capacity: 1000,
65 }
66 }
67}
68
69pub struct WorkerPool {
74 name: String,
76 algorithm: String,
78 workers: Vec<JoinHandle<()>>,
80 shutdown_tx: broadcast::Sender<()>,
82}
83
84impl WorkerPool {
85 pub fn spawn(
100 config: WorkerPoolConfig,
101 task_rx: async_channel::Receiver<SolveTask>,
102 market_data: SharedMarketDataRef,
103 derived_data: SharedDerivedDataRef,
104 event_rx: broadcast::Receiver<MarketEvent>,
105 derived_event_rx: broadcast::Receiver<DerivedDataEvent>,
106 ) -> Result<Self, UnknownAlgorithmError> {
107 let (shutdown_tx, _) = broadcast::channel(1);
108 let name = config.name.clone();
109 let algorithm = config
110 .spawner
111 .algorithm_name()
112 .to_string();
113
114 let params = SpawnWorkersParams {
116 algorithm: algorithm.clone(),
117 num_workers: config.num_workers,
118 algorithm_config: config.algorithm_config,
119 task_rx,
120 market_data,
121 derived_data,
122 event_rx,
123 derived_event_rx,
124 shutdown_tx: shutdown_tx.clone(),
125 };
126 let workers = config.spawner.spawn(params)?;
127
128 info!(
129 name = %name,
130 algorithm = %algorithm,
131 num_workers = workers.len(),
132 "worker pool spawned"
133 );
134
135 Ok(Self { name, algorithm, workers, shutdown_tx })
136 }
137
138 pub fn name(&self) -> &str {
140 &self.name
141 }
142
143 pub fn algorithm(&self) -> &str {
145 &self.algorithm
146 }
147
148 pub fn num_workers(&self) -> usize {
150 self.workers.len()
151 }
152
153 pub fn shutdown(self) {
155 info!(name = %self.name, "shutting down worker pool");
156
157 let _ = self.shutdown_tx.send(());
159
160 for (i, handle) in self.workers.into_iter().enumerate() {
162 if let Err(e) = handle.join() {
163 error!(
164 name = %self.name,
165 worker_id = i,
166 "worker thread panicked: {:?}",
167 e
168 );
169 }
170 }
171
172 info!(name = %self.name, "worker pool shut down");
173 }
174}
175
176#[must_use = "a builder does nothing until .build() is called"]
189pub struct WorkerPoolBuilder {
190 config: WorkerPoolConfig,
191}
192
193impl WorkerPoolBuilder {
194 pub fn new() -> Self {
195 Self { config: WorkerPoolConfig::default() }
196 }
197
198 pub fn name(mut self, name: impl Into<String>) -> Self {
200 self.config.name = name.into();
201 self
202 }
203
204 pub fn algorithm(mut self, algorithm: impl Into<String>) -> Self {
208 self.config.spawner = AlgorithmSpawner::Registry { algorithm: algorithm.into() };
209 self
210 }
211
212 pub fn with_algorithm<A, F>(mut self, name: impl Into<String>, factory: F) -> Self
224 where
225 A: crate::algorithm::Algorithm + 'static,
226 A::GraphManager: MarketEventHandler + EdgeWeightUpdaterWithDerived + 'static,
227 F: Fn(AlgorithmConfig) -> A + Clone + Send + Sync + 'static,
228 {
229 let name = name.into();
230 let spawner =
231 Box::new(move |params: SpawnWorkersParams| spawn_workers_generic(params, &factory));
232 self.config.spawner = AlgorithmSpawner::Custom { algorithm: name, spawner };
233 self
234 }
235
236 pub fn algorithm_config(mut self, config: AlgorithmConfig) -> Self {
238 self.config.algorithm_config = config;
239 self
240 }
241
242 pub fn num_workers(mut self, n: usize) -> Self {
244 self.config.num_workers = n;
245 self
246 }
247
248 pub fn task_queue_capacity(mut self, capacity: usize) -> Self {
250 self.config.task_queue_capacity = capacity;
251 self
252 }
253
254 pub fn build(
263 self,
264 market_data: SharedMarketDataRef,
265 derived_data: SharedDerivedDataRef,
266 event_rx: broadcast::Receiver<MarketEvent>,
267 derived_event_rx: broadcast::Receiver<DerivedDataEvent>,
268 ) -> Result<(WorkerPool, TaskQueueHandle), UnknownAlgorithmError> {
269 let task_queue =
271 TaskQueue::new(TaskQueueConfig { capacity: self.config.task_queue_capacity });
272 let (task_handle, task_rx) = task_queue.split();
273
274 let pool = WorkerPool::spawn(
276 self.config,
277 task_rx,
278 market_data,
279 derived_data,
280 event_rx,
281 derived_event_rx,
282 )?;
283
284 Ok((pool, task_handle))
285 }
286}
287
288impl Default for WorkerPoolBuilder {
289 fn default() -> Self {
290 Self::new()
291 }
292}