use std::thread::JoinHandle;
use tokio::sync::broadcast;
use tracing::{error, info};
use crate::{
algorithm::AlgorithmConfig,
derived::{events::DerivedDataEvent, SharedDerivedDataRef},
feed::{
events::{MarketEvent, MarketEventHandler},
market_data::SharedMarketDataRef,
},
graph::EdgeWeightUpdaterWithDerived,
types::internal::SolveTask,
worker_pool::{
registry::{
spawn_workers_generic, AlgorithmSpawner, SpawnWorkersParams, UnknownAlgorithmError,
DEFAULT_ALGORITHM,
},
task_queue::{TaskQueue, TaskQueueConfig, TaskQueueHandle},
},
};
#[derive(Debug)]
pub struct WorkerPoolConfig {
name: String,
spawner: AlgorithmSpawner,
num_workers: usize,
algorithm_config: AlgorithmConfig,
task_queue_capacity: usize,
}
impl WorkerPoolConfig {
pub fn algorithm_name(&self) -> &str {
self.spawner.algorithm_name()
}
}
impl Default for WorkerPoolConfig {
fn default() -> Self {
Self {
name: DEFAULT_ALGORITHM.to_string(),
spawner: AlgorithmSpawner::Registry { algorithm: DEFAULT_ALGORITHM.to_string() },
num_workers: num_cpus::get(),
algorithm_config: AlgorithmConfig::default(),
task_queue_capacity: 1000,
}
}
}
pub struct WorkerPool {
name: String,
algorithm: String,
workers: Vec<JoinHandle<()>>,
shutdown_tx: broadcast::Sender<()>,
}
impl WorkerPool {
pub fn spawn(
config: WorkerPoolConfig,
task_rx: async_channel::Receiver<SolveTask>,
market_data: SharedMarketDataRef,
derived_data: SharedDerivedDataRef,
event_rx: broadcast::Receiver<MarketEvent>,
derived_event_rx: broadcast::Receiver<DerivedDataEvent>,
) -> Result<Self, UnknownAlgorithmError> {
let (shutdown_tx, _) = broadcast::channel(1);
let name = config.name.clone();
let algorithm = config
.spawner
.algorithm_name()
.to_string();
let params = SpawnWorkersParams {
algorithm: algorithm.clone(),
num_workers: config.num_workers,
algorithm_config: config.algorithm_config,
task_rx,
market_data,
derived_data,
event_rx,
derived_event_rx,
shutdown_tx: shutdown_tx.clone(),
};
let workers = config.spawner.spawn(params)?;
info!(
name = %name,
algorithm = %algorithm,
num_workers = workers.len(),
"worker pool spawned"
);
Ok(Self { name, algorithm, workers, shutdown_tx })
}
pub fn name(&self) -> &str {
&self.name
}
pub fn algorithm(&self) -> &str {
&self.algorithm
}
pub fn num_workers(&self) -> usize {
self.workers.len()
}
pub fn shutdown(self) {
info!(name = %self.name, "shutting down worker pool");
let _ = self.shutdown_tx.send(());
for (i, handle) in self.workers.into_iter().enumerate() {
if let Err(e) = handle.join() {
error!(
name = %self.name,
worker_id = i,
"worker thread panicked: {:?}",
e
);
}
}
info!(name = %self.name, "worker pool shut down");
}
}
#[must_use = "a builder does nothing until .build() is called"]
pub struct WorkerPoolBuilder {
config: WorkerPoolConfig,
}
impl WorkerPoolBuilder {
pub fn new() -> Self {
Self { config: WorkerPoolConfig::default() }
}
pub fn name(mut self, name: impl Into<String>) -> Self {
self.config.name = name.into();
self
}
pub fn algorithm(mut self, algorithm: impl Into<String>) -> Self {
self.config.spawner = AlgorithmSpawner::Registry { algorithm: algorithm.into() };
self
}
pub fn with_algorithm<A, F>(mut self, name: impl Into<String>, factory: F) -> Self
where
A: crate::algorithm::Algorithm + 'static,
A::GraphManager: MarketEventHandler + EdgeWeightUpdaterWithDerived + 'static,
F: Fn(AlgorithmConfig) -> A + Clone + Send + Sync + 'static,
{
let name = name.into();
let spawner =
Box::new(move |params: SpawnWorkersParams| spawn_workers_generic(params, &factory));
self.config.spawner = AlgorithmSpawner::Custom { algorithm: name, spawner };
self
}
pub fn algorithm_config(mut self, config: AlgorithmConfig) -> Self {
self.config.algorithm_config = config;
self
}
pub fn num_workers(mut self, n: usize) -> Self {
self.config.num_workers = n;
self
}
pub fn task_queue_capacity(mut self, capacity: usize) -> Self {
self.config.task_queue_capacity = capacity;
self
}
pub fn build(
self,
market_data: SharedMarketDataRef,
derived_data: SharedDerivedDataRef,
event_rx: broadcast::Receiver<MarketEvent>,
derived_event_rx: broadcast::Receiver<DerivedDataEvent>,
) -> Result<(WorkerPool, TaskQueueHandle), UnknownAlgorithmError> {
let task_queue =
TaskQueue::new(TaskQueueConfig { capacity: self.config.task_queue_capacity });
let (task_handle, task_rx) = task_queue.split();
let pool = WorkerPool::spawn(
self.config,
task_rx,
market_data,
derived_data,
event_rx,
derived_event_rx,
)?;
Ok((pool, task_handle))
}
}
impl Default for WorkerPoolBuilder {
fn default() -> Self {
Self::new()
}
}