Skip to main content

hirn_exec/rules/
depth_scheduling.rs

1//! `DepthSchedulingRule` — prunes physical plan operators based on query complexity.
2//!
3//! When `DEPTH AUTO` is active, this rule inspects the physical plan tree and
4//! removes operators that are unnecessary for simpler queries:
5//!
6//! - **Simple (0 pts):** Remove `GraphActivationExec`, `CausalChainExec`,
7//!   `IterativeRetrievalExec`, `QualityGateExec` — vector search only.
8//! - **Medium (1–2 pts):** Remove `CausalChainExec`, `IterativeRetrievalExec`.
9//! - **Complex (3+ pts):** Keep all operators (no-op).
10//!
11//! Classification is performed eagerly at optimization time using the
12//! `QueryFeatures` embedded in the plan's `QueryComplexityExec` node.
13
14use std::sync::Arc;
15
16use datafusion_common::Result;
17use datafusion_common::tree_node::{Transformed, TreeNode};
18use datafusion_physical_optimizer::PhysicalOptimizerRule;
19use datafusion_physical_plan::ExecutionPlan;
20
21use crate::operators::{
22    CausalChainExec, Complexity, GraphActivationExec, IterativeRetrievalExec, QualityGateExec,
23};
24
25/// Physical optimizer rule that prunes operators based on depth scheduling.
26///
27/// The rule walks the physical plan tree and, for queries classified as Simple
28/// or Medium, removes expensive operators that would not improve result quality
29/// significantly. This achieves the 60%+ latency reduction target for Simple queries.
30#[derive(Debug, Default)]
31pub struct DepthSchedulingRule {
32    /// Override complexity for testing. When `None`, complexity is derived
33    /// from the plan tree (looking for embedded classification).
34    forced_complexity: Option<Complexity>,
35}
36
37impl DepthSchedulingRule {
38    pub fn new() -> Self {
39        Self::default()
40    }
41
42    /// Create a rule with a forced complexity level (for testing or DEPTH FULL/SUMMARY).
43    pub fn with_complexity(complexity: Complexity) -> Self {
44        Self {
45            forced_complexity: Some(complexity),
46        }
47    }
48
49    /// Determine whether a given operator should be pruned for the given complexity.
50    fn should_prune(plan: &dyn ExecutionPlan, complexity: Complexity) -> bool {
51        match complexity {
52            Complexity::Simple => {
53                // Simple: remove graph activation, causal chain, iterative retrieval, quality gate
54                plan.as_any()
55                    .downcast_ref::<GraphActivationExec>()
56                    .is_some()
57                    || plan.as_any().downcast_ref::<CausalChainExec>().is_some()
58                    || plan
59                        .as_any()
60                        .downcast_ref::<IterativeRetrievalExec>()
61                        .is_some()
62                    || plan.as_any().downcast_ref::<QualityGateExec>().is_some()
63            }
64            Complexity::Medium => {
65                // Medium: remove causal chain, iterative retrieval
66                plan.as_any().downcast_ref::<CausalChainExec>().is_some()
67                    || plan
68                        .as_any()
69                        .downcast_ref::<IterativeRetrievalExec>()
70                        .is_some()
71            }
72            Complexity::Complex => false, // Keep all operators
73        }
74    }
75}
76
77impl PhysicalOptimizerRule for DepthSchedulingRule {
78    fn optimize(
79        &self,
80        plan: Arc<dyn ExecutionPlan>,
81        _config: &datafusion_common::config::ConfigOptions,
82    ) -> Result<Arc<dyn ExecutionPlan>> {
83        let complexity = self.forced_complexity.unwrap_or(Complexity::Complex);
84
85        if complexity == Complexity::Complex {
86            // No pruning needed for Complex queries.
87            return Ok(plan);
88        }
89
90        plan.transform_down(|node| {
91            if Self::should_prune(node.as_ref(), complexity) {
92                // Replace the pruned operator with its first child, bypassing it.
93                let children = node.children();
94                if let Some(child) = children.first() {
95                    Ok(Transformed::yes(Arc::clone(child)))
96                } else {
97                    // No children — replace with empty exec.
98                    let schema = node.schema();
99                    Ok(Transformed::yes(
100                        Arc::new(datafusion_physical_plan::empty::EmptyExec::new(schema))
101                            as Arc<dyn ExecutionPlan>,
102                    ))
103                }
104            } else {
105                Ok(Transformed::no(node))
106            }
107        })
108        .map(|t| t.data)
109    }
110
111    fn name(&self) -> &str {
112        "DepthSchedulingRule"
113    }
114
115    fn schema_check(&self) -> bool {
116        true
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123    use crate::operators::{ActivationMode, IterativeConfig};
124    use arrow_array::{RecordBatch, StringArray};
125    use arrow_schema::{DataType, Field, Schema};
126    use datafusion_common::config::ConfigOptions;
127    use datafusion_datasource::memory::MemorySourceConfig;
128
129    fn leaf_plan() -> Arc<dyn ExecutionPlan> {
130        let schema = Arc::new(Schema::new(vec![Field::new(
131            "content",
132            DataType::Utf8,
133            false,
134        )]));
135        let batch = RecordBatch::try_new(
136            schema.clone(),
137            vec![Arc::new(StringArray::from(vec!["test"]))],
138        )
139        .unwrap();
140        MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap()
141    }
142
143    #[test]
144    fn simple_prunes_graph_activation() {
145        let leaf = leaf_plan();
146        let graph = Arc::new(
147            GraphActivationExec::new(leaf, 10, ActivationMode::Spreading, 2, 0.01, 0.1).unwrap(),
148        ) as Arc<dyn ExecutionPlan>;
149
150        let rule = DepthSchedulingRule::with_complexity(Complexity::Simple);
151        let config = ConfigOptions::new();
152        let optimized = rule.optimize(graph, &config).unwrap();
153
154        // GraphActivationExec should be removed.
155        assert!(
156            optimized
157                .as_any()
158                .downcast_ref::<GraphActivationExec>()
159                .is_none(),
160            "Simple should prune GraphActivationExec"
161        );
162    }
163
164    #[test]
165    fn simple_prunes_causal_chain() {
166        let leaf = leaf_plan();
167        let causal = Arc::new(CausalChainExec::new(leaf, 3, 0.3)) as Arc<dyn ExecutionPlan>;
168
169        let rule = DepthSchedulingRule::with_complexity(Complexity::Simple);
170        let config = ConfigOptions::new();
171        let optimized = rule.optimize(causal, &config).unwrap();
172
173        assert!(
174            optimized
175                .as_any()
176                .downcast_ref::<CausalChainExec>()
177                .is_none(),
178            "Simple should prune CausalChainExec"
179        );
180    }
181
182    #[test]
183    fn simple_prunes_iterative_retrieval() {
184        let leaf = leaf_plan();
185        let iterative = Arc::new(IterativeRetrievalExec::new(
186            leaf,
187            IterativeConfig::default(),
188        )) as Arc<dyn ExecutionPlan>;
189
190        let rule = DepthSchedulingRule::with_complexity(Complexity::Simple);
191        let config = ConfigOptions::new();
192        let optimized = rule.optimize(iterative, &config).unwrap();
193
194        assert!(
195            optimized
196                .as_any()
197                .downcast_ref::<IterativeRetrievalExec>()
198                .is_none(),
199            "Simple should prune IterativeRetrievalExec"
200        );
201    }
202
203    #[test]
204    fn medium_keeps_graph_activation() {
205        let leaf = leaf_plan();
206        let graph = Arc::new(
207            GraphActivationExec::new(leaf, 10, ActivationMode::Spreading, 2, 0.01, 0.1).unwrap(),
208        ) as Arc<dyn ExecutionPlan>;
209
210        let rule = DepthSchedulingRule::with_complexity(Complexity::Medium);
211        let config = ConfigOptions::new();
212        let optimized = rule.optimize(graph, &config).unwrap();
213
214        assert!(
215            optimized
216                .as_any()
217                .downcast_ref::<GraphActivationExec>()
218                .is_some(),
219            "Medium should keep GraphActivationExec"
220        );
221    }
222
223    #[test]
224    fn medium_prunes_causal_chain() {
225        let leaf = leaf_plan();
226        let causal = Arc::new(CausalChainExec::new(leaf, 3, 0.3)) as Arc<dyn ExecutionPlan>;
227
228        let rule = DepthSchedulingRule::with_complexity(Complexity::Medium);
229        let config = ConfigOptions::new();
230        let optimized = rule.optimize(causal, &config).unwrap();
231
232        assert!(
233            optimized
234                .as_any()
235                .downcast_ref::<CausalChainExec>()
236                .is_none(),
237            "Medium should prune CausalChainExec"
238        );
239    }
240
241    #[test]
242    fn complex_keeps_all() {
243        let leaf = leaf_plan();
244        let graph = Arc::new(
245            GraphActivationExec::new(leaf, 10, ActivationMode::Spreading, 2, 0.01, 0.1).unwrap(),
246        ) as Arc<dyn ExecutionPlan>;
247
248        let rule = DepthSchedulingRule::with_complexity(Complexity::Complex);
249        let config = ConfigOptions::new();
250        let optimized = rule.optimize(graph, &config).unwrap();
251
252        assert!(
253            optimized
254                .as_any()
255                .downcast_ref::<GraphActivationExec>()
256                .is_some(),
257            "Complex should keep GraphActivationExec"
258        );
259    }
260
261    #[test]
262    fn default_rule_is_complex() {
263        // Default rule acts as no-op (Complex classification).
264        let leaf = leaf_plan();
265        let graph = Arc::new(
266            GraphActivationExec::new(leaf, 10, ActivationMode::Spreading, 2, 0.01, 0.1).unwrap(),
267        ) as Arc<dyn ExecutionPlan>;
268
269        let rule = DepthSchedulingRule::new();
270        let config = ConfigOptions::new();
271        let optimized = rule.optimize(graph, &config).unwrap();
272
273        assert!(
274            optimized
275                .as_any()
276                .downcast_ref::<GraphActivationExec>()
277                .is_some(),
278            "Default (Complex) should keep all operators"
279        );
280    }
281}