Skip to main content

laminar_sql/planner/
streaming_optimizer.rs

1//! Physical optimizer rule for streaming plan validation.
2//!
3//! Detects pipeline-breaking operators (Sort, Final Aggregate) on unbounded
4//! inputs and rejects or warns at plan-creation time, before any execution
5//! begins.
6
7use std::fmt::Debug;
8use std::sync::Arc;
9
10use datafusion::physical_optimizer::PhysicalOptimizerRule;
11use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode};
12use datafusion::physical_plan::execution_plan::Boundedness;
13use datafusion::physical_plan::sorts::sort::SortExec;
14use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties};
15use datafusion_common::config::ConfigOptions;
16use datafusion_common::DataFusionError;
17
18/// How the validator handles streaming plan violations.
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum StreamingValidatorMode {
21    /// Return an error, preventing plan execution. Default.
22    Reject,
23    /// Log a warning but allow execution.
24    Warn,
25    /// Disable validation entirely.
26    Off,
27}
28
29/// A streaming plan violation detected during validation.
30#[derive(Debug)]
31struct StreamingViolation {
32    operator: String,
33    reason: String,
34    plan_path: String,
35}
36
37/// Validates that a physical plan is safe for streaming execution.
38///
39/// Detects pipeline-breaking operators (Sort, Final Aggregate) on
40/// unbounded inputs and rejects or warns depending on configuration.
41#[derive(Debug)]
42pub struct StreamingPhysicalValidator {
43    mode: StreamingValidatorMode,
44}
45
46impl StreamingPhysicalValidator {
47    /// Creates a new validator with the given mode.
48    #[must_use]
49    pub fn new(mode: StreamingValidatorMode) -> Self {
50        Self { mode }
51    }
52}
53
54impl PhysicalOptimizerRule for StreamingPhysicalValidator {
55    fn optimize(
56        &self,
57        plan: Arc<dyn ExecutionPlan>,
58        _config: &ConfigOptions,
59    ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
60        if matches!(self.mode, StreamingValidatorMode::Off) {
61            return Ok(plan);
62        }
63
64        let violations = find_streaming_violations(&plan);
65        if violations.is_empty() {
66            return Ok(plan);
67        }
68
69        match self.mode {
70            StreamingValidatorMode::Reject => {
71                Err(DataFusionError::Plan(format_violations(&violations)))
72            }
73            StreamingValidatorMode::Warn => {
74                for v in &violations {
75                    tracing::warn!(
76                        operator = %v.operator,
77                        path = %v.plan_path,
78                        "Streaming plan violation: {}", v.reason
79                    );
80                }
81                Ok(plan)
82            }
83            StreamingValidatorMode::Off => unreachable!(),
84        }
85    }
86
87    #[allow(clippy::unnecessary_literal_bound)]
88    fn name(&self) -> &str {
89        "streaming_physical_validator"
90    }
91
92    fn schema_check(&self) -> bool {
93        true
94    }
95}
96
97fn find_streaming_violations(plan: &Arc<dyn ExecutionPlan>) -> Vec<StreamingViolation> {
98    let mut violations = Vec::new();
99    walk_plan(plan, &mut violations, "");
100    violations
101}
102
103fn walk_plan(plan: &Arc<dyn ExecutionPlan>, violations: &mut Vec<StreamingViolation>, path: &str) {
104    let name = plan.name();
105    let current_path = if path.is_empty() {
106        name.to_string()
107    } else {
108        format!("{path} -> {name}")
109    };
110
111    // Check 1: SortExec on unbounded input
112    if plan.as_any().downcast_ref::<SortExec>().is_some() && has_unbounded_child(plan) {
113        violations.push(StreamingViolation {
114            operator: name.to_string(),
115            reason: "Sort requires buffering all input; unbounded source will \
116                     buffer forever. Remove ORDER BY or add a window."
117                .to_string(),
118            plan_path: current_path.clone(),
119        });
120    }
121
122    // Check 2: Final AggregateExec on unbounded input
123    if let Some(agg) = plan.as_any().downcast_ref::<AggregateExec>() {
124        if matches!(
125            agg.mode(),
126            &AggregateMode::Final | &AggregateMode::FinalPartitioned
127        ) && has_unbounded_child(plan)
128        {
129            violations.push(StreamingViolation {
130                operator: name.to_string(),
131                reason: "Final aggregation on unbounded input will never emit \
132                         results. Use a window function (TUMBLE/HOP/SESSION) or \
133                         add an EMIT clause."
134                    .to_string(),
135                plan_path: current_path.clone(),
136            });
137        }
138    }
139
140    for child in plan.children() {
141        walk_plan(child, violations, &current_path);
142    }
143}
144
145fn has_unbounded_child(plan: &Arc<dyn ExecutionPlan>) -> bool {
146    plan.children()
147        .iter()
148        .any(|c| matches!(c.boundedness(), Boundedness::Unbounded { .. }))
149}
150
151fn format_violations(violations: &[StreamingViolation]) -> String {
152    use std::fmt::Write;
153
154    let mut msg = String::from("Streaming plan validation failed:\n");
155    for (i, v) in violations.iter().enumerate() {
156        let _ = writeln!(
157            msg,
158            "  {}. [{}] {} (at: {})",
159            i + 1,
160            v.operator,
161            v.reason,
162            v.plan_path
163        );
164    }
165    msg
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use std::any::Any;
172
173    use arrow_schema::{DataType, Field, Schema, SchemaRef};
174    use datafusion::execution::{SendableRecordBatchStream, TaskContext};
175    use datafusion::physical_expr::{EquivalenceProperties, LexOrdering, Partitioning};
176    use datafusion::physical_plan::execution_plan::EmissionType;
177    use datafusion::physical_plan::{DisplayAs, DisplayFormatType, PlanProperties};
178    use datafusion_common::config::ConfigOptions;
179
180    // ── Mock unbounded leaf node ────────────────────────────────────
181
182    #[derive(Debug)]
183    struct MockUnboundedExec {
184        schema: SchemaRef,
185        props: PlanProperties,
186    }
187
188    impl MockUnboundedExec {
189        fn new(schema: SchemaRef) -> Self {
190            let eq = EquivalenceProperties::new(Arc::clone(&schema));
191            let props = PlanProperties::new(
192                eq,
193                Partitioning::UnknownPartitioning(1),
194                EmissionType::Incremental,
195                Boundedness::Unbounded {
196                    requires_infinite_memory: false,
197                },
198            );
199            Self { schema, props }
200        }
201    }
202
203    impl DisplayAs for MockUnboundedExec {
204        fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
205            write!(f, "MockUnboundedExec")
206        }
207    }
208
209    impl ExecutionPlan for MockUnboundedExec {
210        fn name(&self) -> &'static str {
211            "MockUnboundedExec"
212        }
213
214        fn as_any(&self) -> &dyn Any {
215            self
216        }
217
218        fn schema(&self) -> SchemaRef {
219            Arc::clone(&self.schema)
220        }
221
222        fn properties(&self) -> &PlanProperties {
223            &self.props
224        }
225
226        fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
227            vec![]
228        }
229
230        fn with_new_children(
231            self: Arc<Self>,
232            _children: Vec<Arc<dyn ExecutionPlan>>,
233        ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
234            Ok(self)
235        }
236
237        fn execute(
238            &self,
239            _partition: usize,
240            _context: Arc<TaskContext>,
241        ) -> datafusion_common::Result<SendableRecordBatchStream> {
242            unimplemented!("mock")
243        }
244    }
245
246    // ── Mock bounded leaf node ──────────────────────────────────────
247
248    #[derive(Debug)]
249    struct MockBoundedExec {
250        schema: SchemaRef,
251        props: PlanProperties,
252    }
253
254    impl MockBoundedExec {
255        fn new(schema: SchemaRef) -> Self {
256            let eq = EquivalenceProperties::new(Arc::clone(&schema));
257            let props = PlanProperties::new(
258                eq,
259                Partitioning::UnknownPartitioning(1),
260                EmissionType::Final,
261                Boundedness::Bounded,
262            );
263            Self { schema, props }
264        }
265    }
266
267    impl DisplayAs for MockBoundedExec {
268        fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
269            write!(f, "MockBoundedExec")
270        }
271    }
272
273    impl ExecutionPlan for MockBoundedExec {
274        fn name(&self) -> &'static str {
275            "MockBoundedExec"
276        }
277
278        fn as_any(&self) -> &dyn Any {
279            self
280        }
281
282        fn schema(&self) -> SchemaRef {
283            Arc::clone(&self.schema)
284        }
285
286        fn properties(&self) -> &PlanProperties {
287            &self.props
288        }
289
290        fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
291            vec![]
292        }
293
294        fn with_new_children(
295            self: Arc<Self>,
296            _children: Vec<Arc<dyn ExecutionPlan>>,
297        ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
298            Ok(self)
299        }
300
301        fn execute(
302            &self,
303            _partition: usize,
304            _context: Arc<TaskContext>,
305        ) -> datafusion_common::Result<SendableRecordBatchStream> {
306            unimplemented!("mock")
307        }
308    }
309
310    // ── Mock passthrough node (not sort/aggregate) ──────────────────
311
312    #[derive(Debug)]
313    struct MockPassthroughExec {
314        child: Arc<dyn ExecutionPlan>,
315        props: PlanProperties,
316    }
317
318    impl MockPassthroughExec {
319        fn new(child: Arc<dyn ExecutionPlan>) -> Self {
320            let props = child.properties().clone();
321            Self { child, props }
322        }
323    }
324
325    impl DisplayAs for MockPassthroughExec {
326        fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
327            write!(f, "MockPassthroughExec")
328        }
329    }
330
331    impl ExecutionPlan for MockPassthroughExec {
332        fn name(&self) -> &'static str {
333            "MockPassthroughExec"
334        }
335
336        fn as_any(&self) -> &dyn Any {
337            self
338        }
339
340        fn schema(&self) -> SchemaRef {
341            self.child.schema()
342        }
343
344        fn properties(&self) -> &PlanProperties {
345            &self.props
346        }
347
348        fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
349            vec![&self.child]
350        }
351
352        fn with_new_children(
353            self: Arc<Self>,
354            children: Vec<Arc<dyn ExecutionPlan>>,
355        ) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
356            Ok(Arc::new(Self::new(Arc::clone(&children[0]))))
357        }
358
359        fn execute(
360            &self,
361            _partition: usize,
362            _context: Arc<TaskContext>,
363        ) -> datafusion_common::Result<SendableRecordBatchStream> {
364            unimplemented!("mock")
365        }
366    }
367
368    fn test_schema() -> SchemaRef {
369        Arc::new(Schema::new(vec![
370            Field::new("id", DataType::Int64, false),
371            Field::new("value", DataType::Float64, true),
372        ]))
373    }
374
375    fn make_sort_on(child: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
376        use arrow_schema::SortOptions;
377        use datafusion::physical_expr::{expressions::Column, PhysicalSortExpr};
378
379        let sort_expr =
380            PhysicalSortExpr::new(Arc::new(Column::new("id", 0)), SortOptions::default());
381        let ordering = LexOrdering::new(vec![sort_expr]).expect("non-empty sort expr list");
382        Arc::new(SortExec::new(ordering, child))
383    }
384
385    fn make_final_aggregate_on(child: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
386        use datafusion::physical_plan::aggregates::PhysicalGroupBy;
387
388        let schema = child.schema();
389        let group_by = PhysicalGroupBy::new_single(vec![]);
390        let agg = AggregateExec::try_new(
391            AggregateMode::Final,
392            group_by,
393            vec![],
394            vec![],
395            child,
396            Arc::clone(&schema),
397        )
398        .expect("failed to create AggregateExec");
399        Arc::new(agg)
400    }
401
402    // ── Unit tests: violation detection ─────────────────────────────
403
404    #[test]
405    fn test_sort_on_unbounded_rejected() {
406        let leaf = Arc::new(MockUnboundedExec::new(test_schema()));
407        let plan = make_sort_on(leaf);
408        let violations = find_streaming_violations(&plan);
409        assert_eq!(violations.len(), 1);
410        assert!(violations[0].reason.contains("Sort requires buffering"));
411    }
412
413    #[test]
414    fn test_sort_on_bounded_allowed() {
415        let leaf = Arc::new(MockBoundedExec::new(test_schema()));
416        let plan = make_sort_on(leaf);
417        let violations = find_streaming_violations(&plan);
418        assert!(violations.is_empty());
419    }
420
421    #[test]
422    fn test_final_aggregate_on_unbounded_rejected() {
423        let leaf = Arc::new(MockUnboundedExec::new(test_schema()));
424        let plan = make_final_aggregate_on(leaf);
425        let violations = find_streaming_violations(&plan);
426        assert_eq!(violations.len(), 1);
427        assert!(violations[0].reason.contains("Final aggregation"));
428    }
429
430    #[test]
431    fn test_passthrough_on_unbounded_allowed() {
432        let leaf = Arc::new(MockUnboundedExec::new(test_schema()));
433        let plan: Arc<dyn ExecutionPlan> = Arc::new(MockPassthroughExec::new(leaf));
434        let violations = find_streaming_violations(&plan);
435        assert!(violations.is_empty());
436    }
437
438    #[test]
439    fn test_nested_plan_violation_detected() {
440        let leaf = Arc::new(MockUnboundedExec::new(test_schema()));
441        let sort = make_sort_on(leaf);
442        // Wrap the sort inside a passthrough so the violation is deep in the tree
443        let plan: Arc<dyn ExecutionPlan> = Arc::new(MockPassthroughExec::new(sort));
444        let violations = find_streaming_violations(&plan);
445        assert_eq!(violations.len(), 1);
446        assert!(
447            violations[0].plan_path.contains("SortExec"),
448            "path was: {}",
449            violations[0].plan_path
450        );
451    }
452
453    // ── Unit tests: modes ───────────────────────────────────────────
454
455    #[test]
456    fn test_reject_mode_returns_error() {
457        let validator = StreamingPhysicalValidator::new(StreamingValidatorMode::Reject);
458        let leaf = Arc::new(MockUnboundedExec::new(test_schema()));
459        let plan = make_sort_on(leaf);
460        let config = ConfigOptions::new();
461        let result = validator.optimize(plan, &config);
462        assert!(result.is_err());
463        let err = result.unwrap_err().to_string();
464        assert!(
465            err.contains("Streaming plan validation failed"),
466            "error was: {err}"
467        );
468    }
469
470    #[test]
471    fn test_warn_mode_passes_through() {
472        let validator = StreamingPhysicalValidator::new(StreamingValidatorMode::Warn);
473        let leaf = Arc::new(MockUnboundedExec::new(test_schema()));
474        let plan = make_sort_on(leaf);
475        let config = ConfigOptions::new();
476        let result = validator.optimize(plan, &config);
477        assert!(result.is_ok());
478    }
479
480    #[test]
481    fn test_off_mode_skips_validation() {
482        let validator = StreamingPhysicalValidator::new(StreamingValidatorMode::Off);
483        let leaf = Arc::new(MockUnboundedExec::new(test_schema()));
484        let plan = make_sort_on(leaf);
485        let config = ConfigOptions::new();
486        let result = validator.optimize(plan, &config);
487        assert!(result.is_ok());
488    }
489
490    // ── Integration test via create_streaming_context ───────────────
491
492    #[tokio::test]
493    async fn test_streaming_context_rejects_unbounded_sort() {
494        use crate::datafusion::{
495            create_streaming_context, ChannelStreamSource, StreamingTableProvider,
496        };
497        use arrow_schema::{DataType, Field, Schema};
498
499        let ctx = create_streaming_context();
500        let schema = Arc::new(Schema::new(vec![
501            Field::new("id", DataType::Int64, false),
502            Field::new("value", DataType::Float64, true),
503        ]));
504
505        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
506        let _sender = source.take_sender();
507        let provider = StreamingTableProvider::new("events", source);
508        ctx.register_table("events", Arc::new(provider)).unwrap();
509
510        // ORDER BY on unbounded source should fail at plan creation
511        let result = ctx.sql("SELECT * FROM events ORDER BY id").await;
512
513        // The physical optimizer should reject this plan
514        // (DataFusion creates the physical plan during sql() or at collect())
515        match result {
516            Ok(df) => {
517                // Physical plan creation may be deferred to collect()
518                let exec_result = df.collect().await;
519                assert!(
520                    exec_result.is_err(),
521                    "Sort on unbounded stream should be rejected"
522                );
523                let err = exec_result.unwrap_err().to_string();
524                assert!(
525                    err.contains("Streaming plan validation failed")
526                        || err.contains("Sort requires buffering"),
527                    "Expected streaming validation error, got: {err}"
528                );
529            }
530            Err(e) => {
531                let err = e.to_string();
532                assert!(
533                    err.contains("Streaming plan validation failed")
534                        || err.contains("Sort requires buffering"),
535                    "Expected streaming validation error, got: {err}"
536                );
537            }
538        }
539    }
540}