1use std::{collections::HashMap, sync::Arc};
2
3use datafusion::{
4 config::ConfigOptions, datasource::source::DataSourceExec, error::Result,
5 execution::object_store::ObjectStoreUrl, physical_optimizer::PhysicalOptimizerRule,
6 physical_plan::ExecutionPlan, physical_plan::aggregates::AggregateExec,
7 physical_plan::aggregates::AggregateMode, physical_plan::repartition::RepartitionExec,
8};
9use liquid_cache_common::CacheMode;
10
11use crate::client_exec::LiquidCacheClientExec;
12
13#[derive(Debug)]
15pub struct PushdownOptimizer {
16 cache_server: String,
17 cache_mode: CacheMode,
18 object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
19}
20
21impl PushdownOptimizer {
22 pub fn new(
32 cache_server: String,
33 cache_mode: CacheMode,
34 object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
35 ) -> Self {
36 Self {
37 cache_server,
38 cache_mode,
39 object_stores,
40 }
41 }
42
43 fn optimize_plan(&self, plan: Arc<dyn ExecutionPlan>) -> Result<Arc<dyn ExecutionPlan>> {
45 if plan
47 .as_any()
48 .downcast_ref::<LiquidCacheClientExec>()
49 .is_some()
50 {
51 return Ok(plan);
52 }
53
54 if let Some(candidate) = find_pushdown_candidate(&plan) {
56 if Arc::ptr_eq(&plan, &candidate) {
58 return Ok(Arc::new(LiquidCacheClientExec::new(
59 plan,
60 self.cache_server.clone(),
61 self.cache_mode,
62 self.object_stores.clone(),
63 )));
64 }
65 }
66
67 let mut new_children = Vec::with_capacity(plan.children().len());
69 let mut children_changed = false;
70
71 for child in plan.children() {
72 let new_child = self.optimize_plan(child.clone())?;
73 if !Arc::ptr_eq(child, &new_child) {
74 children_changed = true;
75 }
76 new_children.push(new_child);
77 }
78
79 if children_changed {
81 plan.with_new_children(new_children)
82 } else {
83 Ok(plan)
84 }
85 }
86}
87
88fn find_pushdown_candidate(plan: &Arc<dyn ExecutionPlan>) -> Option<Arc<dyn ExecutionPlan>> {
90 if plan
92 .as_any()
93 .downcast_ref::<LiquidCacheClientExec>()
94 .is_some()
95 {
96 return None;
97 }
98
99 let plan_any = plan.as_any();
100
101 if let Some(agg_exec) = plan_any.downcast_ref::<AggregateExec>()
103 && matches!(agg_exec.mode(), AggregateMode::Partial)
104 && agg_exec.group_expr().is_empty()
105 {
106 let child = agg_exec.input();
107
108 if child.as_any().downcast_ref::<DataSourceExec>().is_some() {
110 return Some(plan.clone());
111 }
112 if let Some(repart) = child.as_any().downcast_ref::<RepartitionExec>()
113 && let Some(repart_child) = repart.children().first()
114 && repart_child
115 .as_any()
116 .downcast_ref::<DataSourceExec>()
117 .is_some()
118 {
119 return Some(plan.clone());
120 }
121 }
122
123 if let Some(repart_exec) = plan_any.downcast_ref::<RepartitionExec>()
125 && let Some(child) = repart_exec.children().first()
126 && child.as_any().downcast_ref::<DataSourceExec>().is_some()
127 {
128 return Some(plan.clone());
129 }
130
131 if plan_any.downcast_ref::<DataSourceExec>().is_some() {
133 return Some(plan.clone());
134 }
135
136 for child in plan.children() {
138 if let Some(candidate) = find_pushdown_candidate(child) {
139 return Some(candidate);
140 }
141 }
142
143 None
144}
145
146impl PhysicalOptimizerRule for PushdownOptimizer {
147 fn optimize(
148 &self,
149 plan: Arc<dyn ExecutionPlan>,
150 _config: &ConfigOptions,
151 ) -> Result<Arc<dyn ExecutionPlan>> {
152 self.optimize_plan(plan)
153 }
154
155 fn name(&self) -> &str {
156 "PushdownOptimizer"
157 }
158
159 fn schema_check(&self) -> bool {
160 true
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use std::sync::Arc;
167
168 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
169 use datafusion::{
170 config::ConfigOptions,
171 datasource::memory::MemorySourceConfig,
172 error::Result,
173 execution::SessionStateBuilder,
174 physical_plan::{
175 ExecutionPlan,
176 aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy},
177 display::DisplayableExecutionPlan,
178 repartition::RepartitionExec,
179 },
180 prelude::{SessionConfig, SessionContext},
181 };
182 use liquid_cache_common::CacheMode;
183
184 use super::*;
185
186 async fn create_session_context() -> SessionContext {
187 let mut config = SessionConfig::from_env().unwrap();
188 config.options_mut().execution.parquet.pushdown_filters = true;
189 let builder = SessionStateBuilder::new()
190 .with_config(config)
191 .with_default_features()
192 .with_physical_optimizer_rule(Arc::new(PushdownOptimizer::new(
193 "localhost:15214".to_string(),
194 CacheMode::Liquid,
195 vec![],
196 )));
197 let state = builder.build();
198 let ctx = SessionContext::new_with_state(state);
199 ctx.register_parquet(
200 "nano_hits",
201 "../../examples/nano_hits.parquet",
202 Default::default(),
203 )
204 .await
205 .unwrap();
206 ctx
207 }
208
209 #[tokio::test]
210 async fn test_plan_rewrite() {
211 let ctx = create_session_context().await;
212 let df = ctx
213 .sql("SELECT \"URL\" FROM nano_hits WHERE \"URL\" like 'https://%' limit 10")
214 .await
215 .unwrap();
216 let plan = df.create_physical_plan().await.unwrap();
217 let display_plan = DisplayableExecutionPlan::new(plan.as_ref());
218 let plan_str = display_plan.indent(false).to_string();
219
220 assert!(plan_str.contains("LiquidCacheClientExec"));
221 assert!(plan_str.contains("DataSourceExec"));
222 }
223
224 #[tokio::test]
225 async fn test_aggregate_pushdown() {
226 let ctx = create_session_context().await;
227
228 let df = ctx
229 .sql("SELECT MAX(\"URL\") FROM nano_hits WHERE \"URL\" like 'https://%'")
230 .await
231 .unwrap();
232 let plan = df.create_physical_plan().await.unwrap();
233 let display_plan = DisplayableExecutionPlan::new(plan.as_ref());
234 let plan_str = display_plan.indent(false).to_string();
235
236 println!("Plan: {plan_str}");
237
238 assert!(plan_str.contains("LiquidCacheClientExec"));
245
246 let parts: Vec<&str> = plan_str.split("LiquidCacheClientExec").collect();
247 assert!(parts.len() > 1);
248
249 let higher_layers = parts[0];
250 let pushed_down = parts[1];
251
252 assert!(higher_layers.contains("AggregateExec: mode=Final"));
253 assert!(pushed_down.contains("AggregateExec: mode=Partial"));
254 assert!(pushed_down.contains("DataSourceExec"));
255 }
256
257 fn create_test_schema() -> SchemaRef {
259 Arc::new(Schema::new(vec![
260 Field::new("c1", DataType::Int32, true),
261 Field::new("c2", DataType::Utf8, true),
262 Field::new("c3", DataType::Float64, true),
263 ]))
264 }
265
266 fn create_datasource_exec(schema: SchemaRef) -> Arc<dyn ExecutionPlan> {
268 Arc::new(DataSourceExec::new(Arc::new(
269 MemorySourceConfig::try_new(&[vec![]], schema, None).unwrap(),
270 )))
271 }
272
273 fn apply_optimizer(plan: Arc<dyn ExecutionPlan>) -> String {
275 let optimizer =
276 PushdownOptimizer::new("localhost:15214".to_string(), CacheMode::Liquid, vec![]);
277
278 let optimized = optimizer.optimize(plan, &ConfigOptions::default()).unwrap();
279 let display_plan = DisplayableExecutionPlan::new(optimized.as_ref());
280 display_plan.indent(false).to_string()
281 }
282
283 #[test]
284 fn test_simple_datasource_pushdown() -> Result<()> {
285 let schema = create_test_schema();
286 let datasource = create_datasource_exec(schema);
287 let result = apply_optimizer(datasource);
288 assert!(result.starts_with("LiquidCacheClientExec"));
289 Ok(())
290 }
291
292 #[test]
293 fn test_repartition_datasource_pushdown() -> Result<()> {
294 let schema = create_test_schema();
295 let datasource = create_datasource_exec(schema);
296 let repartition = Arc::new(RepartitionExec::try_new(
297 datasource,
298 datafusion::physical_plan::Partitioning::RoundRobinBatch(4),
299 )?);
300
301 let result = apply_optimizer(repartition);
302
303 assert!(result.starts_with("LiquidCacheClientExec"));
304 assert!(result.contains("RepartitionExec"));
305
306 Ok(())
307 }
308
309 #[test]
310 fn test_partial_aggregate_pushdown() -> Result<()> {
311 let schema = create_test_schema();
313 let datasource = create_datasource_exec(schema.clone());
314
315 let group_by = PhysicalGroupBy::new_single(vec![]);
316
317 let aggregate = Arc::new(AggregateExec::try_new(
318 AggregateMode::Partial,
319 group_by,
320 vec![],
321 vec![],
322 datasource,
323 schema.clone(),
324 )?);
325
326 let result = apply_optimizer(aggregate);
327
328 assert!(result.starts_with("LiquidCacheClientExec"));
329 assert!(result.contains("AggregateExec: mode=Partial"));
330
331 Ok(())
332 }
333
334 #[test]
335 fn test_aggregate_with_repartition_pushdown() -> Result<()> {
336 let schema = create_test_schema();
338 let datasource = create_datasource_exec(schema.clone());
339
340 let repartition = Arc::new(RepartitionExec::try_new(
341 datasource,
342 datafusion::physical_plan::Partitioning::RoundRobinBatch(4),
343 )?);
344
345 let group_by = PhysicalGroupBy::new_single(vec![]);
346 let aggregate = Arc::new(AggregateExec::try_new(
347 AggregateMode::Partial,
348 group_by,
349 vec![],
350 vec![],
351 repartition,
352 schema.clone(),
353 )?);
354
355 let result = apply_optimizer(aggregate);
356
357 assert!(result.starts_with("LiquidCacheClientExec"));
358 assert!(result.contains("AggregateExec: mode=Partial"));
359 assert!(result.contains("RepartitionExec"));
360
361 Ok(())
362 }
363
364 #[test]
365 fn test_non_pushable_aggregate() -> Result<()> {
366 let schema = create_test_schema();
369 let datasource = create_datasource_exec(schema.clone());
370
371 let group_by = PhysicalGroupBy::new_single(vec![]);
372
373 let aggregate = Arc::new(AggregateExec::try_new(
374 AggregateMode::Final,
375 group_by,
376 vec![],
377 vec![],
378 datasource,
379 schema.clone(),
380 )?);
381
382 let result = apply_optimizer(aggregate);
383
384 let parts: Vec<&str> = result.split("LiquidCacheClientExec").collect();
385 assert!(parts.len() > 1);
386
387 let higher_layers = parts[0];
388 assert!(higher_layers.contains("AggregateExec: mode=Final"));
389 let lower_layers = parts[1];
390 assert!(lower_layers.contains("DataSourceExec"));
391
392 Ok(())
393 }
394}