use copybook_sequence_ring::{SequenceRing, SequenceRingStats, SequencedRecord};
use crossbeam_channel::{Sender, bounded};
use std::sync::Arc;
use std::thread;
use tracing::{debug, warn};
use super::ScratchBuffers;
#[derive(Debug)]
pub struct WorkerPool<Input, Output> {
input_sender: Sender<SequencedRecord<Input>>,
output_ring: SequenceRing<Output>,
worker_handles: Vec<thread::JoinHandle<()>>,
next_input_sequence: u64,
}
impl<Input, Output> WorkerPool<Input, Output>
where
Input: Send + 'static,
Output: Send + 'static,
{
#[inline]
#[must_use]
pub fn new<F>(
num_workers: usize,
channel_capacity: usize,
max_window_size: usize,
worker_fn: F,
) -> Self
where
F: Fn(Input, &mut ScratchBuffers) -> Output + Send + Sync + Clone + 'static,
{
let (input_sender, input_receiver) = bounded(channel_capacity);
let output_ring = SequenceRing::new(channel_capacity, max_window_size);
let output_sender = output_ring.sender();
let worker_fn = Arc::new(worker_fn);
let mut worker_handles = Vec::with_capacity(num_workers);
for worker_id in 0..num_workers {
let input_receiver = input_receiver.clone();
let output_sender = output_sender.clone();
let worker_fn = Arc::clone(&worker_fn);
let handle = thread::spawn(move || {
let mut scratch_buffers = ScratchBuffers::new();
debug!("Worker {} started", worker_id);
while let Ok(sequenced_input) = input_receiver.recv() {
let SequencedRecord {
sequence_id,
data: input,
} = sequenced_input;
scratch_buffers.clear();
let output = worker_fn(input, &mut scratch_buffers);
let sequenced_output = SequencedRecord::new(sequence_id, output);
if output_sender.send(sequenced_output).is_err() {
debug!("Worker {} output channel closed", worker_id);
break;
}
}
debug!("Worker {} finished", worker_id);
});
worker_handles.push(handle);
}
Self {
input_sender,
output_ring,
worker_handles,
next_input_sequence: 1,
}
}
#[inline]
#[must_use = "Handle the Result or propagate the error"]
pub fn submit(
&mut self,
input: Input,
) -> Result<(), crossbeam_channel::SendError<SequencedRecord<Input>>> {
let sequenced_input = SequencedRecord::new(self.next_input_sequence, input);
self.next_input_sequence += 1;
self.input_sender.send(sequenced_input)
}
#[inline]
#[must_use = "Handle the Result or propagate the error"]
pub fn recv_ordered(&mut self) -> Result<Option<Output>, crossbeam_channel::RecvError> {
self.output_ring.recv_ordered()
}
#[inline]
#[must_use = "Handle the Result or propagate the error"]
pub fn try_recv_ordered(&mut self) -> Result<Option<Output>, crossbeam_channel::TryRecvError> {
self.output_ring.try_recv_ordered()
}
#[inline]
#[must_use = "Handle the Result or propagate the error"]
pub fn shutdown(self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
drop(self.input_sender);
for (i, handle) in self.worker_handles.into_iter().enumerate() {
if let Err(e) = handle.join() {
warn!("Worker {} panicked: {:?}", i, e);
}
}
Ok(())
}
#[inline]
#[must_use]
pub fn stats(&self) -> WorkerPoolStats {
WorkerPoolStats {
num_workers: self.worker_handles.len(),
next_input_sequence: self.next_input_sequence,
sequence_ring_stats: self.output_ring.stats(),
}
}
}
#[derive(Debug, Clone)]
pub struct WorkerPoolStats {
pub num_workers: usize,
pub next_input_sequence: u64,
pub sequence_ring_stats: SequenceRingStats,
}