use std::{
sync::Arc,
thread::{self, JoinHandle},
};
use tokio::sync::broadcast;
use tracing::info;
use crate::{
algorithm::{AlgorithmConfig, BellmanFordAlgorithm, MostLiquidAlgorithm},
derived::{events::DerivedDataEvent, SharedDerivedDataRef},
feed::{events::MarketEvent, market_data::SharedMarketDataRef},
types::internal::SolveTask,
worker_pool::worker::SolverWorker,
};
pub(crate) const AVAILABLE_ALGORITHMS: &[&str] = &["most_liquid", "bellman_ford"];
pub(crate) const DEFAULT_ALGORITHM: &str = "most_liquid";
pub(crate) struct SpawnWorkersParams {
pub algorithm: String,
pub num_workers: usize,
pub algorithm_config: AlgorithmConfig,
pub task_rx: async_channel::Receiver<SolveTask>,
pub market_data: SharedMarketDataRef,
pub derived_data: SharedDerivedDataRef,
pub event_rx: broadcast::Receiver<MarketEvent>,
pub derived_event_rx: broadcast::Receiver<DerivedDataEvent>,
pub shutdown_tx: broadcast::Sender<()>,
}
#[derive(Debug, Clone, thiserror::Error)]
#[error("unknown algorithm '{name}'. Available: {}", AVAILABLE_ALGORITHMS.join(", "))]
pub struct UnknownAlgorithmError {
pub(crate) name: String,
}
impl UnknownAlgorithmError {
pub fn name(&self) -> &str {
&self.name
}
}
pub(crate) enum AlgorithmSpawner {
Registry { algorithm: String },
Custom {
algorithm: String,
spawner: Box<dyn Fn(SpawnWorkersParams) -> Vec<JoinHandle<()>> + Send + Sync>,
},
}
impl std::fmt::Debug for AlgorithmSpawner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Registry { algorithm } => f
.debug_struct("Registry")
.field("algorithm", algorithm)
.finish(),
Self::Custom { algorithm, .. } => f
.debug_struct("Custom")
.field("algorithm", algorithm)
.finish(),
}
}
}
impl AlgorithmSpawner {
pub(crate) fn spawn(
self,
params: SpawnWorkersParams,
) -> Result<Vec<JoinHandle<()>>, UnknownAlgorithmError> {
match self {
Self::Registry { algorithm } => match algorithm.as_str() {
"most_liquid" => Ok(spawn_most_liquid_workers(params)),
"bellman_ford" => Ok(spawn_bellman_ford_workers(params)),
_ => Err(UnknownAlgorithmError { name: algorithm }),
},
Self::Custom { spawner, .. } => Ok(spawner(params)),
}
}
pub(crate) fn algorithm_name(&self) -> &str {
match self {
Self::Registry { algorithm } | Self::Custom { algorithm, .. } => algorithm,
}
}
}
pub(crate) fn spawn_workers_generic<A, F>(
params: SpawnWorkersParams,
factory: &F,
) -> Vec<JoinHandle<()>>
where
A: crate::algorithm::Algorithm + 'static,
A::GraphManager:
crate::feed::events::MarketEventHandler + crate::graph::EdgeWeightUpdaterWithDerived,
F: Fn(AlgorithmConfig) -> A + Clone + Send + Sync + 'static,
{
let mut workers = Vec::with_capacity(params.num_workers);
for worker_id in 0..params.num_workers {
let task_rx = params.task_rx.clone();
let market_data = Arc::clone(¶ms.market_data);
let derived_data = Arc::clone(¶ms.derived_data);
let event_rx = params.event_rx.resubscribe();
let derived_event_rx = params.derived_event_rx.resubscribe();
let algorithm_config = params.algorithm_config.clone();
let shutdown_rx = params.shutdown_tx.subscribe();
let algorithm_name = params.algorithm.clone();
let factory = factory.clone();
let handle = thread::Builder::new()
.name(format!("{}-worker-{}", algorithm_name, worker_id))
.spawn(move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("failed to create tokio runtime");
rt.block_on(async move {
let algorithm = factory(algorithm_config);
let mut worker =
SolverWorker::new(market_data, derived_data, algorithm, worker_id);
worker.initialize_graph().await;
worker
.run(event_rx, derived_event_rx, task_rx, shutdown_rx)
.await;
});
})
.expect("failed to spawn worker thread");
workers.push(handle);
}
info!(
algorithm = %params.algorithm,
num_workers = params.num_workers,
"spawned workers"
);
workers
}
fn spawn_most_liquid_workers(params: SpawnWorkersParams) -> Vec<JoinHandle<()>> {
let factory = |config: AlgorithmConfig| {
MostLiquidAlgorithm::with_config(config)
.expect("invalid worker configuration for MostLiquidAlgorithm")
};
spawn_workers_generic(params, &factory)
}
fn spawn_bellman_ford_workers(params: SpawnWorkersParams) -> Vec<JoinHandle<()>> {
let factory = |config: AlgorithmConfig| {
BellmanFordAlgorithm::with_config(config)
.expect("invalid worker configuration for BellmanFordAlgorithm")
};
spawn_workers_generic(params, &factory)
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use tokio::sync::RwLock;
use super::*;
use crate::{derived::DerivedData, feed::market_data::SharedMarketData};
fn make_params(algorithm: &str, num_workers: usize) -> SpawnWorkersParams {
let (_task_tx, task_rx) = async_channel::bounded(10);
let market_data = Arc::new(RwLock::new(SharedMarketData::new()));
let derived_data = Arc::new(RwLock::new(DerivedData::new()));
let (_event_tx, event_rx) = broadcast::channel(10);
let (_derived_event_tx, derived_event_rx) = broadcast::channel(10);
let (shutdown_tx, _) = broadcast::channel(1);
SpawnWorkersParams {
algorithm: algorithm.to_string(),
num_workers,
algorithm_config: AlgorithmConfig::default(),
task_rx,
market_data,
derived_data,
event_rx,
derived_event_rx,
shutdown_tx,
}
}
#[test]
fn test_registry_unknown_algorithm_returns_error() {
let params = make_params("unknown_algorithm", 1);
let result =
AlgorithmSpawner::Registry { algorithm: "unknown_algorithm".to_string() }.spawn(params);
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.name, "unknown_algorithm");
assert!(err
.to_string()
.contains("unknown_algorithm"));
assert!(err.to_string().contains("most_liquid"));
}
#[test]
fn test_registry_spawns_correct_number_of_workers() {
let (shutdown_tx, _) = broadcast::channel(1);
let (_task_tx, task_rx) = async_channel::bounded(10);
let market_data = Arc::new(RwLock::new(SharedMarketData::new()));
let derived_data = Arc::new(RwLock::new(DerivedData::new()));
let (event_tx, event_rx) = broadcast::channel(10);
let (_derived_event_tx, derived_event_rx) = broadcast::channel(10);
let params = SpawnWorkersParams {
algorithm: "most_liquid".to_string(),
num_workers: 3,
algorithm_config: AlgorithmConfig::new(1, 2, Duration::from_millis(50), None).unwrap(),
task_rx,
market_data,
derived_data,
event_rx,
derived_event_rx,
shutdown_tx: shutdown_tx.clone(),
};
let workers =
AlgorithmSpawner::Registry { algorithm: "most_liquid".to_string() }.spawn(params);
assert!(workers.is_ok());
let workers = workers.unwrap();
assert_eq!(workers.len(), 3);
let _ = shutdown_tx.send(());
drop(event_tx);
for handle in workers {
let _ = handle.join();
}
}
#[test]
fn test_custom_spawner_bypasses_registry_for_unknown_names() {
let (shutdown_tx, _) = broadcast::channel(1);
let (_task_tx, task_rx) = async_channel::bounded(10);
let market_data = Arc::new(RwLock::new(SharedMarketData::new()));
let derived_data = Arc::new(RwLock::new(DerivedData::new()));
let (event_tx, _) = broadcast::channel::<MarketEvent>(10);
let (derived_event_tx, _) = broadcast::channel(10);
let registry_err = AlgorithmSpawner::Registry { algorithm: "my_custom_algo".to_string() }
.spawn(SpawnWorkersParams {
algorithm: "my_custom_algo".to_string(),
num_workers: 1,
algorithm_config: AlgorithmConfig::default(),
task_rx: task_rx.clone(),
market_data: Arc::clone(&market_data),
derived_data: Arc::clone(&derived_data),
event_rx: event_tx.subscribe(),
derived_event_rx: derived_event_tx.subscribe(),
shutdown_tx: shutdown_tx.clone(),
});
assert!(registry_err.is_err());
let spawner: Box<dyn Fn(SpawnWorkersParams) -> Vec<JoinHandle<()>> + Send + Sync> =
Box::new(|p| {
let factory = |config: AlgorithmConfig| {
MostLiquidAlgorithm::with_config(config)
.expect("invalid config in test custom spawner")
};
spawn_workers_generic(p, &factory)
});
let workers = AlgorithmSpawner::Custom { algorithm: "my_custom_algo".to_string(), spawner }
.spawn(SpawnWorkersParams {
algorithm: "my_custom_algo".to_string(),
num_workers: 2,
algorithm_config: AlgorithmConfig::new(1, 2, Duration::from_millis(50), None)
.unwrap(),
task_rx,
market_data,
derived_data,
event_rx: event_tx.subscribe(),
derived_event_rx: derived_event_tx.subscribe(),
shutdown_tx: shutdown_tx.clone(),
});
assert!(workers.is_ok());
assert_eq!(workers.unwrap().len(), 2);
let _ = shutdown_tx.send(());
}
}