use std::collections::HashMap;
use std::sync::mpsc as std_mpsc;
use std::sync::{Arc, Mutex};
use std::thread::JoinHandle;
use palimpsest_wal::TableId;
use timely::communication::allocator::thread::Thread;
use timely::dataflow::operators::probe::Probe;
use timely::dataflow::operators::Inspect;
use timely::dataflow::ProbeHandle;
use timely::worker::Worker as TimelyWorker;
use timely::WorkerConfig;
use crate::input::{Input, InputSession};
use crate::palimpsest::compile_mir::{install_plan, CompiledPlan};
use crate::palimpsest::time::Lsn;
use crate::palimpsest::wal::{Row, WalTransaction};
#[must_use]
pub fn snapshot_run(plan: &CompiledPlan, inputs: HashMap<TableId, Vec<Row>>) -> Vec<Row> {
let captured: Arc<Mutex<Vec<Row>>> = Arc::new(Mutex::new(Vec::new()));
let cap = Arc::clone(&captured);
let plan = plan.clone();
timely::execute_directly(move |worker| {
worker.dataflow::<u64, _, _>(|scope| {
let mut input_collections = HashMap::new();
for table in &plan.inputs {
let rows = inputs.get(table).cloned().unwrap_or_default();
let (_, collection) = scope.new_collection_from(rows);
input_collections.insert(*table, collection);
}
let output = install_plan(&plan, scope, &input_collections);
let cap_inner = Arc::clone(&cap);
output.inner.inspect(move |entry: &(Row, u64, isize)| {
let (row, _time, diff) = entry;
if *diff > 0 {
cap_inner.lock().expect("capture mutex").push(row.clone());
}
});
});
});
let mut rows = captured.lock().expect("capture mutex");
std::mem::take(&mut *rows)
}
#[derive(Debug, Clone)]
pub struct AggregateDelta {
pub row: Row,
pub lsn: Lsn,
pub diff: isize,
}
enum DataflowCommand {
Seed {
inputs: HashMap<TableId, Vec<Row>>,
reply: std_mpsc::SyncSender<Vec<Row>>,
},
Apply {
diffs: Vec<(TableId, Row, isize)>,
lsn: Lsn,
reply: std_mpsc::SyncSender<Vec<AggregateDelta>>,
},
Stop,
}
struct IncrementalDataflow {
cmd_tx: std_mpsc::Sender<DataflowCommand>,
join: Option<JoinHandle<()>>,
}
impl IncrementalDataflow {
fn spawn(plan: CompiledPlan) -> Self {
let (cmd_tx, cmd_rx) = std_mpsc::channel::<DataflowCommand>();
let join = std::thread::Builder::new()
.name("palimpsest-dataflow".into())
.spawn(move || run_worker(plan, cmd_rx))
.expect("spawn dataflow worker thread");
Self {
cmd_tx,
join: Some(join),
}
}
fn seed(&self, inputs: HashMap<TableId, Vec<Row>>) -> Vec<Row> {
let (tx, rx) = std_mpsc::sync_channel(0);
if self
.cmd_tx
.send(DataflowCommand::Seed { inputs, reply: tx })
.is_err()
{
return Vec::new();
}
rx.recv().unwrap_or_default()
}
fn apply(&self, diffs: Vec<(TableId, Row, isize)>, lsn: Lsn) -> Vec<AggregateDelta> {
let (tx, rx) = std_mpsc::sync_channel(0);
if self
.cmd_tx
.send(DataflowCommand::Apply {
diffs,
lsn,
reply: tx,
})
.is_err()
{
return Vec::new();
}
rx.recv().unwrap_or_default()
}
}
impl Drop for IncrementalDataflow {
fn drop(&mut self) {
let _ = self.cmd_tx.send(DataflowCommand::Stop);
if let Some(join) = self.join.take() {
let _ = join.join();
}
}
}
fn run_worker(plan: CompiledPlan, cmd_rx: std_mpsc::Receiver<DataflowCommand>) {
let mut worker = TimelyWorker::new(WorkerConfig::default(), Thread::default(), None);
let captured: Arc<Mutex<Vec<(Row, Lsn, isize)>>> = Arc::new(Mutex::new(Vec::new()));
let cap_for_dataflow = Arc::clone(&captured);
let mut inputs: HashMap<TableId, InputSession<Lsn, Row, isize>> = HashMap::new();
let mut probe: ProbeHandle<Lsn> = ProbeHandle::new();
worker.dataflow::<Lsn, _, _>(|scope| {
let mut input_collections = HashMap::new();
for table in &plan.inputs {
let mut input = InputSession::<Lsn, Row, isize>::new();
let collection = input.to_collection(scope);
input_collections.insert(*table, collection);
inputs.insert(*table, input);
}
let output = install_plan(&plan, scope, &input_collections);
let cap_for_inspect = Arc::clone(&cap_for_dataflow);
output
.inner
.probe_with(&mut probe)
.inspect(move |entry: &(Row, Lsn, isize)| {
cap_for_inspect.lock().expect("capture").push(entry.clone());
});
});
while let Ok(cmd) = cmd_rx.recv() {
match cmd {
DataflowCommand::Seed {
inputs: seed_rows,
reply,
} => {
for (table, rows) in seed_rows {
if let Some(session) = inputs.get_mut(&table) {
for row in rows {
session.update_at(row, Lsn::new(0), 1);
}
}
}
advance_and_step(&mut worker, &mut inputs, &probe, Lsn::new(1));
let drained = drain_captures(&captured);
let initial: Vec<Row> = drained
.into_iter()
.filter(|(_, _, diff)| *diff > 0)
.map(|(row, _, _)| row)
.collect();
let _ = reply.send(initial);
}
DataflowCommand::Apply { diffs, lsn, reply } => {
for (table, row, diff) in diffs {
if let Some(session) = inputs.get_mut(&table) {
session.update_at(row, lsn, diff);
}
}
let next = Lsn::new(lsn.get().saturating_add(1));
advance_and_step(&mut worker, &mut inputs, &probe, next);
let drained = drain_captures(&captured);
let deltas: Vec<AggregateDelta> = drained
.into_iter()
.map(|(row, t, diff)| AggregateDelta { row, lsn: t, diff })
.collect();
let _ = reply.send(deltas);
}
DataflowCommand::Stop => break,
}
}
}
fn advance_and_step(
worker: &mut TimelyWorker<Thread>,
inputs: &mut HashMap<TableId, InputSession<Lsn, Row, isize>>,
probe: &ProbeHandle<Lsn>,
target: Lsn,
) {
for session in inputs.values_mut() {
session.advance_to(target);
session.flush();
}
while probe.less_than(&target) {
worker.step();
}
}
fn drain_captures(cap: &Arc<Mutex<Vec<(Row, Lsn, isize)>>>) -> Vec<(Row, Lsn, isize)> {
let mut guard = cap.lock().expect("capture");
std::mem::take(&mut *guard)
}
struct PlanState {
dataflow: IncrementalDataflow,
last_output: Vec<Row>,
last_lsn: Lsn,
subscribers: Vec<u64>,
}
#[derive(Default)]
struct HostInner {
plans: HashMap<String, PlanState>,
}
pub struct PersistentHost {
inner: Arc<Mutex<HostInner>>,
}
impl PersistentHost {
#[must_use]
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(HostInner::default())),
}
}
pub fn cached_view(&self, canonical: &str, subscriber: u64) -> Option<(Vec<Row>, Lsn)> {
let mut inner = self.inner.lock().expect("host inner");
let state = inner.plans.get_mut(canonical)?;
state.subscribers.push(subscriber);
Some((state.last_output.clone(), state.last_lsn))
}
pub fn register_or_seed(
&self,
canonical: &str,
plan: &CompiledPlan,
inputs: HashMap<TableId, Vec<Row>>,
snapshot_lsn: Lsn,
subscriber: u64,
) -> Vec<Row> {
{
let mut inner = self.inner.lock().expect("host inner");
if let Some(state) = inner.plans.get_mut(canonical) {
state.subscribers.push(subscriber);
return state.last_output.clone();
}
}
let dataflow = IncrementalDataflow::spawn(plan.clone());
let initial = dataflow.seed(inputs);
let mut inner = self.inner.lock().expect("host inner");
if let Some(state) = inner.plans.get_mut(canonical) {
state.subscribers.push(subscriber);
return state.last_output.clone();
}
inner.plans.insert(
canonical.to_owned(),
PlanState {
dataflow,
last_output: initial.clone(),
last_lsn: snapshot_lsn,
subscribers: vec![subscriber],
},
);
initial
}
#[must_use]
pub fn subscribers(&self, canonical: &str) -> Option<Vec<u64>> {
let inner = self.inner.lock().expect("host inner");
inner.plans.get(canonical).map(|s| s.subscribers.clone())
}
pub fn push_table_diff(
&self,
canonical: &str,
table_id: TableId,
row: Row,
diff: isize,
lsn: Lsn,
) -> Vec<AggregateDelta> {
self.apply_and_fanout(canonical, vec![(table_id, row, diff)], lsn)
.map_or_else(Vec::new, |(deltas, _subs)| deltas)
}
pub fn apply_and_fanout(
&self,
canonical: &str,
diffs: Vec<(TableId, Row, isize)>,
lsn: Lsn,
) -> Option<(Vec<AggregateDelta>, Vec<u64>)> {
let mut inner = self.inner.lock().expect("host inner");
let state = inner.plans.get_mut(canonical)?;
let deltas = state.dataflow.apply(diffs, lsn);
apply_deltas_to_cache(&mut state.last_output, &deltas);
state.last_lsn = lsn;
Some((deltas, state.subscribers.clone()))
}
pub fn push_table_batch(
&self,
canonical: &str,
diffs: Vec<(TableId, Row, isize)>,
lsn: Lsn,
) -> Vec<AggregateDelta> {
self.apply_and_fanout(canonical, diffs, lsn)
.map_or_else(Vec::new, |(deltas, _subs)| deltas)
}
pub fn push_transaction(
&self,
canonical: &str,
transaction: &WalTransaction,
) -> Vec<AggregateDelta> {
let diffs = transaction
.updates
.iter()
.map(|update| (update.table, update.row.clone(), update.diff))
.collect();
self.push_table_batch(canonical, diffs, transaction.commit_lsn)
}
pub fn release(&self, canonical: &str, subscriber: u64) -> usize {
let mut inner = self.inner.lock().expect("host inner");
let Some(state) = inner.plans.get_mut(canonical) else {
return 0;
};
if let Some(pos) = state.subscribers.iter().position(|s| *s == subscriber) {
state.subscribers.swap_remove(pos);
}
let remaining = state.subscribers.len();
if remaining == 0 {
inner.plans.remove(canonical);
}
remaining
}
}
impl Default for PersistentHost {
fn default() -> Self {
Self::new()
}
}
fn apply_deltas_to_cache(rows: &mut Vec<Row>, deltas: &[AggregateDelta]) {
for delta in deltas {
if delta.diff > 0 {
for _ in 0..delta.diff {
rows.push(delta.row.clone());
}
} else if delta.diff < 0 {
for _ in 0..-delta.diff {
if let Some(pos) = rows.iter().position(|r| r == &delta.row) {
rows.swap_remove(pos);
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use palimpsest_sql::catalog::ColumnType;
use palimpsest_sql::lower::parse_and_lower;
use palimpsest_wal::Datum;
use crate::palimpsest::compile_mir::compile_mir;
use crate::palimpsest::eval::ScalarSchema;
fn events_schema() -> ScalarSchema {
ScalarSchema::from_pairs([
("id".to_owned(), ColumnType::Int),
("category_id".to_owned(), ColumnType::Int),
("value".to_owned(), ColumnType::Int),
])
}
fn lookup(table: &str) -> Option<(TableId, ScalarSchema)> {
match table {
"events" => Some((TableId::new(2), events_schema())),
_ => None,
}
}
fn row(values: Vec<Datum>) -> Row {
values.into_iter().collect()
}
#[test]
fn snapshot_run_emits_aggregate_rows() {
let sql = "WITH per_category AS (
SELECT category_id, COUNT(*) AS n, SUM(value) AS total
FROM events
GROUP BY category_id
)
SELECT category_id, n, total
FROM per_category
ORDER BY total DESC
LIMIT 8";
let graph = parse_and_lower(sql).unwrap();
let plan = compile_mir(&graph, &lookup).unwrap();
let mut inputs = HashMap::new();
inputs.insert(
TableId::new(2),
vec![
row(vec![Datum::I64(1), Datum::I64(7), Datum::I64(100)]),
row(vec![Datum::I64(2), Datum::I64(7), Datum::I64(50)]),
row(vec![Datum::I64(3), Datum::I64(9), Datum::I64(20)]),
row(vec![Datum::I64(4), Datum::I64(9), Datum::I64(20)]),
row(vec![Datum::I64(5), Datum::I64(11), Datum::I64(5)]),
],
);
let mut output = snapshot_run(&plan, inputs);
output.sort();
assert_eq!(output.len(), 3, "three categories");
}
#[test]
fn persistent_host_emits_initial_and_diffs() {
let sql = "WITH per_category AS (
SELECT category_id, COUNT(*) AS n, SUM(value) AS total
FROM events
GROUP BY category_id
)
SELECT category_id, n, total
FROM per_category
ORDER BY total DESC
LIMIT 8";
let graph = parse_and_lower(sql).unwrap();
let plan = compile_mir(&graph, &lookup).unwrap();
let host = PersistentHost::new();
let canonical = "events.top_categories";
let mut seed = HashMap::new();
seed.insert(
TableId::new(2),
vec![
row(vec![Datum::I64(1), Datum::I64(7), Datum::I64(100)]),
row(vec![Datum::I64(2), Datum::I64(7), Datum::I64(50)]),
row(vec![Datum::I64(3), Datum::I64(9), Datum::I64(20)]),
],
);
let mut initial = host.register_or_seed(canonical, &plan, seed, Lsn::new(1), 42);
initial.sort();
assert_eq!(initial.len(), 2, "initial has cat 7 + cat 9");
let next_lsn = Lsn::new(2);
let deltas = host.push_table_diff(
canonical,
TableId::new(2),
row(vec![Datum::I64(4), Datum::I64(9), Datum::I64(100)]),
1,
next_lsn,
);
let retracts: Vec<_> = deltas.iter().filter(|d| d.diff < 0).collect();
let asserts: Vec<_> = deltas.iter().filter(|d| d.diff > 0).collect();
assert_eq!(retracts.len(), 1, "one retract — old cat 9 row");
assert_eq!(asserts.len(), 1, "one assert — new cat 9 row");
let retracted = &retracts[0].row;
assert_eq!(retracted.get(0), Some(&Datum::I64(9)));
assert_eq!(retracted.get(1), Some(&Datum::I64(1)));
assert_eq!(retracted.get(2), Some(&Datum::I64(20)));
let asserted = &asserts[0].row;
assert_eq!(asserted.get(0), Some(&Datum::I64(9)));
assert_eq!(asserted.get(1), Some(&Datum::I64(2)));
assert_eq!(asserted.get(2), Some(&Datum::I64(120)));
host.release(canonical, 42);
}
#[test]
fn persistent_host_batch_coalesces() {
let sql = "WITH per_category AS (
SELECT category_id, COUNT(*) AS n, SUM(value) AS total
FROM events
GROUP BY category_id
)
SELECT category_id, n, total
FROM per_category
ORDER BY total DESC
LIMIT 8";
let plan = compile_mir(&parse_and_lower(sql).unwrap(), &lookup).unwrap();
let host = PersistentHost::new();
let canonical = "events.batch";
let mut seed = HashMap::new();
seed.insert(
TableId::new(2),
vec![row(vec![Datum::I64(1), Datum::I64(7), Datum::I64(10)])],
);
host.register_or_seed(canonical, &plan, seed, Lsn::new(1), 7);
let batch = vec![
(
TableId::new(2),
row(vec![Datum::I64(2), Datum::I64(7), Datum::I64(20)]),
1,
),
(
TableId::new(2),
row(vec![Datum::I64(3), Datum::I64(7), Datum::I64(30)]),
1,
),
];
let deltas = host.push_table_batch(canonical, batch, Lsn::new(2));
assert_eq!(deltas.len(), 2);
assert!(deltas.iter().all(|d| d.lsn == Lsn::new(2)));
host.release(canonical, 7);
}
#[test]
fn cached_view_attaches_late_subscriber_to_current_state() {
let sql = "WITH per_category AS (
SELECT category_id, COUNT(*) AS n, SUM(value) AS total
FROM events
GROUP BY category_id
)
SELECT category_id, n, total
FROM per_category
ORDER BY total DESC
LIMIT 8";
let plan = compile_mir(&parse_and_lower(sql).unwrap(), &lookup).unwrap();
let host = PersistentHost::new();
let canonical = "events.shared";
let mut seed = HashMap::new();
seed.insert(
TableId::new(2),
vec![row(vec![Datum::I64(1), Datum::I64(7), Datum::I64(10)])],
);
host.register_or_seed(canonical, &plan, seed, Lsn::new(5), 1);
let apply_lsn = Lsn::new(6);
let (_deltas, subs_after_apply) = host
.apply_and_fanout(
canonical,
vec![(
TableId::new(2),
row(vec![Datum::I64(2), Datum::I64(7), Datum::I64(20)]),
1,
)],
apply_lsn,
)
.expect("plan still registered");
assert_eq!(subs_after_apply, vec![1]);
let (cached, cached_lsn) = host
.cached_view(canonical, 2)
.expect("cache hit on registered plan");
assert_eq!(cached_lsn, apply_lsn);
assert!(cached.iter().any(|r| r.get(2) == Some(&Datum::I64(30))));
let subs_view = host.subscribers(canonical).expect("plan registered");
assert_eq!(subs_view, vec![1, 2]);
assert_eq!(host.release(canonical, 1), 1);
assert!(host.subscribers(canonical).is_some());
assert_eq!(host.release(canonical, 2), 0);
assert!(host.subscribers(canonical).is_none());
}
}