use std::collections::VecDeque;
use std::sync::Arc;
use arrow::array::RecordBatch;
use arrow::datatypes::SchemaRef;
use async_trait::async_trait;
use datafusion::prelude::SessionContext;
use rustc_hash::{FxHashMap, FxHashSet};
use serde::{Deserialize, Serialize};
use crate::error::DbError;
use crate::metrics::PipelineCounters;
use crate::stream_executor::{
apply_topk_filter, detect_asof_query, detect_stream_join_query, detect_temporal_query,
extract_table_references,
};
use laminar_sql::parser::EmitClause;
use laminar_sql::translator::{
OrderOperatorConfig, TemporalJoinTranslatorConfig, WindowOperatorConfig,
};
#[async_trait]
pub(crate) trait GraphOperator: Send {
async fn process(
&mut self,
inputs: &[Vec<RecordBatch>],
watermark: i64,
) -> Result<Vec<RecordBatch>, DbError>;
fn checkpoint(&mut self) -> Result<Option<OperatorCheckpoint>, DbError>;
fn restore(&mut self, checkpoint: OperatorCheckpoint) -> Result<(), DbError>;
fn estimated_state_bytes(&self) -> usize {
0
}
}
pub(crate) struct OperatorCheckpoint {
pub data: Vec<u8>,
}
#[derive(Serialize, Deserialize)]
pub(crate) struct GraphCheckpoint {
pub version: u32,
pub operators: FxHashMap<String, Vec<u8>>,
}
struct GraphNode {
name: Arc<str>,
operator: Box<dyn GraphOperator>,
input_port_count: usize,
output_routes: Vec<(usize, u8)>,
removed: bool,
}
struct GraphEdge {
source: usize,
target: usize,
}
struct SourcePassthrough;
#[async_trait]
impl GraphOperator for SourcePassthrough {
async fn process(
&mut self,
inputs: &[Vec<RecordBatch>],
_watermark: i64,
) -> Result<Vec<RecordBatch>, DbError> {
Ok(inputs.first().cloned().unwrap_or_default())
}
fn checkpoint(&mut self) -> Result<Option<OperatorCheckpoint>, DbError> {
Ok(None)
}
fn restore(&mut self, _checkpoint: OperatorCheckpoint) -> Result<(), DbError> {
Ok(())
}
}
struct TombstonedOperator;
#[async_trait]
impl GraphOperator for TombstonedOperator {
async fn process(
&mut self,
_inputs: &[Vec<RecordBatch>],
_watermark: i64,
) -> Result<Vec<RecordBatch>, DbError> {
Ok(Vec::new())
}
fn checkpoint(&mut self) -> Result<Option<OperatorCheckpoint>, DbError> {
Ok(None)
}
fn restore(&mut self, _checkpoint: OperatorCheckpoint) -> Result<(), DbError> {
Ok(())
}
}
pub(crate) struct OperatorGraph {
nodes: Vec<GraphNode>,
edges: Vec<GraphEdge>,
topo_order: Vec<usize>,
topo_dirty: bool,
source_map: FxHashMap<Arc<str>, usize>,
output_map: FxHashMap<Arc<str>, usize>,
input_bufs: Vec<Vec<Vec<RecordBatch>>>,
query_budget_ns: u64,
max_state_bytes: Option<usize>,
ctx: SessionContext,
counters: Option<Arc<PipelineCounters>>,
lookup_registry: Option<Arc<laminar_sql::datafusion::LookupTableRegistry>>,
source_schemas: FxHashMap<String, SchemaRef>,
temporal_configs: Vec<TemporalJoinTranslatorConfig>,
depends_on_stream: FxHashSet<usize>,
order_configs: FxHashMap<usize, OrderOperatorConfig>,
registered_sources: Vec<String>,
cycle_intermediates: Vec<String>,
}
impl OperatorGraph {
pub fn new(ctx: SessionContext) -> Self {
Self {
nodes: Vec::new(),
edges: Vec::new(),
topo_order: Vec::new(),
topo_dirty: true,
source_map: FxHashMap::default(),
output_map: FxHashMap::default(),
input_bufs: Vec::new(),
query_budget_ns: 8_000_000,
max_state_bytes: None,
ctx,
counters: None,
lookup_registry: None,
source_schemas: FxHashMap::default(),
temporal_configs: Vec::new(),
depends_on_stream: FxHashSet::default(),
order_configs: FxHashMap::default(),
registered_sources: Vec::new(),
cycle_intermediates: Vec::new(),
}
}
pub fn set_max_state_bytes(&mut self, limit: Option<usize>) {
self.max_state_bytes = limit;
}
pub fn set_query_budget_ns(&mut self, ns: u64) {
self.query_budget_ns = ns;
}
pub fn set_counters(&mut self, c: Arc<PipelineCounters>) {
self.counters = Some(c);
}
pub fn set_lookup_registry(
&mut self,
registry: Arc<laminar_sql::datafusion::LookupTableRegistry>,
) {
self.lookup_registry = Some(registry);
}
pub fn register_source_schema(&mut self, name: String, schema: SchemaRef) {
self.source_schemas.insert(name, schema);
}
pub fn temporal_join_configs(&self) -> Vec<TemporalJoinTranslatorConfig> {
self.temporal_configs.clone()
}
fn find_node(&self, name: &str) -> Option<usize> {
self.nodes
.iter()
.position(|n| &*n.name == name && !n.removed)
}
fn ensure_source_node(&mut self, table_name: &str) -> usize {
if let Some(&id) = self.source_map.get(table_name) {
return id;
}
let node_id = self.nodes.len();
let name: Arc<str> = Arc::from(table_name);
self.nodes.push(GraphNode {
name: Arc::clone(&name),
operator: Box::new(SourcePassthrough),
input_port_count: 1,
output_routes: Vec::new(),
removed: false,
});
self.input_bufs.push(vec![Vec::new()]);
self.source_map.insert(name, node_id);
node_id
}
fn add_edge(&mut self, source: usize, target: usize, target_port: u8) {
self.edges.push(GraphEdge { source, target });
self.nodes[source].output_routes.push((target, target_port));
}
#[allow(clippy::too_many_lines, clippy::needless_pass_by_value)]
pub fn add_query(
&mut self,
name: String,
sql: String,
emit_clause: Option<EmitClause>,
window_config: Option<WindowOperatorConfig>,
order_config: Option<OrderOperatorConfig>,
) {
let (asof_config, mut projection_sql) = detect_asof_query(&sql);
let (temporal_config, temporal_projection_sql) = detect_temporal_query(&sql);
let (stream_join_config, stream_join_projection_sql) = detect_stream_join_query(&sql);
if projection_sql.is_none() {
projection_sql = temporal_projection_sql;
}
if projection_sql.is_none() {
projection_sql = stream_join_projection_sql;
}
if stream_join_config.is_none() && asof_config.is_none() && temporal_config.is_none() {
let sql_upper = sql.to_uppercase();
if sql_upper.contains("JOIN") && sql_upper.contains("BETWEEN") {
tracing::warn!(
query = %name,
"Query contains JOIN with BETWEEN but was not detected as an interval join. \
It will execute as a batch join (matches within one cycle only). \
Ensure time columns in the BETWEEN clause are simple column references."
);
}
}
let table_refs = extract_table_references(&sql);
if let Some(ref tc) = temporal_config {
self.temporal_configs.push(tc.clone());
}
let operator: Box<dyn GraphOperator> = self.create_operator(
&name,
&sql,
emit_clause.as_ref(),
window_config.as_ref(),
asof_config.as_ref(),
temporal_config.as_ref(),
stream_join_config.as_ref(),
projection_sql.as_deref(),
);
let input_port_count = if asof_config.is_some() || stream_join_config.is_some() {
2
} else {
1
};
if let Some(&placeholder_id) = self.source_map.get(name.as_str()) {
self.nodes[placeholder_id].operator = operator;
self.nodes[placeholder_id].input_port_count = input_port_count;
self.input_bufs[placeholder_id] = vec![Vec::new(); input_port_count];
self.source_map.remove(name.as_str());
let node_id = placeholder_id;
for table_ref in &table_refs {
if self.find_node(table_ref).is_none() {
self.ensure_source_node(table_ref);
}
}
if let Some(ref asof_cfg) = asof_config {
self.find_node(&asof_cfg.left_table)
.unwrap_or_else(|| self.ensure_source_node(&asof_cfg.left_table));
self.find_node(&asof_cfg.right_table)
.unwrap_or_else(|| self.ensure_source_node(&asof_cfg.right_table));
} else if let Some(ref sjc) = stream_join_config {
self.find_node(&sjc.left_table)
.unwrap_or_else(|| self.ensure_source_node(&sjc.left_table));
self.find_node(&sjc.right_table)
.unwrap_or_else(|| self.ensure_source_node(&sjc.right_table));
}
if let Some(ref asof_cfg) = asof_config {
let left_id = self
.find_node(&asof_cfg.left_table)
.expect("source ensured");
let right_id = self
.find_node(&asof_cfg.right_table)
.expect("source ensured");
self.add_edge(left_id, node_id, 0);
self.add_edge(right_id, node_id, 1);
} else if let Some(ref sjc) = stream_join_config {
let left_id = self.find_node(&sjc.left_table).expect("source ensured");
let right_id = self.find_node(&sjc.right_table).expect("source ensured");
self.add_edge(left_id, node_id, 0);
self.add_edge(right_id, node_id, 1);
} else {
for table_ref in &table_refs {
let upstream_id = self.find_node(table_ref).expect("source ensured");
let already_connected = self.nodes[upstream_id]
.output_routes
.iter()
.any(|&(t, p)| t == node_id && p == 0);
if !already_connected {
self.add_edge(upstream_id, node_id, 0);
}
}
}
for &(target, _) in &self.nodes[node_id].output_routes {
self.depends_on_stream.insert(target);
}
if let Some(oc) = order_config {
self.order_configs.insert(node_id, oc);
}
self.output_map.insert(Arc::from(name.as_str()), node_id);
self.topo_dirty = true;
return;
}
if let Some(ref asof_cfg) = asof_config {
self.find_node(&asof_cfg.left_table)
.unwrap_or_else(|| self.ensure_source_node(&asof_cfg.left_table));
self.find_node(&asof_cfg.right_table)
.unwrap_or_else(|| self.ensure_source_node(&asof_cfg.right_table));
} else if let Some(ref sjc) = stream_join_config {
self.find_node(&sjc.left_table)
.unwrap_or_else(|| self.ensure_source_node(&sjc.left_table));
self.find_node(&sjc.right_table)
.unwrap_or_else(|| self.ensure_source_node(&sjc.right_table));
} else if let Some(ref tc) = temporal_config {
if self.find_node(&tc.stream_table).is_none() {
self.ensure_source_node(&tc.stream_table);
}
} else {
for table_ref in &table_refs {
if self.find_node(table_ref).is_none() {
self.ensure_source_node(table_ref);
}
}
}
let node_id = self.nodes.len();
self.nodes.push(GraphNode {
name: Arc::from(name.as_str()),
operator,
input_port_count,
output_routes: Vec::new(),
removed: false,
});
self.input_bufs.push(vec![Vec::new(); input_port_count]);
if let Some(ref asof_cfg) = asof_config {
let left_id = self
.find_node(&asof_cfg.left_table)
.expect("source node ensured above");
let right_id = self
.find_node(&asof_cfg.right_table)
.expect("source node ensured above");
self.add_edge(left_id, node_id, 0);
self.add_edge(right_id, node_id, 1);
} else if let Some(ref sjc) = stream_join_config {
let left_id = self
.find_node(&sjc.left_table)
.expect("source node ensured above");
let right_id = self
.find_node(&sjc.right_table)
.expect("source node ensured above");
self.add_edge(left_id, node_id, 0);
self.add_edge(right_id, node_id, 1);
} else if let Some(ref tc) = temporal_config {
let stream_id = self
.find_node(&tc.stream_table)
.expect("source node ensured above");
self.add_edge(stream_id, node_id, 0);
if self.output_map.contains_key(tc.stream_table.as_str()) {
self.depends_on_stream.insert(node_id);
}
} else {
let mut depends_on_query = false;
for table_ref in &table_refs {
let upstream_id = self
.find_node(table_ref)
.expect("source node ensured above");
let already_connected = self.nodes[upstream_id]
.output_routes
.iter()
.any(|&(t, p)| t == node_id && p == 0);
if !already_connected {
self.add_edge(upstream_id, node_id, 0);
}
if self.output_map.contains_key(table_ref.as_str()) {
depends_on_query = true;
}
}
if depends_on_query {
self.depends_on_stream.insert(node_id);
}
}
if let Some(oc) = order_config {
self.order_configs.insert(node_id, oc);
}
self.output_map.insert(Arc::from(name.as_str()), node_id);
self.topo_dirty = true;
}
#[allow(clippy::too_many_arguments)]
fn create_operator(
&self,
name: &str,
sql: &str,
emit_clause: Option<&EmitClause>,
window_config: Option<&WindowOperatorConfig>,
asof_config: Option<&laminar_sql::translator::AsofJoinTranslatorConfig>,
temporal_config: Option<&TemporalJoinTranslatorConfig>,
stream_join_config: Option<&laminar_sql::translator::StreamJoinConfig>,
projection_sql: Option<&str>,
) -> Box<dyn GraphOperator> {
use crate::operators;
if let Some(cfg) = asof_config {
return Box::new(operators::asof_join::AsofJoinOperator::new(
name,
cfg.clone(),
projection_sql.map(Arc::from),
self.ctx.clone(),
));
}
if let Some(cfg) = temporal_config {
return Box::new(operators::temporal_join::TemporalJoinOperator::new(
name,
cfg.clone(),
projection_sql.map(Arc::from),
self.ctx.clone(),
self.lookup_registry.clone(),
));
}
if let Some(cfg) = stream_join_config {
return Box::new(operators::interval_join::IntervalJoinOperator::new(
name,
cfg.clone(),
projection_sql.map(Arc::from),
self.ctx.clone(),
));
}
let is_eowc = emit_clause
.is_some_and(|ec| matches!(ec, EmitClause::OnWindowClose | EmitClause::Final));
if is_eowc {
return Box::new(operators::eowc_query::EowcQueryOperator::new(
name,
sql,
emit_clause.cloned(),
window_config.cloned(),
self.ctx.clone(),
self.counters.clone(),
));
}
Box::new(operators::sql_query::SqlQueryOperator::new(
name,
sql,
self.ctx.clone(),
self.counters.clone(),
))
}
pub fn remove_query(&mut self, name: &str) {
let Some(node_id) = self.find_node(name) else {
return;
};
self.nodes[node_id].removed = true;
self.nodes[node_id].operator = Box::new(TombstonedOperator);
self.output_map.remove(name);
self.order_configs.remove(&node_id);
self.depends_on_stream.remove(&node_id);
self.edges
.retain(|e| e.source != node_id && e.target != node_id);
self.nodes[node_id].output_routes.clear();
for node in &mut self.nodes {
node.output_routes.retain(|&(t, _)| t != node_id);
}
self.topo_dirty = true;
}
fn compute_topo_order(&mut self) {
let n = self.nodes.len();
let mut in_degree = vec![0usize; n];
let mut dependents: Vec<Vec<usize>> = vec![Vec::new(); n];
for edge in &self.edges {
if !self.nodes[edge.source].removed && !self.nodes[edge.target].removed {
in_degree[edge.target] += 1;
dependents[edge.source].push(edge.target);
}
}
for deps in &mut dependents {
deps.sort_unstable();
deps.dedup();
}
in_degree.fill(0);
for deps in &dependents {
for &dep in deps {
in_degree[dep] += 1;
}
}
let mut queue = VecDeque::new();
for (i, °) in in_degree.iter().enumerate() {
if deg == 0 && !self.nodes[i].removed {
queue.push_back(i);
}
}
self.topo_order.clear();
while let Some(idx) = queue.pop_front() {
self.topo_order.push(idx);
for &dep in &dependents[idx] {
in_degree[dep] = in_degree[dep].saturating_sub(1);
if in_degree[dep] == 0 {
queue.push_back(dep);
}
}
}
let active_count = self.nodes.iter().filter(|n| !n.removed).count();
if self.topo_order.len() < active_count {
tracing::warn!(
ordered = self.topo_order.len(),
total = active_count,
"circular dependency in operator graph, \
falling back to insertion order for remaining nodes"
);
let in_order: FxHashSet<usize> = self.topo_order.iter().copied().collect();
for i in 0..n {
if !in_order.contains(&i) && !self.nodes[i].removed {
self.topo_order.push(i);
}
}
}
self.topo_dirty = false;
}
fn register_source_tables(
&mut self,
source_batches: &FxHashMap<Arc<str>, Vec<RecordBatch>>,
) -> Result<(), DbError> {
for (name, batches) in source_batches {
if batches.is_empty() {
continue;
}
let schema = batches[0].schema();
let mem_table =
datafusion::datasource::MemTable::try_new(schema, vec![batches.clone()])
.map_err(|e| DbError::query_pipeline(&**name, &e))?;
let _ = self.ctx.deregister_table(&**name);
self.ctx
.register_table(&**name, Arc::new(mem_table))
.map_err(|e| DbError::query_pipeline(&**name, &e))?;
self.registered_sources.push(name.to_string());
}
for (name, schema) in &self.source_schemas {
if source_batches.contains_key(name.as_str()) {
continue;
}
let empty = datafusion::datasource::MemTable::try_new(schema.clone(), vec![vec![]])
.map_err(|e| DbError::query_pipeline(name, &e))?;
let _ = self.ctx.deregister_table(name);
self.ctx
.register_table(name, Arc::new(empty))
.map_err(|e| DbError::query_pipeline(name, &e))?;
self.registered_sources.push(name.clone());
}
Ok(())
}
fn cleanup_source_tables(&mut self) {
for name in self.registered_sources.drain(..) {
let _ = self.ctx.deregister_table(&name);
}
}
#[allow(clippy::too_many_lines)]
pub async fn execute_cycle(
&mut self,
source_batches: &FxHashMap<Arc<str>, Vec<RecordBatch>>,
current_watermark: i64,
) -> Result<FxHashMap<Arc<str>, Vec<RecordBatch>>, DbError> {
if self.topo_dirty {
self.compute_topo_order();
}
self.register_source_tables(source_batches)?;
for (name, batches) in source_batches {
if let Some(&node_id) = self.source_map.get(name) {
self.input_bufs[node_id][0].clone_from(batches);
}
}
let mut results = FxHashMap::default();
let mut intermediate_tables = std::mem::take(&mut self.cycle_intermediates);
intermediate_tables.clear();
let cycle_start = std::time::Instant::now();
let topo_len = self.topo_order.len();
for i in 0..topo_len {
let node_id = self.topo_order[i];
if self.nodes[node_id].removed {
continue;
}
if i > 0 {
#[allow(clippy::cast_possible_truncation)]
let elapsed_ns = cycle_start.elapsed().as_nanos() as u64;
if elapsed_ns > self.query_budget_ns {
tracing::debug!(
skipped = topo_len - i,
elapsed_ms = elapsed_ns / 1_000_000,
"per-query budget exceeded — deferring remaining operators"
);
break;
}
}
let inputs = std::mem::take(&mut self.input_bufs[node_id]);
let output_result = self.nodes[node_id]
.operator
.process(&inputs, current_watermark)
.await;
let port_count = self.nodes[node_id].input_port_count;
self.input_bufs[node_id] = vec![Vec::new(); port_count];
let batches = match output_result {
Ok(b) => b,
Err(e) => {
if self.depends_on_stream.contains(&node_id) {
tracing::debug!(
query = %self.nodes[node_id].name,
error = %e,
"Query skipped (upstream not ready)"
);
continue;
}
self.cleanup_source_tables();
for name in &intermediate_tables {
let _ = self.ctx.deregister_table(name);
}
self.cycle_intermediates = intermediate_tables;
return Err(e);
}
};
if let Some(limit) = self.max_state_bytes {
let size = self.nodes[node_id].operator.estimated_state_bytes();
if size >= limit {
self.cleanup_source_tables();
for name in &intermediate_tables {
let _ = self.ctx.deregister_table(name);
}
self.cycle_intermediates = intermediate_tables;
return Err(DbError::Pipeline(format!(
"state size limit exceeded for query '{}' ({size} bytes >= {limit} limit)",
self.nodes[node_id].name
)));
}
if size >= limit * 4 / 5 {
tracing::warn!(
query = %self.nodes[node_id].name,
size_bytes = size,
limit_bytes = limit,
"state size at 80% of limit"
);
}
}
let batches = if let Some(oc) = self.order_configs.get(&node_id) {
match oc {
OrderOperatorConfig::TopK(c) => apply_topk_filter(&batches, c.k),
OrderOperatorConfig::PerGroupTopK(c) => apply_topk_filter(&batches, c.k),
_ => batches,
}
} else {
batches
};
if !batches.is_empty() {
let node_name = Arc::clone(&self.nodes[node_id].name);
if !self.nodes[node_id].output_routes.is_empty() {
let schema = batches[0].schema();
if let Ok(mem_table) =
datafusion::datasource::MemTable::try_new(schema, vec![batches.clone()])
{
let _ = self.ctx.deregister_table(&*node_name);
if let Err(e) = self.ctx.register_table(&*node_name, Arc::new(mem_table)) {
tracing::warn!(
query = %node_name,
error = %e,
"[LDB-3015] Failed to register intermediate table"
);
}
intermediate_tables.push(node_name.to_string());
}
}
if self.output_map.values().any(|&id| id == node_id) {
results.insert(node_name, batches.clone());
}
let routes = self.nodes[node_id].output_routes.clone();
if routes.len() == 1 {
let (target, port) = routes[0];
if self.input_bufs[target][port as usize].is_empty() {
self.input_bufs[target][port as usize] = batches;
} else {
self.input_bufs[target][port as usize].extend(batches);
}
} else if routes.len() > 1 {
for &(target, port) in &routes {
self.input_bufs[target][port as usize].extend(batches.iter().cloned());
}
}
}
}
self.cleanup_source_tables();
for name in &intermediate_tables {
let _ = self.ctx.deregister_table(name);
}
self.cycle_intermediates = intermediate_tables;
Ok(results)
}
pub fn snapshot_state(&mut self) -> Result<Option<GraphCheckpoint>, DbError> {
let mut operators = FxHashMap::default();
for node in &mut self.nodes {
if node.removed {
continue;
}
if let Some(cp) = node.operator.checkpoint()? {
operators.insert(node.name.to_string(), cp.data);
}
}
if operators.is_empty() {
return Ok(None);
}
Ok(Some(GraphCheckpoint {
version: 1,
operators,
}))
}
pub fn restore_state(&mut self, checkpoint: &GraphCheckpoint) -> Result<usize, DbError> {
let mut restored = 0;
for node in &mut self.nodes {
if node.removed {
continue;
}
if let Some(bytes) = checkpoint.operators.get(&*node.name) {
node.operator.restore(OperatorCheckpoint {
data: bytes.clone(),
})?;
restored += 1;
}
}
Ok(restored)
}
pub fn serialize_checkpoint(cp: &GraphCheckpoint) -> Result<Vec<u8>, DbError> {
serde_json::to_vec(cp)
.map_err(|e| DbError::Pipeline(format!("operator graph checkpoint serialization: {e}")))
}
pub fn restore_from_bytes(&mut self, bytes: &[u8]) -> Result<usize, DbError> {
let checkpoint: GraphCheckpoint = serde_json::from_slice(bytes).map_err(|e| {
DbError::Pipeline(format!("operator graph checkpoint deserialization: {e}"))
})?;
self.restore_state(&checkpoint)
}
}
pub(crate) fn try_evaluate_compiled(
proj: &crate::aggregate_state::CompiledProjection,
batches: &[RecordBatch],
) -> Result<Vec<RecordBatch>, crate::error::DbError> {
let mut result = Vec::with_capacity(batches.len());
for batch in batches {
let b = proj.evaluate(batch)?;
if b.num_rows() > 0 {
result.push(b);
}
}
Ok(result)
}
#[cfg(test)]
#[allow(clippy::redundant_closure_for_method_calls)]
mod tests {
use super::*;
use arrow::array::{Float64Array, Int64Array, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
fn test_schema() -> Arc<Schema> {
Arc::new(Schema::new(vec![
Field::new("symbol", DataType::Utf8, false),
Field::new("price", DataType::Float64, false),
Field::new("ts", DataType::Int64, false),
]))
}
fn test_batch() -> RecordBatch {
RecordBatch::try_new(
test_schema(),
vec![
Arc::new(StringArray::from(vec!["AAPL", "GOOG"])),
Arc::new(Float64Array::from(vec![150.0, 2800.0])),
Arc::new(Int64Array::from(vec![1000, 2000])),
],
)
.unwrap()
}
#[test]
fn test_source_passthrough() {
let rt = tokio::runtime::Builder::new_current_thread()
.build()
.unwrap();
rt.block_on(async {
let mut op = SourcePassthrough;
let batch = test_batch();
let result = op.process(&[vec![batch.clone()]], 0).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].num_rows(), 2);
});
}
#[test]
fn test_graph_construction() {
let ctx = laminar_sql::create_session_context();
let mut graph = OperatorGraph::new(ctx);
graph.add_query(
"q1".to_string(),
"SELECT symbol, price FROM trades WHERE price > 100".to_string(),
None,
None,
None,
);
assert_eq!(graph.nodes.len(), 2); assert_eq!(graph.edges.len(), 1); assert!(graph.source_map.contains_key("trades"));
assert!(graph.output_map.contains_key("q1"));
}
#[test]
fn test_cascading_queries() {
let ctx = laminar_sql::create_session_context();
let mut graph = OperatorGraph::new(ctx);
graph.add_query(
"q1".to_string(),
"SELECT symbol, price FROM trades".to_string(),
None,
None,
None,
);
graph.add_query(
"q2".to_string(),
"SELECT symbol FROM q1 WHERE price > 100".to_string(),
None,
None,
None,
);
assert_eq!(graph.nodes.len(), 3);
assert_eq!(graph.edges.len(), 2);
assert!(graph.depends_on_stream.contains(&2)); }
#[test]
fn test_topo_order() {
let ctx = laminar_sql::create_session_context();
let mut graph = OperatorGraph::new(ctx);
graph.add_query(
"q2".to_string(),
"SELECT * FROM q1".to_string(),
None,
None,
None,
);
graph.add_query(
"q1".to_string(),
"SELECT * FROM trades".to_string(),
None,
None,
None,
);
graph.compute_topo_order();
let q1_pos = graph
.topo_order
.iter()
.position(|&id| &*graph.nodes[id].name == "q1");
let q2_pos = graph
.topo_order
.iter()
.position(|&id| &*graph.nodes[id].name == "q2");
assert!(q1_pos.is_some());
assert!(q2_pos.is_some());
}
#[test]
fn test_remove_query() {
let ctx = laminar_sql::create_session_context();
let mut graph = OperatorGraph::new(ctx);
graph.add_query(
"q1".to_string(),
"SELECT * FROM trades".to_string(),
None,
None,
None,
);
assert!(graph.output_map.contains_key("q1"));
graph.remove_query("q1");
assert!(!graph.output_map.contains_key("q1"));
assert!(graph.nodes[1].removed); }
#[tokio::test]
async fn test_execute_cycle_basic() {
let ctx = laminar_sql::create_session_context();
laminar_sql::register_streaming_functions(&ctx);
let mut graph = OperatorGraph::new(ctx);
graph.add_query(
"filtered".to_string(),
"SELECT symbol, price FROM trades WHERE price > 200".to_string(),
None,
None,
None,
);
let batch = test_batch();
let mut source_batches = FxHashMap::default();
source_batches.insert(Arc::from("trades"), vec![batch]);
let results = graph
.execute_cycle(&source_batches, i64::MAX)
.await
.unwrap();
assert!(results.contains_key("filtered"));
let filtered = &results[&Arc::from("filtered") as &Arc<str>];
let total_rows: usize = filtered.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 1);
}
#[tokio::test]
async fn test_execute_cycle_empty_source() {
let ctx = laminar_sql::create_session_context();
laminar_sql::register_streaming_functions(&ctx);
let mut graph = OperatorGraph::new(ctx);
graph.register_source_schema("trades".to_string(), test_schema());
graph.add_query(
"q1".to_string(),
"SELECT * FROM trades".to_string(),
None,
None,
None,
);
let source_batches = FxHashMap::default();
let results = graph
.execute_cycle(&source_batches, i64::MAX)
.await
.unwrap();
let total: usize = results
.get("q1")
.map_or(0, |bs| bs.iter().map(|b| b.num_rows()).sum());
assert_eq!(total, 0);
}
#[tokio::test]
async fn test_fan_out() {
let ctx = laminar_sql::create_session_context();
laminar_sql::register_streaming_functions(&ctx);
let mut graph = OperatorGraph::new(ctx);
graph.add_query(
"q1".to_string(),
"SELECT symbol, price FROM trades".to_string(),
None,
None,
None,
);
graph.add_query(
"q2".to_string(),
"SELECT symbol FROM trades".to_string(),
None,
None,
None,
);
let batch = test_batch();
let mut source_batches = FxHashMap::default();
source_batches.insert(Arc::from("trades"), vec![batch]);
let results = graph
.execute_cycle(&source_batches, i64::MAX)
.await
.unwrap();
assert!(results.contains_key("q1"));
assert!(results.contains_key("q2"));
}
#[test]
fn test_checkpoint_empty() {
let ctx = laminar_sql::create_session_context();
let mut graph = OperatorGraph::new(ctx);
graph.add_query(
"q1".to_string(),
"SELECT * FROM trades".to_string(),
None,
None,
None,
);
let cp = graph.snapshot_state().unwrap();
assert!(cp.is_none());
}
fn total_rows(results: &FxHashMap<Arc<str>, Vec<RecordBatch>>, key: &str) -> usize {
results
.get(key)
.map_or(0, |bs| bs.iter().map(|b| b.num_rows()).sum())
}
fn test_graph() -> OperatorGraph {
let ctx = laminar_sql::create_session_context();
laminar_sql::register_streaming_functions(&ctx);
let mut graph = OperatorGraph::new(ctx);
graph.set_query_budget_ns(5_000_000_000); graph
}
#[tokio::test]
async fn test_og_compiled_projection() {
let mut graph = test_graph();
graph.add_query(
"projected".to_string(),
"SELECT symbol, price FROM trades".to_string(),
None,
None,
None,
);
let mut source = FxHashMap::default();
source.insert(Arc::from("trades"), vec![test_batch()]);
let r = graph.execute_cycle(&source, i64::MAX).await.unwrap();
assert_eq!(total_rows(&r, "projected"), 2);
let r2 = graph.execute_cycle(&source, i64::MAX).await.unwrap();
assert_eq!(total_rows(&r2, "projected"), 2);
}
#[tokio::test]
async fn test_og_compiled_fallback_on_type_mismatch() {
let mut graph = test_graph();
graph.add_query(
"filtered".to_string(),
"SELECT symbol, price FROM trades WHERE price > 200".to_string(),
None,
None,
None,
);
let mut source = FxHashMap::default();
source.insert(Arc::from("trades"), vec![test_batch()]);
let r = graph.execute_cycle(&source, i64::MAX).await.unwrap();
assert_eq!(total_rows(&r, "filtered"), 1); }
#[tokio::test]
async fn test_og_aggregate_incremental() {
let mut graph = test_graph();
graph.add_query(
"agg".to_string(),
"SELECT symbol, SUM(price) AS total FROM trades GROUP BY symbol".to_string(),
None,
None,
None,
);
let mut source = FxHashMap::default();
source.insert(Arc::from("trades"), vec![test_batch()]);
let r = graph.execute_cycle(&source, i64::MAX).await.unwrap();
assert_eq!(total_rows(&r, "agg"), 2);
let r2 = graph.execute_cycle(&source, i64::MAX).await.unwrap();
let agg_batches = &r2[&Arc::from("agg") as &Arc<str>];
assert_eq!(total_rows(&r2, "agg"), 2);
let price_col = agg_batches[0]
.column_by_name("total")
.unwrap()
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
let symbol_col = agg_batches[0]
.column_by_name("symbol")
.unwrap()
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
for i in 0..agg_batches[0].num_rows() {
match symbol_col.value(i) {
"AAPL" => assert!((price_col.value(i) - 300.0).abs() < f64::EPSILON),
"GOOG" => assert!((price_col.value(i) - 5600.0).abs() < f64::EPSILON),
other => panic!("unexpected symbol: {other}"),
}
}
}
#[tokio::test]
async fn test_og_cascading() {
let mut graph = test_graph();
graph.add_query(
"step1".to_string(),
"SELECT symbol, price * 2 AS doubled FROM trades".to_string(),
None,
None,
None,
);
graph.add_query(
"step2".to_string(),
"SELECT symbol, doubled FROM step1 WHERE doubled > 400".to_string(),
None,
None,
None,
);
let mut source = FxHashMap::default();
source.insert(Arc::from("trades"), vec![test_batch()]);
let r = graph.execute_cycle(&source, i64::MAX).await.unwrap();
assert_eq!(total_rows(&r, "step1"), 2);
assert_eq!(total_rows(&r, "step2"), 1);
}
#[tokio::test]
async fn test_og_diamond_dag() {
let mut graph = test_graph();
graph.add_query(
"high".to_string(),
"SELECT symbol, price FROM trades WHERE price > 200".to_string(),
None,
None,
None,
);
graph.add_query(
"low".to_string(),
"SELECT symbol, price FROM trades WHERE price <= 200".to_string(),
None,
None,
None,
);
graph.add_query(
"combined".to_string(),
"SELECT h.symbol, h.price FROM high h INNER JOIN low l ON h.symbol = l.symbol"
.to_string(),
None,
None,
None,
);
let mut source = FxHashMap::default();
source.insert(Arc::from("trades"), vec![test_batch()]);
let r = graph.execute_cycle(&source, i64::MAX).await.unwrap();
assert_eq!(total_rows(&r, "high"), 1); assert_eq!(total_rows(&r, "low"), 1); assert_eq!(total_rows(&r, "combined"), 0);
}
#[tokio::test]
async fn test_og_budget_exhaustion() {
let mut graph = test_graph();
graph.set_query_budget_ns(1);
graph.add_query(
"q1".to_string(),
"SELECT * FROM trades".to_string(),
None,
None,
None,
);
graph.add_query(
"q2".to_string(),
"SELECT * FROM trades".to_string(),
None,
None,
None,
);
let mut source = FxHashMap::default();
source.insert(Arc::from("trades"), vec![test_batch()]);
let r = graph.execute_cycle(&source, i64::MAX).await.unwrap();
let produced = r.len();
assert!(
produced < 2,
"with 1ns budget, at most one query should run"
);
}
#[tokio::test]
async fn test_og_state_size_limit() {
let mut graph = test_graph();
graph.set_max_state_bytes(Some(1));
graph.add_query(
"agg".to_string(),
"SELECT symbol, SUM(price) AS total FROM trades GROUP BY symbol".to_string(),
None,
None,
None,
);
let mut source = FxHashMap::default();
source.insert(Arc::from("trades"), vec![test_batch()]);
let result = graph.execute_cycle(&source, i64::MAX).await;
assert!(result.is_err(), "state size limit should be exceeded");
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("state size limit exceeded"),
"unexpected error: {err_msg}"
);
}
#[tokio::test]
async fn test_og_checkpoint_roundtrip_aggregate() {
let mut graph = test_graph();
graph.add_query(
"agg".to_string(),
"SELECT symbol, SUM(price) AS total FROM trades GROUP BY symbol".to_string(),
None,
None,
None,
);
let mut source = FxHashMap::default();
source.insert(Arc::from("trades"), vec![test_batch()]);
let _ = graph.execute_cycle(&source, i64::MAX).await.unwrap();
let cp = graph
.snapshot_state()
.unwrap()
.expect("aggregate should have state");
let bytes = OperatorGraph::serialize_checkpoint(&cp).unwrap();
let mut graph2 = test_graph();
graph2.add_query(
"agg".to_string(),
"SELECT symbol, SUM(price) AS total FROM trades GROUP BY symbol".to_string(),
None,
None,
None,
);
let _ = graph2.execute_cycle(&source, i64::MAX).await.unwrap();
let restored = graph2.restore_from_bytes(&bytes).unwrap();
assert!(restored > 0, "should restore at least one operator");
let r = graph2.execute_cycle(&source, i64::MAX).await.unwrap();
assert_eq!(total_rows(&r, "agg"), 2);
}
#[tokio::test]
async fn test_og_aggregate_empty_source_emits_state() {
let mut graph = test_graph();
graph.register_source_schema("trades".to_string(), test_schema());
graph.add_query(
"agg".to_string(),
"SELECT symbol, SUM(price) AS total FROM trades GROUP BY symbol".to_string(),
None,
None,
None,
);
let mut source = FxHashMap::default();
source.insert(Arc::from("trades"), vec![test_batch()]);
let r = graph.execute_cycle(&source, i64::MAX).await.unwrap();
assert_eq!(total_rows(&r, "agg"), 2);
let empty_source = FxHashMap::default();
let r2 = graph.execute_cycle(&empty_source, i64::MAX).await.unwrap();
assert_eq!(total_rows(&r2, "agg"), 2);
}
#[tokio::test]
async fn test_og_reverse_order_cascading() {
let mut graph = test_graph();
graph.add_query(
"q2".to_string(),
"SELECT symbol FROM q1 WHERE price > 200".to_string(),
None,
None,
None,
);
graph.add_query(
"q1".to_string(),
"SELECT symbol, price FROM trades".to_string(),
None,
None,
None,
);
assert!(
!graph.source_map.contains_key("q1"),
"q1 placeholder should be replaced, not in source_map"
);
assert!(graph.output_map.contains_key("q1"));
assert!(graph.output_map.contains_key("q2"));
let mut source = FxHashMap::default();
source.insert(Arc::from("trades"), vec![test_batch()]);
let r = graph.execute_cycle(&source, i64::MAX).await.unwrap();
assert_eq!(total_rows(&r, "q1"), 2); assert_eq!(total_rows(&r, "q2"), 1); }
}