use crate::{collate::Collate, dataset::Dataset, sampler::BatchSampler};
use std::collections::HashMap;
use std::sync::mpsc;
use std::sync::{Arc, Mutex};
use std::thread;
use torsh_core::error::Result;
#[derive(Debug, Clone)]
struct WorkerTask {
task_id: usize,
indices: Vec<usize>,
}
#[derive(Debug)]
pub struct WorkerResult<T> {
pub task_id: usize,
pub result: Result<T>,
}
struct WorkerHandle<T> {
_thread: thread::JoinHandle<()>,
_phantom: std::marker::PhantomData<T>,
}
pub struct WorkerPool<D, C>
where
D: Dataset + Clone + Send + Sync + 'static,
C: Collate<D::Item> + Clone + Send + Sync + 'static,
D::Item: Send + 'static,
C::Output: Send + 'static,
{
#[allow(dead_code)]
dataset: Arc<D>,
#[allow(dead_code)]
collate_fn: Arc<C>,
num_workers: usize,
#[allow(dead_code)]
workers: Vec<WorkerHandle<C::Output>>,
task_sender: mpsc::Sender<WorkerTask>,
result_receiver: mpsc::Receiver<WorkerResult<C::Output>>,
}
impl<D, C> WorkerPool<D, C>
where
D: Dataset + Clone + Send + Sync + 'static,
C: Collate<D::Item> + Clone + Send + Sync + 'static,
D::Item: Send + 'static,
C::Output: Send + 'static,
{
pub fn new(dataset: Arc<D>, collate_fn: Arc<C>, num_workers: usize) -> Self {
let (task_sender, task_receiver) = mpsc::channel::<WorkerTask>();
let (result_sender, result_receiver) = mpsc::channel::<WorkerResult<C::Output>>();
let task_receiver = Arc::new(Mutex::new(task_receiver));
let mut workers = Vec::with_capacity(num_workers);
for worker_id in 0..num_workers {
let dataset_clone = Arc::clone(&dataset);
let collate_fn_clone = Arc::clone(&collate_fn);
let task_receiver_clone = Arc::clone(&task_receiver);
let result_sender_clone = result_sender.clone();
let worker_thread = thread::spawn(move || {
Self::worker_loop(
worker_id,
dataset_clone,
collate_fn_clone,
task_receiver_clone,
result_sender_clone,
);
});
workers.push(WorkerHandle {
_thread: worker_thread,
_phantom: std::marker::PhantomData,
});
}
Self {
dataset,
collate_fn,
num_workers,
workers,
task_sender,
result_receiver,
}
}
#[allow(clippy::too_many_arguments)]
fn worker_loop(
_worker_id: usize,
dataset: Arc<D>,
collate_fn: Arc<C>,
task_receiver: Arc<Mutex<mpsc::Receiver<WorkerTask>>>,
result_sender: mpsc::Sender<WorkerResult<C::Output>>,
) {
loop {
let task = {
let receiver = match task_receiver.lock() {
Ok(receiver) => receiver,
Err(_) => {
break;
}
};
receiver.recv()
};
match task {
Ok(WorkerTask { task_id, indices }) => {
let batch_result = Self::process_batch(&*dataset, &*collate_fn, indices);
let result = WorkerResult {
task_id,
result: batch_result,
};
if result_sender.send(result).is_err() {
break;
}
}
Err(_) => {
break;
}
}
}
}
fn process_batch(dataset: &D, collate_fn: &C, indices: Vec<usize>) -> Result<C::Output> {
let mut samples = Vec::with_capacity(indices.len());
for idx in indices {
match dataset.get(idx) {
Ok(sample) => samples.push(sample),
Err(e) => return Err(e),
}
}
collate_fn.collate(samples)
}
pub fn submit_task(&self, task_id: usize, indices: Vec<usize>) -> Result<()> {
let task = WorkerTask { task_id, indices };
self.task_sender.send(task).map_err(|_| {
torsh_core::error::TorshError::RuntimeError(
"Failed to send task to worker pool".to_string(),
)
})?;
Ok(())
}
pub fn get_result(&self) -> Result<WorkerResult<C::Output>> {
self.result_receiver.recv().map_err(|_| {
torsh_core::error::TorshError::RuntimeError(
"Failed to receive result from worker pool".to_string(),
)
})
}
pub fn try_get_result(&self) -> Option<WorkerResult<C::Output>> {
self.result_receiver.try_recv().ok()
}
pub fn num_workers(&self) -> usize {
self.num_workers
}
}
pub struct MultiProcessIterator<'a, D, S, C>
where
D: Dataset + Clone + Send + Sync + 'static,
S: BatchSampler,
C: Collate<D::Item> + Clone + Send + Sync + 'static,
D::Item: Send + 'static,
C::Output: Send + 'static,
{
sampler_iter: S::Iter,
worker_pool: &'a WorkerPool<D, C>,
pending_tasks: HashMap<usize, Vec<usize>>,
next_task_id: usize,
max_pending: usize,
}
impl<'a, D, S, C> MultiProcessIterator<'a, D, S, C>
where
D: Dataset + Clone + Send + Sync + 'static,
S: BatchSampler,
C: Collate<D::Item> + Clone + Send + Sync + 'static,
D::Item: Send + 'static,
C::Output: Send + 'static,
{
pub fn new(sampler_iter: S::Iter, worker_pool: &'a WorkerPool<D, C>) -> Self {
let max_pending = worker_pool.num_workers() * 2;
Self {
sampler_iter,
worker_pool,
pending_tasks: HashMap::new(),
next_task_id: 0,
max_pending,
}
}
pub fn with_buffer_size(
sampler_iter: S::Iter,
worker_pool: &'a WorkerPool<D, C>,
max_pending: usize,
) -> Self {
Self {
sampler_iter,
worker_pool,
pending_tasks: HashMap::new(),
next_task_id: 0,
max_pending,
}
}
fn submit_tasks(&mut self) {
while self.pending_tasks.len() < self.max_pending {
if let Some(indices) = self.sampler_iter.next() {
let task_id = self.next_task_id;
self.next_task_id += 1;
if self
.worker_pool
.submit_task(task_id, indices.clone())
.is_ok()
{
self.pending_tasks.insert(task_id, indices);
} else {
break;
}
} else {
break;
}
}
}
pub fn pending_count(&self) -> usize {
self.pending_tasks.len()
}
pub fn has_pending_tasks(&self) -> bool {
!self.pending_tasks.is_empty()
}
}
impl<D, S, C> Iterator for MultiProcessIterator<'_, D, S, C>
where
D: Dataset + Clone + Send + Sync + 'static,
S: BatchSampler,
S::Iter: Iterator<Item = Vec<usize>>,
C: Collate<D::Item> + Clone + Send + Sync + 'static,
D::Item: Send + 'static,
C::Output: Send + 'static,
{
type Item = Result<C::Output>;
fn next(&mut self) -> Option<Self::Item> {
self.submit_tasks();
if self.pending_tasks.is_empty() {
return None;
}
match self.worker_pool.get_result() {
Ok(WorkerResult { task_id, result }) => {
self.pending_tasks.remove(&task_id);
Some(result)
}
Err(e) => Some(Err(e)),
}
}
}
#[derive(Debug)]
enum PersistentWorkerMessage {
Task(WorkerTask),
Shutdown,
}
struct PersistentWorkerHandle {
_thread: thread::JoinHandle<()>,
}
pub struct PersistentWorkerPool<D, C>
where
D: Dataset + Clone + Send + Sync + 'static,
C: Collate<D::Item> + Clone + Send + Sync + 'static,
D::Item: Send + 'static,
C::Output: Send + 'static,
{
#[allow(dead_code)]
dataset: Arc<D>,
#[allow(dead_code)]
collate_fn: Arc<C>,
num_workers: usize,
#[allow(dead_code)]
workers: Vec<PersistentWorkerHandle>,
task_sender: mpsc::Sender<PersistentWorkerMessage>,
result_receiver: mpsc::Receiver<WorkerResult<C::Output>>,
is_shutdown: Arc<std::sync::atomic::AtomicBool>,
}
impl<D, C> PersistentWorkerPool<D, C>
where
D: Dataset + Clone + Send + Sync + 'static,
C: Collate<D::Item> + Clone + Send + Sync + 'static,
D::Item: Send + 'static,
C::Output: Send + 'static,
{
pub fn new(dataset: Arc<D>, collate_fn: Arc<C>, num_workers: usize) -> Self {
let (task_sender, task_receiver) = mpsc::channel::<PersistentWorkerMessage>();
let (result_sender, result_receiver) = mpsc::channel::<WorkerResult<C::Output>>();
let task_receiver = Arc::new(Mutex::new(task_receiver));
let is_shutdown = Arc::new(std::sync::atomic::AtomicBool::new(false));
let mut workers = Vec::with_capacity(num_workers);
for worker_id in 0..num_workers {
let dataset_clone = Arc::clone(&dataset);
let collate_fn_clone = Arc::clone(&collate_fn);
let task_receiver_clone = Arc::clone(&task_receiver);
let result_sender_clone = result_sender.clone();
let is_shutdown_clone = Arc::clone(&is_shutdown);
let worker_thread = thread::spawn(move || {
Self::persistent_worker_loop(
worker_id,
dataset_clone,
collate_fn_clone,
task_receiver_clone,
result_sender_clone,
is_shutdown_clone,
);
});
workers.push(PersistentWorkerHandle {
_thread: worker_thread,
});
}
Self {
dataset,
collate_fn,
num_workers,
workers,
task_sender,
result_receiver,
is_shutdown,
}
}
#[allow(clippy::too_many_arguments)]
fn persistent_worker_loop(
_worker_id: usize,
dataset: Arc<D>,
collate_fn: Arc<C>,
task_receiver: Arc<Mutex<mpsc::Receiver<PersistentWorkerMessage>>>,
result_sender: mpsc::Sender<WorkerResult<C::Output>>,
is_shutdown: Arc<std::sync::atomic::AtomicBool>,
) {
loop {
if is_shutdown.load(std::sync::atomic::Ordering::Relaxed) {
break;
}
let message = {
let receiver = match task_receiver.lock() {
Ok(receiver) => receiver,
Err(_) => {
break;
}
};
receiver.recv_timeout(std::time::Duration::from_millis(100))
};
match message {
Ok(PersistentWorkerMessage::Task(WorkerTask { task_id, indices })) => {
let batch_result = Self::process_batch(&*dataset, &*collate_fn, indices);
let result = WorkerResult {
task_id,
result: batch_result,
};
if result_sender.send(result).is_err() {
break;
}
}
Ok(PersistentWorkerMessage::Shutdown) => {
break;
}
Err(mpsc::RecvTimeoutError::Timeout) => {
continue;
}
Err(mpsc::RecvTimeoutError::Disconnected) => {
break;
}
}
}
}
fn process_batch(dataset: &D, collate_fn: &C, indices: Vec<usize>) -> Result<C::Output> {
let mut samples = Vec::with_capacity(indices.len());
for idx in indices {
match dataset.get(idx) {
Ok(sample) => samples.push(sample),
Err(e) => return Err(e),
}
}
collate_fn.collate(samples)
}
pub fn submit_task(&self, task_id: usize, indices: Vec<usize>) -> Result<()> {
let message = PersistentWorkerMessage::Task(WorkerTask { task_id, indices });
self.task_sender.send(message).map_err(|_| {
torsh_core::error::TorshError::RuntimeError(
"Failed to send task to persistent worker pool".to_string(),
)
})?;
Ok(())
}
pub fn get_result(&self) -> Result<WorkerResult<C::Output>> {
self.result_receiver.recv().map_err(|_| {
torsh_core::error::TorshError::RuntimeError(
"Failed to receive result from persistent worker pool".to_string(),
)
})
}
pub fn get_result_timeout(
&self,
timeout: std::time::Duration,
) -> Result<WorkerResult<C::Output>> {
self.result_receiver
.recv_timeout(timeout)
.map_err(|e| match e {
mpsc::RecvTimeoutError::Timeout => torsh_core::error::TorshError::RuntimeError(
"Timeout waiting for result from persistent worker pool".to_string(),
),
mpsc::RecvTimeoutError::Disconnected => {
torsh_core::error::TorshError::RuntimeError(
"Persistent worker pool disconnected".to_string(),
)
}
})
}
pub fn try_get_result(&self) -> Option<WorkerResult<C::Output>> {
self.result_receiver.try_recv().ok()
}
pub fn num_workers(&self) -> usize {
self.num_workers
}
pub fn is_shutdown(&self) -> bool {
self.is_shutdown.load(std::sync::atomic::Ordering::Relaxed)
}
pub fn shutdown(&self) -> Result<()> {
self.is_shutdown
.store(true, std::sync::atomic::Ordering::Relaxed);
for _ in 0..self.num_workers {
if self
.task_sender
.send(PersistentWorkerMessage::Shutdown)
.is_err()
{
break;
}
}
Ok(())
}
pub fn reset_for_epoch(&self) -> Result<()> {
Ok(())
}
}
impl<D, C> Drop for PersistentWorkerPool<D, C>
where
D: Dataset + Clone + Send + Sync + 'static,
C: Collate<D::Item> + Clone + Send + Sync + 'static,
D::Item: Send + 'static,
C::Output: Send + 'static,
{
fn drop(&mut self) {
let _ = self.shutdown();
}
}
#[derive(Debug, Clone)]
pub struct WorkerConfig {
pub num_workers: usize,
pub max_pending_tasks: Option<usize>,
pub persistent: bool,
}
impl Default for WorkerConfig {
fn default() -> Self {
Self {
num_workers: std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1),
max_pending_tasks: None,
persistent: false,
}
}
}
impl WorkerConfig {
pub fn new() -> Self {
Self::default()
}
pub fn num_workers(mut self, num_workers: usize) -> Self {
self.num_workers = num_workers;
self
}
pub fn max_pending_tasks(mut self, max_pending: usize) -> Self {
self.max_pending_tasks = Some(max_pending);
self
}
pub fn persistent(mut self, persistent: bool) -> Self {
self.persistent = persistent;
self
}
}
pub mod utils {
use super::*;
pub fn optimal_worker_count(cpu_intensive: bool) -> usize {
let cpu_count = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
if cpu_intensive {
(cpu_count * 3 / 4).max(1)
} else {
cpu_count * 2
}
}
pub fn training_config() -> WorkerConfig {
WorkerConfig::new()
.num_workers(optimal_worker_count(false))
.persistent(true)
.max_pending_tasks(optimal_worker_count(false) * 3)
}
pub fn inference_config() -> WorkerConfig {
WorkerConfig::new()
.num_workers(optimal_worker_count(true))
.persistent(false)
.max_pending_tasks(optimal_worker_count(true) * 2)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{collate::DefaultCollate, dataset::TensorDataset};
#[test]
fn test_worker_pool_creation() {
use torsh_core::device::DeviceType;
use torsh_tensor::Tensor;
let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0, 5.0], vec![5], DeviceType::Cpu)
.expect("Tensor should succeed");
let dataset = Arc::new(TensorDataset::from_tensor(tensor));
let collate_fn = Arc::new(DefaultCollate);
let worker_pool = WorkerPool::new(dataset, collate_fn, 2);
assert_eq!(worker_pool.num_workers(), 2);
}
#[test]
fn test_worker_pool_task_submission() {
use torsh_core::device::DeviceType;
use torsh_tensor::Tensor;
let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0, 5.0], vec![5], DeviceType::Cpu)
.expect("Tensor should succeed");
let dataset = Arc::new(TensorDataset::from_tensor(tensor));
let collate_fn = Arc::new(DefaultCollate);
let worker_pool = WorkerPool::new(dataset, collate_fn, 2);
assert!(worker_pool.submit_task(0, vec![0, 1]).is_ok());
}
#[test]
fn test_persistent_worker_pool_creation() {
use torsh_core::device::DeviceType;
use torsh_tensor::Tensor;
let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0, 5.0], vec![5], DeviceType::Cpu)
.expect("Tensor should succeed");
let dataset = Arc::new(TensorDataset::from_tensor(tensor));
let collate_fn = Arc::new(DefaultCollate);
let worker_pool = PersistentWorkerPool::new(dataset, collate_fn, 2);
assert_eq!(worker_pool.num_workers(), 2);
assert!(!worker_pool.is_shutdown());
}
#[test]
fn test_persistent_worker_pool_shutdown() {
use torsh_core::device::DeviceType;
use torsh_tensor::Tensor;
let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0, 5.0], vec![5], DeviceType::Cpu)
.expect("Tensor should succeed");
let dataset = Arc::new(TensorDataset::from_tensor(tensor));
let collate_fn = Arc::new(DefaultCollate);
let worker_pool = PersistentWorkerPool::new(dataset, collate_fn, 2);
assert!(worker_pool.shutdown().is_ok());
assert!(worker_pool.is_shutdown());
}
#[test]
fn test_worker_config() {
let config = WorkerConfig::new()
.num_workers(4)
.max_pending_tasks(8)
.persistent(true);
assert_eq!(config.num_workers, 4);
assert_eq!(config.max_pending_tasks, Some(8));
assert!(config.persistent);
}
#[test]
fn test_optimal_worker_count() {
let cpu_intensive_count = utils::optimal_worker_count(true);
let io_bound_count = utils::optimal_worker_count(false);
assert!(cpu_intensive_count > 0);
assert!(io_bound_count > 0);
assert!(io_bound_count >= cpu_intensive_count);
}
#[test]
fn test_training_config() {
let config = utils::training_config();
assert!(config.num_workers > 0);
assert!(config.persistent);
assert!(config.max_pending_tasks.is_some());
}
#[test]
fn test_inference_config() {
let config = utils::inference_config();
assert!(config.num_workers > 0);
assert!(!config.persistent);
assert!(config.max_pending_tasks.is_some());
}
}