use std::time::Duration;
use timely::{communication::allocator::thread::Thread, WorkerConfig};
use tokio::{sync::mpsc, task::JoinHandle};
pub type LocalTimelyWorker = timely::worker::Worker<Thread>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct StepLoopConfig {
pub command_capacity: usize,
pub max_steps_per_tick: usize,
pub idle_park: Duration,
}
impl Default for StepLoopConfig {
fn default() -> Self {
Self {
command_capacity: 128,
max_steps_per_tick: 64,
idle_park: Duration::from_millis(1),
}
}
}
pub enum WorkerCommand {
Build(Box<dyn FnOnce(&mut LocalTimelyWorker) + Send + 'static>),
Step,
Stop,
}
impl std::fmt::Debug for WorkerCommand {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Build(_) => f.write_str("Build(..)"),
Self::Step => f.write_str("Step"),
Self::Stop => f.write_str("Stop"),
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct WorkerStats {
pub steps: usize,
pub builds: usize,
}
#[derive(Debug)]
pub enum WorkerError {
Closed,
Join(tokio::task::JoinError),
}
impl std::fmt::Display for WorkerError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Closed => f.write_str("dataflow worker command channel is closed"),
Self::Join(err) => write!(f, "dataflow worker task failed: {err}"),
}
}
}
impl std::error::Error for WorkerError {}
#[derive(Debug)]
pub struct WorkerHandle {
commands: mpsc::Sender<WorkerCommand>,
task: JoinHandle<WorkerStats>,
}
#[must_use]
pub fn spawn_worker(config: StepLoopConfig) -> WorkerHandle {
let (commands, receiver) = mpsc::channel(config.command_capacity);
let task = tokio::task::spawn_blocking(move || step_loop(config, receiver));
WorkerHandle { commands, task }
}
impl WorkerHandle {
pub async fn build(
&self,
build: impl FnOnce(&mut LocalTimelyWorker) + Send + 'static,
) -> Result<(), WorkerError> {
self.commands
.send(WorkerCommand::Build(Box::new(build)))
.await
.map_err(|_| WorkerError::Closed)
}
pub async fn step(&self) -> Result<(), WorkerError> {
self.commands
.send(WorkerCommand::Step)
.await
.map_err(|_| WorkerError::Closed)
}
pub async fn stop(self) -> Result<WorkerStats, WorkerError> {
self.commands
.send(WorkerCommand::Stop)
.await
.map_err(|_| WorkerError::Closed)?;
self.task.await.map_err(WorkerError::Join)
}
}
fn step_loop(config: StepLoopConfig, mut commands: mpsc::Receiver<WorkerCommand>) -> WorkerStats {
let mut worker = LocalTimelyWorker::new(WorkerConfig::default(), Thread::default(), None);
let mut stats = WorkerStats::default();
while let Some(command) = commands.blocking_recv() {
if apply_command(command, &mut worker, &mut stats, config.max_steps_per_tick) {
break;
}
while let Ok(command) = commands.try_recv() {
if apply_command(command, &mut worker, &mut stats, config.max_steps_per_tick) {
return stats;
}
}
worker.step_or_park(Some(config.idle_park));
stats.steps = stats.steps.saturating_add(1);
}
stats
}
fn apply_command(
command: WorkerCommand,
worker: &mut LocalTimelyWorker,
stats: &mut WorkerStats,
max_steps_per_tick: usize,
) -> bool {
match command {
WorkerCommand::Build(build) => {
build(worker);
stats.builds = stats.builds.saturating_add(1);
false
}
WorkerCommand::Step => {
for _ in 0..max_steps_per_tick {
worker.step();
stats.steps = stats.steps.saturating_add(1);
}
false
}
WorkerCommand::Stop => true,
}
}
#[cfg(test)]
mod tests {
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use timely::dataflow::operators::{Inspect, ToStream};
use super::{spawn_worker, StepLoopConfig};
#[tokio::test]
async fn worker_builds_dataflow_steps_and_stops() {
let seen = Arc::new(AtomicUsize::new(0));
let worker = spawn_worker(StepLoopConfig {
command_capacity: 4,
max_steps_per_tick: 2,
..StepLoopConfig::default()
});
let seen_in_dataflow = Arc::clone(&seen);
worker
.build(move |worker| {
worker.dataflow::<u64, _, _>(move |scope| {
let seen_in_operator = Arc::clone(&seen_in_dataflow);
(0..3).to_stream(scope).inspect(move |_| {
seen_in_operator.fetch_add(1, Ordering::SeqCst);
});
});
})
.await
.unwrap();
worker.step().await.unwrap();
let stats = worker.stop().await.unwrap();
assert_eq!(seen.load(Ordering::SeqCst), 3);
assert_eq!(stats.builds, 1);
assert!(stats.steps > 0);
}
}