1use std::{collections::HashMap, sync::Arc};
2
3use datafusion::{
4 config::ConfigOptions, datasource::source::DataSourceExec, error::Result,
5 execution::object_store::ObjectStoreUrl, physical_optimizer::PhysicalOptimizerRule,
6 physical_plan::ExecutionPlan, physical_plan::aggregates::AggregateExec,
7 physical_plan::aggregates::AggregateMode, physical_plan::repartition::RepartitionExec,
8};
9use liquid_cache_common::CacheMode;
10
11use crate::client_exec::LiquidCacheClientExec;
12
13#[derive(Debug)]
15pub struct PushdownOptimizer {
16 cache_server: String,
17 cache_mode: CacheMode,
18 object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
19}
20
21impl PushdownOptimizer {
22 pub fn new(
32 cache_server: String,
33 cache_mode: CacheMode,
34 object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
35 ) -> Self {
36 Self {
37 cache_server,
38 cache_mode,
39 object_stores,
40 }
41 }
42
43 fn optimize_plan(&self, plan: Arc<dyn ExecutionPlan>) -> Result<Arc<dyn ExecutionPlan>> {
45 if plan
47 .as_any()
48 .downcast_ref::<LiquidCacheClientExec>()
49 .is_some()
50 {
51 return Ok(plan);
52 }
53
54 if let Some(candidate) = find_pushdown_candidate(&plan) {
56 if Arc::ptr_eq(&plan, &candidate) {
58 return Ok(Arc::new(LiquidCacheClientExec::new(
59 plan,
60 self.cache_server.clone(),
61 self.cache_mode,
62 self.object_stores.clone(),
63 )));
64 }
65 }
66
67 let mut new_children = Vec::with_capacity(plan.children().len());
69 let mut children_changed = false;
70
71 for child in plan.children() {
72 let new_child = self.optimize_plan(child.clone())?;
73 if !Arc::ptr_eq(child, &new_child) {
74 children_changed = true;
75 }
76 new_children.push(new_child);
77 }
78
79 if children_changed {
81 plan.with_new_children(new_children)
82 } else {
83 Ok(plan)
84 }
85 }
86}
87
88fn find_pushdown_candidate(plan: &Arc<dyn ExecutionPlan>) -> Option<Arc<dyn ExecutionPlan>> {
90 if plan
92 .as_any()
93 .downcast_ref::<LiquidCacheClientExec>()
94 .is_some()
95 {
96 return None;
97 }
98
99 let plan_any = plan.as_any();
100
101 if let Some(agg_exec) = plan_any.downcast_ref::<AggregateExec>() {
103 if matches!(agg_exec.mode(), AggregateMode::Partial) && agg_exec.group_expr().is_empty() {
104 let child = agg_exec.input();
105
106 if child.as_any().downcast_ref::<DataSourceExec>().is_some() {
108 return Some(plan.clone());
109 } else if let Some(repart) = child.as_any().downcast_ref::<RepartitionExec>() {
110 if let Some(repart_child) = repart.children().first() {
111 if repart_child
112 .as_any()
113 .downcast_ref::<DataSourceExec>()
114 .is_some()
115 {
116 return Some(plan.clone());
117 }
118 }
119 }
120 }
121 }
122
123 if let Some(repart_exec) = plan_any.downcast_ref::<RepartitionExec>() {
125 if let Some(child) = repart_exec.children().first() {
126 if child.as_any().downcast_ref::<DataSourceExec>().is_some() {
127 return Some(plan.clone());
128 }
129 }
130 }
131
132 if plan_any.downcast_ref::<DataSourceExec>().is_some() {
134 return Some(plan.clone());
135 }
136
137 for child in plan.children() {
139 if let Some(candidate) = find_pushdown_candidate(child) {
140 return Some(candidate);
141 }
142 }
143
144 None
145}
146
147impl PhysicalOptimizerRule for PushdownOptimizer {
148 fn optimize(
149 &self,
150 plan: Arc<dyn ExecutionPlan>,
151 _config: &ConfigOptions,
152 ) -> Result<Arc<dyn ExecutionPlan>> {
153 self.optimize_plan(plan)
154 }
155
156 fn name(&self) -> &str {
157 "PushdownOptimizer"
158 }
159
160 fn schema_check(&self) -> bool {
161 true
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use std::sync::Arc;
168
169 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
170 use datafusion::{
171 config::ConfigOptions,
172 datasource::memory::MemorySourceConfig,
173 error::Result,
174 execution::SessionStateBuilder,
175 physical_plan::{
176 ExecutionPlan,
177 aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy},
178 display::DisplayableExecutionPlan,
179 repartition::RepartitionExec,
180 },
181 prelude::{SessionConfig, SessionContext},
182 };
183 use liquid_cache_common::CacheMode;
184
185 use super::*;
186
187 async fn create_session_context() -> SessionContext {
188 let mut config = SessionConfig::from_env().unwrap();
189 config.options_mut().execution.parquet.pushdown_filters = true;
190 let builder = SessionStateBuilder::new()
191 .with_config(config)
192 .with_default_features()
193 .with_physical_optimizer_rule(Arc::new(PushdownOptimizer::new(
194 "localhost:50051".to_string(),
195 CacheMode::Liquid,
196 vec![],
197 )));
198 let state = builder.build();
199 let ctx = SessionContext::new_with_state(state);
200 ctx.register_parquet(
201 "nano_hits",
202 "../../examples/nano_hits.parquet",
203 Default::default(),
204 )
205 .await
206 .unwrap();
207 ctx
208 }
209
210 #[tokio::test]
211 async fn test_plan_rewrite() {
212 let ctx = create_session_context().await;
213 let df = ctx
214 .sql("SELECT \"URL\" FROM nano_hits WHERE \"URL\" like 'https://%' limit 10")
215 .await
216 .unwrap();
217 let plan = df.create_physical_plan().await.unwrap();
218 let display_plan = DisplayableExecutionPlan::new(plan.as_ref());
219 let plan_str = display_plan.indent(false).to_string();
220
221 assert!(plan_str.contains("LiquidCacheClientExec"));
222 assert!(plan_str.contains("DataSourceExec"));
223 }
224
225 #[tokio::test]
226 async fn test_aggregate_pushdown() {
227 let ctx = create_session_context().await;
228
229 let df = ctx
230 .sql("SELECT MAX(\"URL\") FROM nano_hits WHERE \"URL\" like 'https://%'")
231 .await
232 .unwrap();
233 let plan = df.create_physical_plan().await.unwrap();
234 let display_plan = DisplayableExecutionPlan::new(plan.as_ref());
235 let plan_str = display_plan.indent(false).to_string();
236
237 println!("Plan: {}", plan_str);
238
239 assert!(plan_str.contains("LiquidCacheClientExec"));
246
247 let parts: Vec<&str> = plan_str.split("LiquidCacheClientExec").collect();
248 assert!(parts.len() > 1);
249
250 let higher_layers = parts[0];
251 let pushed_down = parts[1];
252
253 assert!(higher_layers.contains("AggregateExec: mode=Final"));
254 assert!(pushed_down.contains("AggregateExec: mode=Partial"));
255 assert!(pushed_down.contains("DataSourceExec"));
256 }
257
258 fn create_test_schema() -> SchemaRef {
260 Arc::new(Schema::new(vec![
261 Field::new("c1", DataType::Int32, true),
262 Field::new("c2", DataType::Utf8, true),
263 Field::new("c3", DataType::Float64, true),
264 ]))
265 }
266
267 fn create_datasource_exec(schema: SchemaRef) -> Arc<dyn ExecutionPlan> {
269 Arc::new(DataSourceExec::new(Arc::new(
270 MemorySourceConfig::try_new(&[vec![]], schema, None).unwrap(),
271 )))
272 }
273
274 fn apply_optimizer(plan: Arc<dyn ExecutionPlan>) -> String {
276 let optimizer =
277 PushdownOptimizer::new("localhost:50051".to_string(), CacheMode::Liquid, vec![]);
278
279 let optimized = optimizer.optimize(plan, &ConfigOptions::default()).unwrap();
280 let display_plan = DisplayableExecutionPlan::new(optimized.as_ref());
281 display_plan.indent(false).to_string()
282 }
283
284 #[test]
285 fn test_simple_datasource_pushdown() -> Result<()> {
286 let schema = create_test_schema();
287 let datasource = create_datasource_exec(schema);
288 let result = apply_optimizer(datasource);
289 assert!(result.starts_with("LiquidCacheClientExec"));
290 Ok(())
291 }
292
293 #[test]
294 fn test_repartition_datasource_pushdown() -> Result<()> {
295 let schema = create_test_schema();
296 let datasource = create_datasource_exec(schema);
297 let repartition = Arc::new(RepartitionExec::try_new(
298 datasource,
299 datafusion::physical_plan::Partitioning::RoundRobinBatch(4),
300 )?);
301
302 let result = apply_optimizer(repartition);
303
304 assert!(result.starts_with("LiquidCacheClientExec"));
305 assert!(result.contains("RepartitionExec"));
306
307 Ok(())
308 }
309
310 #[test]
311 fn test_partial_aggregate_pushdown() -> Result<()> {
312 let schema = create_test_schema();
314 let datasource = create_datasource_exec(schema.clone());
315
316 let group_by = PhysicalGroupBy::new_single(vec![]);
317
318 let aggregate = Arc::new(AggregateExec::try_new(
319 AggregateMode::Partial,
320 group_by,
321 vec![],
322 vec![],
323 datasource,
324 schema.clone(),
325 )?);
326
327 let result = apply_optimizer(aggregate);
328
329 assert!(result.starts_with("LiquidCacheClientExec"));
330 assert!(result.contains("AggregateExec: mode=Partial"));
331
332 Ok(())
333 }
334
335 #[test]
336 fn test_aggregate_with_repartition_pushdown() -> Result<()> {
337 let schema = create_test_schema();
339 let datasource = create_datasource_exec(schema.clone());
340
341 let repartition = Arc::new(RepartitionExec::try_new(
342 datasource,
343 datafusion::physical_plan::Partitioning::RoundRobinBatch(4),
344 )?);
345
346 let group_by = PhysicalGroupBy::new_single(vec![]);
347 let aggregate = Arc::new(AggregateExec::try_new(
348 AggregateMode::Partial,
349 group_by,
350 vec![],
351 vec![],
352 repartition,
353 schema.clone(),
354 )?);
355
356 let result = apply_optimizer(aggregate);
357
358 assert!(result.starts_with("LiquidCacheClientExec"));
359 assert!(result.contains("AggregateExec: mode=Partial"));
360 assert!(result.contains("RepartitionExec"));
361
362 Ok(())
363 }
364
365 #[test]
366 fn test_non_pushable_aggregate() -> Result<()> {
367 let schema = create_test_schema();
370 let datasource = create_datasource_exec(schema.clone());
371
372 let group_by = PhysicalGroupBy::new_single(vec![]);
373
374 let aggregate = Arc::new(AggregateExec::try_new(
375 AggregateMode::Final,
376 group_by,
377 vec![],
378 vec![],
379 datasource,
380 schema.clone(),
381 )?);
382
383 let result = apply_optimizer(aggregate);
384
385 let parts: Vec<&str> = result.split("LiquidCacheClientExec").collect();
386 assert!(parts.len() > 1);
387
388 let higher_layers = parts[0];
389 assert!(higher_layers.contains("AggregateExec: mode=Final"));
390 let lower_layers = parts[1];
391 assert!(lower_layers.contains("DataSourceExec"));
392
393 Ok(())
394 }
395}