1use std::{collections::HashMap, sync::Arc};
2
3use datafusion::{
4 config::ConfigOptions, datasource::source::DataSourceExec, error::Result,
5 execution::object_store::ObjectStoreUrl, physical_optimizer::PhysicalOptimizerRule,
6 physical_plan::ExecutionPlan, physical_plan::aggregates::AggregateExec,
7 physical_plan::aggregates::AggregateMode, physical_plan::repartition::RepartitionExec,
8};
9
10use crate::client_exec::LiquidCacheClientExec;
11
12#[derive(Debug)]
14pub struct PushdownOptimizer {
15 cache_server: String,
16 object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
17}
18
19impl PushdownOptimizer {
20 pub fn new(
30 cache_server: String,
31 object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
32 ) -> Self {
33 Self {
34 cache_server,
35 object_stores,
36 }
37 }
38
39 fn optimize_plan(&self, plan: Arc<dyn ExecutionPlan>) -> Result<Arc<dyn ExecutionPlan>> {
41 if plan
43 .as_any()
44 .downcast_ref::<LiquidCacheClientExec>()
45 .is_some()
46 {
47 return Ok(plan);
48 }
49
50 if let Some(candidate) = find_pushdown_candidate(&plan) {
52 if Arc::ptr_eq(&plan, &candidate) {
54 return Ok(Arc::new(LiquidCacheClientExec::new(
55 plan,
56 self.cache_server.clone(),
57 self.object_stores.clone(),
58 )));
59 }
60 }
61
62 let mut new_children = Vec::with_capacity(plan.children().len());
64 let mut children_changed = false;
65
66 for child in plan.children() {
67 let new_child = self.optimize_plan(child.clone())?;
68 if !Arc::ptr_eq(child, &new_child) {
69 children_changed = true;
70 }
71 new_children.push(new_child);
72 }
73
74 if children_changed {
76 plan.with_new_children(new_children)
77 } else {
78 Ok(plan)
79 }
80 }
81}
82
83fn find_pushdown_candidate(plan: &Arc<dyn ExecutionPlan>) -> Option<Arc<dyn ExecutionPlan>> {
85 if plan
87 .as_any()
88 .downcast_ref::<LiquidCacheClientExec>()
89 .is_some()
90 {
91 return None;
92 }
93
94 let plan_any = plan.as_any();
95
96 if let Some(agg_exec) = plan_any.downcast_ref::<AggregateExec>()
98 && matches!(agg_exec.mode(), AggregateMode::Partial)
99 && agg_exec.group_expr().is_empty()
100 {
101 let child = agg_exec.input();
102
103 if child.as_any().downcast_ref::<DataSourceExec>().is_some() {
105 return Some(plan.clone());
106 }
107 if let Some(repart) = child.as_any().downcast_ref::<RepartitionExec>()
108 && let Some(repart_child) = repart.children().first()
109 && repart_child
110 .as_any()
111 .downcast_ref::<DataSourceExec>()
112 .is_some()
113 {
114 return Some(plan.clone());
115 }
116 }
117
118 if let Some(repart_exec) = plan_any.downcast_ref::<RepartitionExec>()
120 && let Some(child) = repart_exec.children().first()
121 && child.as_any().downcast_ref::<DataSourceExec>().is_some()
122 {
123 return Some(plan.clone());
124 }
125
126 if plan_any.downcast_ref::<DataSourceExec>().is_some() {
128 return Some(plan.clone());
129 }
130
131 for child in plan.children() {
133 if let Some(candidate) = find_pushdown_candidate(child) {
134 return Some(candidate);
135 }
136 }
137
138 None
139}
140
141impl PhysicalOptimizerRule for PushdownOptimizer {
142 fn optimize(
143 &self,
144 plan: Arc<dyn ExecutionPlan>,
145 _config: &ConfigOptions,
146 ) -> Result<Arc<dyn ExecutionPlan>> {
147 self.optimize_plan(plan)
148 }
149
150 fn name(&self) -> &str {
151 "PushdownOptimizer"
152 }
153
154 fn schema_check(&self) -> bool {
155 true
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use std::sync::Arc;
162
163 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
164 use datafusion::{
165 config::ConfigOptions,
166 datasource::memory::MemorySourceConfig,
167 error::Result,
168 execution::SessionStateBuilder,
169 physical_plan::{
170 ExecutionPlan,
171 aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy},
172 display::DisplayableExecutionPlan,
173 repartition::RepartitionExec,
174 },
175 prelude::{SessionConfig, SessionContext},
176 };
177
178 use super::*;
179
180 async fn create_session_context() -> SessionContext {
181 let mut config = SessionConfig::from_env().unwrap();
182 config.options_mut().execution.parquet.pushdown_filters = true;
183 let builder = SessionStateBuilder::new()
184 .with_config(config)
185 .with_default_features()
186 .with_physical_optimizer_rule(Arc::new(PushdownOptimizer::new(
187 "localhost:15214".to_string(),
188 vec![],
189 )));
190 let state = builder.build();
191 let ctx = SessionContext::new_with_state(state);
192 ctx.register_parquet(
193 "nano_hits",
194 "../../examples/nano_hits.parquet",
195 Default::default(),
196 )
197 .await
198 .unwrap();
199 ctx
200 }
201
202 #[tokio::test]
203 async fn test_plan_rewrite() {
204 let ctx = create_session_context().await;
205 let df = ctx
206 .sql("SELECT \"URL\" FROM nano_hits WHERE \"URL\" like 'https://%' limit 10")
207 .await
208 .unwrap();
209 let plan = df.create_physical_plan().await.unwrap();
210 let display_plan = DisplayableExecutionPlan::new(plan.as_ref());
211 let plan_str = display_plan.indent(false).to_string();
212
213 assert!(plan_str.contains("LiquidCacheClientExec"));
214 assert!(plan_str.contains("DataSourceExec"));
215 }
216
217 #[tokio::test]
218 async fn test_aggregate_pushdown() {
219 let ctx = create_session_context().await;
220
221 let df = ctx
222 .sql("SELECT MAX(\"URL\") FROM nano_hits WHERE \"URL\" like 'https://%'")
223 .await
224 .unwrap();
225 let plan = df.create_physical_plan().await.unwrap();
226 let display_plan = DisplayableExecutionPlan::new(plan.as_ref());
227 let plan_str = display_plan.indent(false).to_string();
228
229 println!("Plan: {plan_str}");
230
231 assert!(plan_str.contains("LiquidCacheClientExec"));
238
239 let parts: Vec<&str> = plan_str.split("LiquidCacheClientExec").collect();
240 assert!(parts.len() > 1);
241
242 let higher_layers = parts[0];
243 let pushed_down = parts[1];
244
245 assert!(higher_layers.contains("AggregateExec: mode=Final"));
246 assert!(pushed_down.contains("AggregateExec: mode=Partial"));
247 assert!(pushed_down.contains("DataSourceExec"));
248 }
249
250 fn create_test_schema() -> SchemaRef {
252 Arc::new(Schema::new(vec![
253 Field::new("c1", DataType::Int32, true),
254 Field::new("c2", DataType::Utf8, true),
255 Field::new("c3", DataType::Float64, true),
256 ]))
257 }
258
259 fn create_datasource_exec(schema: SchemaRef) -> Arc<dyn ExecutionPlan> {
261 Arc::new(DataSourceExec::new(Arc::new(
262 MemorySourceConfig::try_new(&[vec![]], schema, None).unwrap(),
263 )))
264 }
265
266 fn apply_optimizer(plan: Arc<dyn ExecutionPlan>) -> String {
268 let optimizer = PushdownOptimizer::new("localhost:15214".to_string(), vec![]);
269
270 let optimized = optimizer.optimize(plan, &ConfigOptions::default()).unwrap();
271 let display_plan = DisplayableExecutionPlan::new(optimized.as_ref());
272 display_plan.indent(false).to_string()
273 }
274
275 #[test]
276 fn test_simple_datasource_pushdown() -> Result<()> {
277 let schema = create_test_schema();
278 let datasource = create_datasource_exec(schema);
279 let result = apply_optimizer(datasource);
280 assert!(result.starts_with("LiquidCacheClientExec"));
281 Ok(())
282 }
283
284 #[test]
285 fn test_repartition_datasource_pushdown() -> Result<()> {
286 let schema = create_test_schema();
287 let datasource = create_datasource_exec(schema);
288 let repartition = Arc::new(RepartitionExec::try_new(
289 datasource,
290 datafusion::physical_plan::Partitioning::RoundRobinBatch(4),
291 )?);
292
293 let result = apply_optimizer(repartition);
294
295 assert!(result.starts_with("LiquidCacheClientExec"));
296 assert!(result.contains("RepartitionExec"));
297
298 Ok(())
299 }
300
301 #[test]
302 fn test_partial_aggregate_pushdown() -> Result<()> {
303 let schema = create_test_schema();
305 let datasource = create_datasource_exec(schema.clone());
306
307 let group_by = PhysicalGroupBy::new_single(vec![]);
308
309 let aggregate = Arc::new(AggregateExec::try_new(
310 AggregateMode::Partial,
311 group_by,
312 vec![],
313 vec![],
314 datasource,
315 schema.clone(),
316 )?);
317
318 let result = apply_optimizer(aggregate);
319
320 assert!(result.starts_with("LiquidCacheClientExec"));
321 assert!(result.contains("AggregateExec: mode=Partial"));
322
323 Ok(())
324 }
325
326 #[test]
327 fn test_aggregate_with_repartition_pushdown() -> Result<()> {
328 let schema = create_test_schema();
330 let datasource = create_datasource_exec(schema.clone());
331
332 let repartition = Arc::new(RepartitionExec::try_new(
333 datasource,
334 datafusion::physical_plan::Partitioning::RoundRobinBatch(4),
335 )?);
336
337 let group_by = PhysicalGroupBy::new_single(vec![]);
338 let aggregate = Arc::new(AggregateExec::try_new(
339 AggregateMode::Partial,
340 group_by,
341 vec![],
342 vec![],
343 repartition,
344 schema.clone(),
345 )?);
346
347 let result = apply_optimizer(aggregate);
348
349 assert!(result.starts_with("LiquidCacheClientExec"));
350 assert!(result.contains("AggregateExec: mode=Partial"));
351 assert!(result.contains("RepartitionExec"));
352
353 Ok(())
354 }
355
356 #[test]
357 fn test_non_pushable_aggregate() -> Result<()> {
358 let schema = create_test_schema();
361 let datasource = create_datasource_exec(schema.clone());
362
363 let group_by = PhysicalGroupBy::new_single(vec![]);
364
365 let aggregate = Arc::new(AggregateExec::try_new(
366 AggregateMode::Final,
367 group_by,
368 vec![],
369 vec![],
370 datasource,
371 schema.clone(),
372 )?);
373
374 let result = apply_optimizer(aggregate);
375
376 let parts: Vec<&str> = result.split("LiquidCacheClientExec").collect();
377 assert!(parts.len() > 1);
378
379 let higher_layers = parts[0];
380 assert!(higher_layers.contains("AggregateExec: mode=Final"));
381 let lower_layers = parts[1];
382 assert!(lower_layers.contains("DataSourceExec"));
383
384 Ok(())
385 }
386}