1use std::collections::HashMap;
22use std::sync::mpsc as std_mpsc;
23use std::sync::{Arc, Mutex};
24use std::thread::JoinHandle;
25
26use palimpsest_wal::TableId;
27use timely::communication::allocator::thread::Thread;
28use timely::dataflow::operators::probe::Probe;
29use timely::dataflow::operators::Inspect;
30use timely::dataflow::ProbeHandle;
31use timely::worker::Worker as TimelyWorker;
32use timely::WorkerConfig;
33
34use crate::input::{Input, InputSession};
35use crate::palimpsest::compile_mir::{install_plan, CompiledPlan};
36use crate::palimpsest::time::Lsn;
37use crate::palimpsest::wal::{Row, WalTransaction};
38
39#[must_use]
46pub fn snapshot_run(plan: &CompiledPlan, inputs: HashMap<TableId, Vec<Row>>) -> Vec<Row> {
47 let captured: Arc<Mutex<Vec<Row>>> = Arc::new(Mutex::new(Vec::new()));
48 let cap = Arc::clone(&captured);
49 let plan = plan.clone();
50
51 timely::execute_directly(move |worker| {
52 worker.dataflow::<u64, _, _>(|scope| {
53 let mut input_collections = HashMap::new();
54 for table in &plan.inputs {
55 let rows = inputs.get(table).cloned().unwrap_or_default();
56 let (_, collection) = scope.new_collection_from(rows);
57 input_collections.insert(*table, collection);
58 }
59 let output = install_plan(&plan, scope, &input_collections);
60 let cap_inner = Arc::clone(&cap);
61 output.inner.inspect(move |entry: &(Row, u64, isize)| {
62 let (row, _time, diff) = entry;
63 if *diff > 0 {
64 cap_inner.lock().expect("capture mutex").push(row.clone());
65 }
66 });
67 });
68 });
69
70 let mut rows = captured.lock().expect("capture mutex");
71 std::mem::take(&mut *rows)
72}
73
74#[derive(Debug, Clone)]
82pub struct AggregateDelta {
83 pub row: Row,
85 pub lsn: Lsn,
87 pub diff: isize,
89}
90
91enum DataflowCommand {
94 Seed {
97 inputs: HashMap<TableId, Vec<Row>>,
98 reply: std_mpsc::SyncSender<Vec<Row>>,
99 },
100 Apply {
104 diffs: Vec<(TableId, Row, isize)>,
105 lsn: Lsn,
106 reply: std_mpsc::SyncSender<Vec<AggregateDelta>>,
107 },
108 Stop,
110}
111
112struct IncrementalDataflow {
115 cmd_tx: std_mpsc::Sender<DataflowCommand>,
116 join: Option<JoinHandle<()>>,
117}
118
119impl IncrementalDataflow {
120 fn spawn(plan: CompiledPlan) -> Self {
121 let (cmd_tx, cmd_rx) = std_mpsc::channel::<DataflowCommand>();
122 let join = std::thread::Builder::new()
123 .name("palimpsest-dataflow".into())
124 .spawn(move || run_worker(plan, cmd_rx))
125 .expect("spawn dataflow worker thread");
126 Self {
127 cmd_tx,
128 join: Some(join),
129 }
130 }
131
132 fn seed(&self, inputs: HashMap<TableId, Vec<Row>>) -> Vec<Row> {
133 let (tx, rx) = std_mpsc::sync_channel(0);
134 if self
135 .cmd_tx
136 .send(DataflowCommand::Seed { inputs, reply: tx })
137 .is_err()
138 {
139 return Vec::new();
140 }
141 rx.recv().unwrap_or_default()
142 }
143
144 fn apply(&self, diffs: Vec<(TableId, Row, isize)>, lsn: Lsn) -> Vec<AggregateDelta> {
145 let (tx, rx) = std_mpsc::sync_channel(0);
146 if self
147 .cmd_tx
148 .send(DataflowCommand::Apply {
149 diffs,
150 lsn,
151 reply: tx,
152 })
153 .is_err()
154 {
155 return Vec::new();
156 }
157 rx.recv().unwrap_or_default()
158 }
159}
160
161impl Drop for IncrementalDataflow {
162 fn drop(&mut self) {
163 let _ = self.cmd_tx.send(DataflowCommand::Stop);
164 if let Some(join) = self.join.take() {
165 let _ = join.join();
166 }
167 }
168}
169
170fn run_worker(plan: CompiledPlan, cmd_rx: std_mpsc::Receiver<DataflowCommand>) {
174 let mut worker = TimelyWorker::new(WorkerConfig::default(), Thread::default(), None);
175
176 let captured: Arc<Mutex<Vec<(Row, Lsn, isize)>>> = Arc::new(Mutex::new(Vec::new()));
180 let cap_for_dataflow = Arc::clone(&captured);
181
182 let mut inputs: HashMap<TableId, InputSession<Lsn, Row, isize>> = HashMap::new();
183 let mut probe: ProbeHandle<Lsn> = ProbeHandle::new();
184
185 worker.dataflow::<Lsn, _, _>(|scope| {
189 let mut input_collections = HashMap::new();
190 for table in &plan.inputs {
191 let mut input = InputSession::<Lsn, Row, isize>::new();
192 let collection = input.to_collection(scope);
193 input_collections.insert(*table, collection);
194 inputs.insert(*table, input);
195 }
196 let output = install_plan(&plan, scope, &input_collections);
197 let cap_for_inspect = Arc::clone(&cap_for_dataflow);
198 output
199 .inner
200 .probe_with(&mut probe)
201 .inspect(move |entry: &(Row, Lsn, isize)| {
202 cap_for_inspect.lock().expect("capture").push(entry.clone());
203 });
204 });
205
206 while let Ok(cmd) = cmd_rx.recv() {
207 match cmd {
208 DataflowCommand::Seed {
209 inputs: seed_rows,
210 reply,
211 } => {
212 for (table, rows) in seed_rows {
217 if let Some(session) = inputs.get_mut(&table) {
218 for row in rows {
219 session.update_at(row, Lsn::new(0), 1);
220 }
221 }
222 }
223 advance_and_step(&mut worker, &mut inputs, &probe, Lsn::new(1));
224 let drained = drain_captures(&captured);
225 let initial: Vec<Row> = drained
226 .into_iter()
227 .filter(|(_, _, diff)| *diff > 0)
228 .map(|(row, _, _)| row)
229 .collect();
230 let _ = reply.send(initial);
231 }
232 DataflowCommand::Apply { diffs, lsn, reply } => {
233 for (table, row, diff) in diffs {
234 if let Some(session) = inputs.get_mut(&table) {
235 session.update_at(row, lsn, diff);
236 }
237 }
238 let next = Lsn::new(lsn.get().saturating_add(1));
239 advance_and_step(&mut worker, &mut inputs, &probe, next);
240 let drained = drain_captures(&captured);
241 let deltas: Vec<AggregateDelta> = drained
242 .into_iter()
243 .map(|(row, t, diff)| AggregateDelta { row, lsn: t, diff })
244 .collect();
245 let _ = reply.send(deltas);
246 }
247 DataflowCommand::Stop => break,
248 }
249 }
250}
251
252fn advance_and_step(
256 worker: &mut TimelyWorker<Thread>,
257 inputs: &mut HashMap<TableId, InputSession<Lsn, Row, isize>>,
258 probe: &ProbeHandle<Lsn>,
259 target: Lsn,
260) {
261 for session in inputs.values_mut() {
262 session.advance_to(target);
263 session.flush();
264 }
265 while probe.less_than(&target) {
266 worker.step();
267 }
268}
269
270fn drain_captures(cap: &Arc<Mutex<Vec<(Row, Lsn, isize)>>>) -> Vec<(Row, Lsn, isize)> {
271 let mut guard = cap.lock().expect("capture");
272 std::mem::take(&mut *guard)
273}
274
275struct PlanState {
277 dataflow: IncrementalDataflow,
278 last_output: Vec<Row>,
283 last_lsn: Lsn,
289 subscribers: Vec<u64>,
293}
294
295#[derive(Default)]
296struct HostInner {
297 plans: HashMap<String, PlanState>,
298}
299
300pub struct PersistentHost {
305 inner: Arc<Mutex<HostInner>>,
306}
307
308impl PersistentHost {
309 #[must_use]
311 pub fn new() -> Self {
312 Self {
313 inner: Arc::new(Mutex::new(HostInner::default())),
314 }
315 }
316
317 pub fn cached_view(&self, canonical: &str, subscriber: u64) -> Option<(Vec<Row>, Lsn)> {
337 let mut inner = self.inner.lock().expect("host inner");
338 let state = inner.plans.get_mut(canonical)?;
339 state.subscribers.push(subscriber);
340 Some((state.last_output.clone(), state.last_lsn))
341 }
342
343 pub fn register_or_seed(
354 &self,
355 canonical: &str,
356 plan: &CompiledPlan,
357 inputs: HashMap<TableId, Vec<Row>>,
358 snapshot_lsn: Lsn,
359 subscriber: u64,
360 ) -> Vec<Row> {
361 {
364 let mut inner = self.inner.lock().expect("host inner");
365 if let Some(state) = inner.plans.get_mut(canonical) {
366 state.subscribers.push(subscriber);
367 return state.last_output.clone();
368 }
369 }
370
371 let dataflow = IncrementalDataflow::spawn(plan.clone());
374 let initial = dataflow.seed(inputs);
375
376 let mut inner = self.inner.lock().expect("host inner");
377 if let Some(state) = inner.plans.get_mut(canonical) {
381 state.subscribers.push(subscriber);
382 return state.last_output.clone();
383 }
384 inner.plans.insert(
385 canonical.to_owned(),
386 PlanState {
387 dataflow,
388 last_output: initial.clone(),
389 last_lsn: snapshot_lsn,
390 subscribers: vec![subscriber],
391 },
392 );
393 initial
394 }
395
396 #[must_use]
400 pub fn subscribers(&self, canonical: &str) -> Option<Vec<u64>> {
401 let inner = self.inner.lock().expect("host inner");
402 inner.plans.get(canonical).map(|s| s.subscribers.clone())
403 }
404
405 pub fn push_table_diff(
408 &self,
409 canonical: &str,
410 table_id: TableId,
411 row: Row,
412 diff: isize,
413 lsn: Lsn,
414 ) -> Vec<AggregateDelta> {
415 self.apply_and_fanout(canonical, vec![(table_id, row, diff)], lsn)
416 .map_or_else(Vec::new, |(deltas, _subs)| deltas)
417 }
418
419 pub fn apply_and_fanout(
428 &self,
429 canonical: &str,
430 diffs: Vec<(TableId, Row, isize)>,
431 lsn: Lsn,
432 ) -> Option<(Vec<AggregateDelta>, Vec<u64>)> {
433 let mut inner = self.inner.lock().expect("host inner");
434 let state = inner.plans.get_mut(canonical)?;
435 let deltas = state.dataflow.apply(diffs, lsn);
436 apply_deltas_to_cache(&mut state.last_output, &deltas);
437 state.last_lsn = lsn;
438 Some((deltas, state.subscribers.clone()))
439 }
440
441 pub fn push_table_batch(
444 &self,
445 canonical: &str,
446 diffs: Vec<(TableId, Row, isize)>,
447 lsn: Lsn,
448 ) -> Vec<AggregateDelta> {
449 self.apply_and_fanout(canonical, diffs, lsn)
450 .map_or_else(Vec::new, |(deltas, _subs)| deltas)
451 }
452
453 pub fn push_transaction(
455 &self,
456 canonical: &str,
457 transaction: &WalTransaction,
458 ) -> Vec<AggregateDelta> {
459 let diffs = transaction
460 .updates
461 .iter()
462 .map(|update| (update.table, update.row.clone(), update.diff))
463 .collect();
464 self.push_table_batch(canonical, diffs, transaction.commit_lsn)
465 }
466
467 pub fn release(&self, canonical: &str, subscriber: u64) -> usize {
472 let mut inner = self.inner.lock().expect("host inner");
473 let Some(state) = inner.plans.get_mut(canonical) else {
474 return 0;
475 };
476 if let Some(pos) = state.subscribers.iter().position(|s| *s == subscriber) {
477 state.subscribers.swap_remove(pos);
478 }
479 let remaining = state.subscribers.len();
480 if remaining == 0 {
481 inner.plans.remove(canonical);
482 }
483 remaining
484 }
485}
486
487impl Default for PersistentHost {
488 fn default() -> Self {
489 Self::new()
490 }
491}
492
493fn apply_deltas_to_cache(rows: &mut Vec<Row>, deltas: &[AggregateDelta]) {
497 for delta in deltas {
498 if delta.diff > 0 {
499 for _ in 0..delta.diff {
500 rows.push(delta.row.clone());
501 }
502 } else if delta.diff < 0 {
503 for _ in 0..-delta.diff {
504 if let Some(pos) = rows.iter().position(|r| r == &delta.row) {
505 rows.swap_remove(pos);
506 }
507 }
508 }
509 }
510}
511
512#[cfg(test)]
513mod tests {
514 use super::*;
515 use palimpsest_sql::catalog::ColumnType;
516 use palimpsest_sql::lower::parse_and_lower;
517 use palimpsest_wal::Datum;
518
519 use crate::palimpsest::compile_mir::compile_mir;
520 use crate::palimpsest::eval::ScalarSchema;
521
522 fn events_schema() -> ScalarSchema {
523 ScalarSchema::from_pairs([
524 ("id".to_owned(), ColumnType::Int),
525 ("category_id".to_owned(), ColumnType::Int),
526 ("value".to_owned(), ColumnType::Int),
527 ])
528 }
529
530 fn lookup(table: &str) -> Option<(TableId, ScalarSchema)> {
531 match table {
532 "events" => Some((TableId::new(2), events_schema())),
533 _ => None,
534 }
535 }
536
537 fn row(values: Vec<Datum>) -> Row {
538 values.into_iter().collect()
539 }
540
541 #[test]
542 fn snapshot_run_emits_aggregate_rows() {
543 let sql = "WITH per_category AS (
544 SELECT category_id, COUNT(*) AS n, SUM(value) AS total
545 FROM events
546 GROUP BY category_id
547 )
548 SELECT category_id, n, total
549 FROM per_category
550 ORDER BY total DESC
551 LIMIT 8";
552 let graph = parse_and_lower(sql).unwrap();
553 let plan = compile_mir(&graph, &lookup).unwrap();
554
555 let mut inputs = HashMap::new();
556 inputs.insert(
557 TableId::new(2),
558 vec![
559 row(vec![Datum::I64(1), Datum::I64(7), Datum::I64(100)]),
560 row(vec![Datum::I64(2), Datum::I64(7), Datum::I64(50)]),
561 row(vec![Datum::I64(3), Datum::I64(9), Datum::I64(20)]),
562 row(vec![Datum::I64(4), Datum::I64(9), Datum::I64(20)]),
563 row(vec![Datum::I64(5), Datum::I64(11), Datum::I64(5)]),
564 ],
565 );
566
567 let mut output = snapshot_run(&plan, inputs);
568 output.sort();
569
570 assert_eq!(output.len(), 3, "three categories");
571 }
572
573 #[test]
574 fn persistent_host_emits_initial_and_diffs() {
575 let sql = "WITH per_category AS (
576 SELECT category_id, COUNT(*) AS n, SUM(value) AS total
577 FROM events
578 GROUP BY category_id
579 )
580 SELECT category_id, n, total
581 FROM per_category
582 ORDER BY total DESC
583 LIMIT 8";
584 let graph = parse_and_lower(sql).unwrap();
585 let plan = compile_mir(&graph, &lookup).unwrap();
586
587 let host = PersistentHost::new();
588 let canonical = "events.top_categories";
589
590 let mut seed = HashMap::new();
591 seed.insert(
592 TableId::new(2),
593 vec![
594 row(vec![Datum::I64(1), Datum::I64(7), Datum::I64(100)]),
595 row(vec![Datum::I64(2), Datum::I64(7), Datum::I64(50)]),
596 row(vec![Datum::I64(3), Datum::I64(9), Datum::I64(20)]),
597 ],
598 );
599 let mut initial = host.register_or_seed(canonical, &plan, seed, Lsn::new(1), 42);
600 initial.sort();
601 assert_eq!(initial.len(), 2, "initial has cat 7 + cat 9");
602
603 let next_lsn = Lsn::new(2);
606 let deltas = host.push_table_diff(
607 canonical,
608 TableId::new(2),
609 row(vec![Datum::I64(4), Datum::I64(9), Datum::I64(100)]),
610 1,
611 next_lsn,
612 );
613
614 let retracts: Vec<_> = deltas.iter().filter(|d| d.diff < 0).collect();
615 let asserts: Vec<_> = deltas.iter().filter(|d| d.diff > 0).collect();
616 assert_eq!(retracts.len(), 1, "one retract — old cat 9 row");
617 assert_eq!(asserts.len(), 1, "one assert — new cat 9 row");
618
619 let retracted = &retracts[0].row;
621 assert_eq!(retracted.get(0), Some(&Datum::I64(9)));
622 assert_eq!(retracted.get(1), Some(&Datum::I64(1)));
623 assert_eq!(retracted.get(2), Some(&Datum::I64(20)));
624 let asserted = &asserts[0].row;
625 assert_eq!(asserted.get(0), Some(&Datum::I64(9)));
626 assert_eq!(asserted.get(1), Some(&Datum::I64(2)));
627 assert_eq!(asserted.get(2), Some(&Datum::I64(120)));
628
629 host.release(canonical, 42);
630 }
631
632 #[test]
633 fn persistent_host_batch_coalesces() {
634 let sql = "WITH per_category AS (
635 SELECT category_id, COUNT(*) AS n, SUM(value) AS total
636 FROM events
637 GROUP BY category_id
638 )
639 SELECT category_id, n, total
640 FROM per_category
641 ORDER BY total DESC
642 LIMIT 8";
643 let plan = compile_mir(&parse_and_lower(sql).unwrap(), &lookup).unwrap();
644 let host = PersistentHost::new();
645 let canonical = "events.batch";
646
647 let mut seed = HashMap::new();
648 seed.insert(
649 TableId::new(2),
650 vec![row(vec![Datum::I64(1), Datum::I64(7), Datum::I64(10)])],
651 );
652 host.register_or_seed(canonical, &plan, seed, Lsn::new(1), 7);
653
654 let batch = vec![
655 (
656 TableId::new(2),
657 row(vec![Datum::I64(2), Datum::I64(7), Datum::I64(20)]),
658 1,
659 ),
660 (
661 TableId::new(2),
662 row(vec![Datum::I64(3), Datum::I64(7), Datum::I64(30)]),
663 1,
664 ),
665 ];
666 let deltas = host.push_table_batch(canonical, batch, Lsn::new(2));
667 assert_eq!(deltas.len(), 2);
669 assert!(deltas.iter().all(|d| d.lsn == Lsn::new(2)));
670
671 host.release(canonical, 7);
672 }
673
674 #[test]
675 fn cached_view_attaches_late_subscriber_to_current_state() {
676 let sql = "WITH per_category AS (
677 SELECT category_id, COUNT(*) AS n, SUM(value) AS total
678 FROM events
679 GROUP BY category_id
680 )
681 SELECT category_id, n, total
682 FROM per_category
683 ORDER BY total DESC
684 LIMIT 8";
685 let plan = compile_mir(&parse_and_lower(sql).unwrap(), &lookup).unwrap();
686 let host = PersistentHost::new();
687 let canonical = "events.shared";
688
689 let mut seed = HashMap::new();
690 seed.insert(
691 TableId::new(2),
692 vec![row(vec![Datum::I64(1), Datum::I64(7), Datum::I64(10)])],
693 );
694 host.register_or_seed(canonical, &plan, seed, Lsn::new(5), 1);
696
697 let apply_lsn = Lsn::new(6);
699 let (_deltas, subs_after_apply) = host
700 .apply_and_fanout(
701 canonical,
702 vec![(
703 TableId::new(2),
704 row(vec![Datum::I64(2), Datum::I64(7), Datum::I64(20)]),
705 1,
706 )],
707 apply_lsn,
708 )
709 .expect("plan still registered");
710 assert_eq!(subs_after_apply, vec![1]);
711
712 let (cached, cached_lsn) = host
716 .cached_view(canonical, 2)
717 .expect("cache hit on registered plan");
718 assert_eq!(cached_lsn, apply_lsn);
719 assert!(cached.iter().any(|r| r.get(2) == Some(&Datum::I64(30))));
721
722 let subs_view = host.subscribers(canonical).expect("plan registered");
723 assert_eq!(subs_view, vec![1, 2]);
724
725 assert_eq!(host.release(canonical, 1), 1);
727 assert!(host.subscribers(canonical).is_some());
728 assert_eq!(host.release(canonical, 2), 0);
730 assert!(host.subscribers(canonical).is_none());
731 }
732}