hirn_exec/rules/
depth_scheduling.rs1use std::sync::Arc;
15
16use datafusion_common::Result;
17use datafusion_common::tree_node::{Transformed, TreeNode};
18use datafusion_physical_optimizer::PhysicalOptimizerRule;
19use datafusion_physical_plan::ExecutionPlan;
20
21use crate::operators::{
22 CausalChainExec, Complexity, GraphActivationExec, IterativeRetrievalExec, QualityGateExec,
23};
24
25#[derive(Debug, Default)]
31pub struct DepthSchedulingRule {
32 forced_complexity: Option<Complexity>,
35}
36
37impl DepthSchedulingRule {
38 pub fn new() -> Self {
39 Self::default()
40 }
41
42 pub fn with_complexity(complexity: Complexity) -> Self {
44 Self {
45 forced_complexity: Some(complexity),
46 }
47 }
48
49 fn should_prune(plan: &dyn ExecutionPlan, complexity: Complexity) -> bool {
51 match complexity {
52 Complexity::Simple => {
53 plan.as_any()
55 .downcast_ref::<GraphActivationExec>()
56 .is_some()
57 || plan.as_any().downcast_ref::<CausalChainExec>().is_some()
58 || plan
59 .as_any()
60 .downcast_ref::<IterativeRetrievalExec>()
61 .is_some()
62 || plan.as_any().downcast_ref::<QualityGateExec>().is_some()
63 }
64 Complexity::Medium => {
65 plan.as_any().downcast_ref::<CausalChainExec>().is_some()
67 || plan
68 .as_any()
69 .downcast_ref::<IterativeRetrievalExec>()
70 .is_some()
71 }
72 Complexity::Complex => false, }
74 }
75}
76
77impl PhysicalOptimizerRule for DepthSchedulingRule {
78 fn optimize(
79 &self,
80 plan: Arc<dyn ExecutionPlan>,
81 _config: &datafusion_common::config::ConfigOptions,
82 ) -> Result<Arc<dyn ExecutionPlan>> {
83 let complexity = self.forced_complexity.unwrap_or(Complexity::Complex);
84
85 if complexity == Complexity::Complex {
86 return Ok(plan);
88 }
89
90 plan.transform_down(|node| {
91 if Self::should_prune(node.as_ref(), complexity) {
92 let children = node.children();
94 if let Some(child) = children.first() {
95 Ok(Transformed::yes(Arc::clone(child)))
96 } else {
97 let schema = node.schema();
99 Ok(Transformed::yes(
100 Arc::new(datafusion_physical_plan::empty::EmptyExec::new(schema))
101 as Arc<dyn ExecutionPlan>,
102 ))
103 }
104 } else {
105 Ok(Transformed::no(node))
106 }
107 })
108 .map(|t| t.data)
109 }
110
111 fn name(&self) -> &str {
112 "DepthSchedulingRule"
113 }
114
115 fn schema_check(&self) -> bool {
116 true
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123 use crate::operators::{ActivationMode, IterativeConfig};
124 use arrow_array::{RecordBatch, StringArray};
125 use arrow_schema::{DataType, Field, Schema};
126 use datafusion_common::config::ConfigOptions;
127 use datafusion_datasource::memory::MemorySourceConfig;
128
129 fn leaf_plan() -> Arc<dyn ExecutionPlan> {
130 let schema = Arc::new(Schema::new(vec![Field::new(
131 "content",
132 DataType::Utf8,
133 false,
134 )]));
135 let batch = RecordBatch::try_new(
136 schema.clone(),
137 vec![Arc::new(StringArray::from(vec!["test"]))],
138 )
139 .unwrap();
140 MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap()
141 }
142
143 #[test]
144 fn simple_prunes_graph_activation() {
145 let leaf = leaf_plan();
146 let graph = Arc::new(
147 GraphActivationExec::new(leaf, 10, ActivationMode::Spreading, 2, 0.01, 0.1).unwrap(),
148 ) as Arc<dyn ExecutionPlan>;
149
150 let rule = DepthSchedulingRule::with_complexity(Complexity::Simple);
151 let config = ConfigOptions::new();
152 let optimized = rule.optimize(graph, &config).unwrap();
153
154 assert!(
156 optimized
157 .as_any()
158 .downcast_ref::<GraphActivationExec>()
159 .is_none(),
160 "Simple should prune GraphActivationExec"
161 );
162 }
163
164 #[test]
165 fn simple_prunes_causal_chain() {
166 let leaf = leaf_plan();
167 let causal = Arc::new(CausalChainExec::new(leaf, 3, 0.3)) as Arc<dyn ExecutionPlan>;
168
169 let rule = DepthSchedulingRule::with_complexity(Complexity::Simple);
170 let config = ConfigOptions::new();
171 let optimized = rule.optimize(causal, &config).unwrap();
172
173 assert!(
174 optimized
175 .as_any()
176 .downcast_ref::<CausalChainExec>()
177 .is_none(),
178 "Simple should prune CausalChainExec"
179 );
180 }
181
182 #[test]
183 fn simple_prunes_iterative_retrieval() {
184 let leaf = leaf_plan();
185 let iterative = Arc::new(IterativeRetrievalExec::new(
186 leaf,
187 IterativeConfig::default(),
188 )) as Arc<dyn ExecutionPlan>;
189
190 let rule = DepthSchedulingRule::with_complexity(Complexity::Simple);
191 let config = ConfigOptions::new();
192 let optimized = rule.optimize(iterative, &config).unwrap();
193
194 assert!(
195 optimized
196 .as_any()
197 .downcast_ref::<IterativeRetrievalExec>()
198 .is_none(),
199 "Simple should prune IterativeRetrievalExec"
200 );
201 }
202
203 #[test]
204 fn medium_keeps_graph_activation() {
205 let leaf = leaf_plan();
206 let graph = Arc::new(
207 GraphActivationExec::new(leaf, 10, ActivationMode::Spreading, 2, 0.01, 0.1).unwrap(),
208 ) as Arc<dyn ExecutionPlan>;
209
210 let rule = DepthSchedulingRule::with_complexity(Complexity::Medium);
211 let config = ConfigOptions::new();
212 let optimized = rule.optimize(graph, &config).unwrap();
213
214 assert!(
215 optimized
216 .as_any()
217 .downcast_ref::<GraphActivationExec>()
218 .is_some(),
219 "Medium should keep GraphActivationExec"
220 );
221 }
222
223 #[test]
224 fn medium_prunes_causal_chain() {
225 let leaf = leaf_plan();
226 let causal = Arc::new(CausalChainExec::new(leaf, 3, 0.3)) as Arc<dyn ExecutionPlan>;
227
228 let rule = DepthSchedulingRule::with_complexity(Complexity::Medium);
229 let config = ConfigOptions::new();
230 let optimized = rule.optimize(causal, &config).unwrap();
231
232 assert!(
233 optimized
234 .as_any()
235 .downcast_ref::<CausalChainExec>()
236 .is_none(),
237 "Medium should prune CausalChainExec"
238 );
239 }
240
241 #[test]
242 fn complex_keeps_all() {
243 let leaf = leaf_plan();
244 let graph = Arc::new(
245 GraphActivationExec::new(leaf, 10, ActivationMode::Spreading, 2, 0.01, 0.1).unwrap(),
246 ) as Arc<dyn ExecutionPlan>;
247
248 let rule = DepthSchedulingRule::with_complexity(Complexity::Complex);
249 let config = ConfigOptions::new();
250 let optimized = rule.optimize(graph, &config).unwrap();
251
252 assert!(
253 optimized
254 .as_any()
255 .downcast_ref::<GraphActivationExec>()
256 .is_some(),
257 "Complex should keep GraphActivationExec"
258 );
259 }
260
261 #[test]
262 fn default_rule_is_complex() {
263 let leaf = leaf_plan();
265 let graph = Arc::new(
266 GraphActivationExec::new(leaf, 10, ActivationMode::Spreading, 2, 0.01, 0.1).unwrap(),
267 ) as Arc<dyn ExecutionPlan>;
268
269 let rule = DepthSchedulingRule::new();
270 let config = ConfigOptions::new();
271 let optimized = rule.optimize(graph, &config).unwrap();
272
273 assert!(
274 optimized
275 .as_any()
276 .downcast_ref::<GraphActivationExec>()
277 .is_some(),
278 "Default (Complex) should keep all operators"
279 );
280 }
281}