palimpsest_dataflow/palimpsest/
worker.rs1use std::time::Duration;
4
5use timely::{communication::allocator::thread::Thread, WorkerConfig};
6use tokio::{sync::mpsc, task::JoinHandle};
7
8pub type LocalTimelyWorker = timely::worker::Worker<Thread>;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub struct StepLoopConfig {
14 pub command_capacity: usize,
16 pub max_steps_per_tick: usize,
18 pub idle_park: Duration,
20}
21
22impl Default for StepLoopConfig {
23 fn default() -> Self {
24 Self {
25 command_capacity: 128,
26 max_steps_per_tick: 64,
27 idle_park: Duration::from_millis(1),
28 }
29 }
30}
31
32pub enum WorkerCommand {
34 Build(Box<dyn FnOnce(&mut LocalTimelyWorker) + Send + 'static>),
36 Step,
38 Stop,
40}
41
42impl std::fmt::Debug for WorkerCommand {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 match self {
45 Self::Build(_) => f.write_str("Build(..)"),
46 Self::Step => f.write_str("Step"),
47 Self::Stop => f.write_str("Stop"),
48 }
49 }
50}
51
52#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
54pub struct WorkerStats {
55 pub steps: usize,
57 pub builds: usize,
59}
60
61#[derive(Debug)]
63pub enum WorkerError {
64 Closed,
66 Join(tokio::task::JoinError),
68}
69
70impl std::fmt::Display for WorkerError {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 match self {
73 Self::Closed => f.write_str("dataflow worker command channel is closed"),
74 Self::Join(err) => write!(f, "dataflow worker task failed: {err}"),
75 }
76 }
77}
78
79impl std::error::Error for WorkerError {}
80
81#[derive(Debug)]
83pub struct WorkerHandle {
84 commands: mpsc::Sender<WorkerCommand>,
85 task: JoinHandle<WorkerStats>,
86}
87
88#[must_use]
90pub fn spawn_worker(config: StepLoopConfig) -> WorkerHandle {
91 let (commands, receiver) = mpsc::channel(config.command_capacity);
92 let task = tokio::task::spawn_blocking(move || step_loop(config, receiver));
93
94 WorkerHandle { commands, task }
95}
96
97impl WorkerHandle {
98 pub async fn build(
100 &self,
101 build: impl FnOnce(&mut LocalTimelyWorker) + Send + 'static,
102 ) -> Result<(), WorkerError> {
103 self.commands
104 .send(WorkerCommand::Build(Box::new(build)))
105 .await
106 .map_err(|_| WorkerError::Closed)
107 }
108
109 pub async fn step(&self) -> Result<(), WorkerError> {
111 self.commands
112 .send(WorkerCommand::Step)
113 .await
114 .map_err(|_| WorkerError::Closed)
115 }
116
117 pub async fn stop(self) -> Result<WorkerStats, WorkerError> {
119 self.commands
120 .send(WorkerCommand::Stop)
121 .await
122 .map_err(|_| WorkerError::Closed)?;
123 self.task.await.map_err(WorkerError::Join)
124 }
125}
126
127fn step_loop(config: StepLoopConfig, mut commands: mpsc::Receiver<WorkerCommand>) -> WorkerStats {
128 let mut worker = LocalTimelyWorker::new(WorkerConfig::default(), Thread::default(), None);
129 let mut stats = WorkerStats::default();
130
131 while let Some(command) = commands.blocking_recv() {
132 if apply_command(command, &mut worker, &mut stats, config.max_steps_per_tick) {
133 break;
134 }
135
136 while let Ok(command) = commands.try_recv() {
137 if apply_command(command, &mut worker, &mut stats, config.max_steps_per_tick) {
138 return stats;
139 }
140 }
141
142 worker.step_or_park(Some(config.idle_park));
143 stats.steps = stats.steps.saturating_add(1);
144 }
145
146 stats
147}
148
149fn apply_command(
150 command: WorkerCommand,
151 worker: &mut LocalTimelyWorker,
152 stats: &mut WorkerStats,
153 max_steps_per_tick: usize,
154) -> bool {
155 match command {
156 WorkerCommand::Build(build) => {
157 build(worker);
158 stats.builds = stats.builds.saturating_add(1);
159 false
160 }
161 WorkerCommand::Step => {
162 for _ in 0..max_steps_per_tick {
163 worker.step();
164 stats.steps = stats.steps.saturating_add(1);
165 }
166 false
167 }
168 WorkerCommand::Stop => true,
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use std::sync::{
175 atomic::{AtomicUsize, Ordering},
176 Arc,
177 };
178
179 use timely::dataflow::operators::{Inspect, ToStream};
180
181 use super::{spawn_worker, StepLoopConfig};
182
183 #[tokio::test]
184 async fn worker_builds_dataflow_steps_and_stops() {
185 let seen = Arc::new(AtomicUsize::new(0));
186 let worker = spawn_worker(StepLoopConfig {
187 command_capacity: 4,
188 max_steps_per_tick: 2,
189 ..StepLoopConfig::default()
190 });
191
192 let seen_in_dataflow = Arc::clone(&seen);
193 worker
194 .build(move |worker| {
195 worker.dataflow::<u64, _, _>(move |scope| {
196 let seen_in_operator = Arc::clone(&seen_in_dataflow);
197 (0..3).to_stream(scope).inspect(move |_| {
198 seen_in_operator.fetch_add(1, Ordering::SeqCst);
199 });
200 });
201 })
202 .await
203 .unwrap();
204 worker.step().await.unwrap();
205 let stats = worker.stop().await.unwrap();
206
207 assert_eq!(seen.load(Ordering::SeqCst), 3);
208 assert_eq!(stats.builds, 1);
209 assert!(stats.steps > 0);
210 }
211}