1use std::{collections::HashMap, sync::Arc};
2
3use datafusion::{
4 config::ConfigOptions, datasource::source::DataSourceExec, error::Result,
5 execution::object_store::ObjectStoreUrl, physical_optimizer::PhysicalOptimizerRule,
6 physical_plan::ExecutionPlan, physical_plan::aggregates::AggregateExec,
7 physical_plan::aggregates::AggregateMode, physical_plan::repartition::RepartitionExec,
8};
9use liquid_cache_common::CacheMode;
10
11use crate::client_exec::LiquidCacheClientExec;
12
13#[derive(Debug)]
15pub struct PushdownOptimizer {
16 cache_server: String,
17 cache_mode: CacheMode,
18 object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
19}
20
21impl PushdownOptimizer {
22 pub fn new(
32 cache_server: String,
33 cache_mode: CacheMode,
34 object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
35 ) -> Self {
36 Self {
37 cache_server,
38 cache_mode,
39 object_stores,
40 }
41 }
42
43 fn optimize_plan(&self, plan: Arc<dyn ExecutionPlan>) -> Result<Arc<dyn ExecutionPlan>> {
45 if plan
47 .as_any()
48 .downcast_ref::<LiquidCacheClientExec>()
49 .is_some()
50 {
51 return Ok(plan);
52 }
53
54 if let Some(candidate) = find_pushdown_candidate(&plan) {
56 if Arc::ptr_eq(&plan, &candidate) {
58 return Ok(Arc::new(LiquidCacheClientExec::new(
59 plan,
60 self.cache_server.clone(),
61 self.cache_mode,
62 self.object_stores.clone(),
63 )));
64 }
65 }
66
67 let mut new_children = Vec::with_capacity(plan.children().len());
69 let mut children_changed = false;
70
71 for child in plan.children() {
72 let new_child = self.optimize_plan(child.clone())?;
73 if !Arc::ptr_eq(child, &new_child) {
74 children_changed = true;
75 }
76 new_children.push(new_child);
77 }
78
79 if children_changed {
81 plan.with_new_children(new_children)
82 } else {
83 Ok(plan)
84 }
85 }
86}
87
88fn find_pushdown_candidate(plan: &Arc<dyn ExecutionPlan>) -> Option<Arc<dyn ExecutionPlan>> {
90 if plan
92 .as_any()
93 .downcast_ref::<LiquidCacheClientExec>()
94 .is_some()
95 {
96 return None;
97 }
98
99 let plan_any = plan.as_any();
100
101 if let Some(agg_exec) = plan_any.downcast_ref::<AggregateExec>()
103 && matches!(agg_exec.mode(), AggregateMode::Partial)
104 && agg_exec.group_expr().is_empty()
105 {
106 let child = agg_exec.input();
107
108 if child.as_any().downcast_ref::<DataSourceExec>().is_some() {
110 return Some(plan.clone());
111 } else if let Some(repart) = child.as_any().downcast_ref::<RepartitionExec>()
112 && let Some(repart_child) = repart.children().first()
113 && repart_child
114 .as_any()
115 .downcast_ref::<DataSourceExec>()
116 .is_some()
117 {
118 return Some(plan.clone());
119 }
120 }
121
122 if let Some(repart_exec) = plan_any.downcast_ref::<RepartitionExec>()
124 && let Some(child) = repart_exec.children().first()
125 && child.as_any().downcast_ref::<DataSourceExec>().is_some()
126 {
127 return Some(plan.clone());
128 }
129
130 if plan_any.downcast_ref::<DataSourceExec>().is_some() {
132 return Some(plan.clone());
133 }
134
135 for child in plan.children() {
137 if let Some(candidate) = find_pushdown_candidate(child) {
138 return Some(candidate);
139 }
140 }
141
142 None
143}
144
145impl PhysicalOptimizerRule for PushdownOptimizer {
146 fn optimize(
147 &self,
148 plan: Arc<dyn ExecutionPlan>,
149 _config: &ConfigOptions,
150 ) -> Result<Arc<dyn ExecutionPlan>> {
151 self.optimize_plan(plan)
152 }
153
154 fn name(&self) -> &str {
155 "PushdownOptimizer"
156 }
157
158 fn schema_check(&self) -> bool {
159 true
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use std::sync::Arc;
166
167 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
168 use datafusion::{
169 config::ConfigOptions,
170 datasource::memory::MemorySourceConfig,
171 error::Result,
172 execution::SessionStateBuilder,
173 physical_plan::{
174 ExecutionPlan,
175 aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy},
176 display::DisplayableExecutionPlan,
177 repartition::RepartitionExec,
178 },
179 prelude::{SessionConfig, SessionContext},
180 };
181 use liquid_cache_common::CacheMode;
182
183 use super::*;
184
185 async fn create_session_context() -> SessionContext {
186 let mut config = SessionConfig::from_env().unwrap();
187 config.options_mut().execution.parquet.pushdown_filters = true;
188 let builder = SessionStateBuilder::new()
189 .with_config(config)
190 .with_default_features()
191 .with_physical_optimizer_rule(Arc::new(PushdownOptimizer::new(
192 "localhost:15214".to_string(),
193 CacheMode::Liquid,
194 vec![],
195 )));
196 let state = builder.build();
197 let ctx = SessionContext::new_with_state(state);
198 ctx.register_parquet(
199 "nano_hits",
200 "../../examples/nano_hits.parquet",
201 Default::default(),
202 )
203 .await
204 .unwrap();
205 ctx
206 }
207
208 #[tokio::test]
209 async fn test_plan_rewrite() {
210 let ctx = create_session_context().await;
211 let df = ctx
212 .sql("SELECT \"URL\" FROM nano_hits WHERE \"URL\" like 'https://%' limit 10")
213 .await
214 .unwrap();
215 let plan = df.create_physical_plan().await.unwrap();
216 let display_plan = DisplayableExecutionPlan::new(plan.as_ref());
217 let plan_str = display_plan.indent(false).to_string();
218
219 assert!(plan_str.contains("LiquidCacheClientExec"));
220 assert!(plan_str.contains("DataSourceExec"));
221 }
222
223 #[tokio::test]
224 async fn test_aggregate_pushdown() {
225 let ctx = create_session_context().await;
226
227 let df = ctx
228 .sql("SELECT MAX(\"URL\") FROM nano_hits WHERE \"URL\" like 'https://%'")
229 .await
230 .unwrap();
231 let plan = df.create_physical_plan().await.unwrap();
232 let display_plan = DisplayableExecutionPlan::new(plan.as_ref());
233 let plan_str = display_plan.indent(false).to_string();
234
235 println!("Plan: {}", plan_str);
236
237 assert!(plan_str.contains("LiquidCacheClientExec"));
244
245 let parts: Vec<&str> = plan_str.split("LiquidCacheClientExec").collect();
246 assert!(parts.len() > 1);
247
248 let higher_layers = parts[0];
249 let pushed_down = parts[1];
250
251 assert!(higher_layers.contains("AggregateExec: mode=Final"));
252 assert!(pushed_down.contains("AggregateExec: mode=Partial"));
253 assert!(pushed_down.contains("DataSourceExec"));
254 }
255
256 fn create_test_schema() -> SchemaRef {
258 Arc::new(Schema::new(vec![
259 Field::new("c1", DataType::Int32, true),
260 Field::new("c2", DataType::Utf8, true),
261 Field::new("c3", DataType::Float64, true),
262 ]))
263 }
264
265 fn create_datasource_exec(schema: SchemaRef) -> Arc<dyn ExecutionPlan> {
267 Arc::new(DataSourceExec::new(Arc::new(
268 MemorySourceConfig::try_new(&[vec![]], schema, None).unwrap(),
269 )))
270 }
271
272 fn apply_optimizer(plan: Arc<dyn ExecutionPlan>) -> String {
274 let optimizer =
275 PushdownOptimizer::new("localhost:15214".to_string(), CacheMode::Liquid, vec![]);
276
277 let optimized = optimizer.optimize(plan, &ConfigOptions::default()).unwrap();
278 let display_plan = DisplayableExecutionPlan::new(optimized.as_ref());
279 display_plan.indent(false).to_string()
280 }
281
282 #[test]
283 fn test_simple_datasource_pushdown() -> Result<()> {
284 let schema = create_test_schema();
285 let datasource = create_datasource_exec(schema);
286 let result = apply_optimizer(datasource);
287 assert!(result.starts_with("LiquidCacheClientExec"));
288 Ok(())
289 }
290
291 #[test]
292 fn test_repartition_datasource_pushdown() -> Result<()> {
293 let schema = create_test_schema();
294 let datasource = create_datasource_exec(schema);
295 let repartition = Arc::new(RepartitionExec::try_new(
296 datasource,
297 datafusion::physical_plan::Partitioning::RoundRobinBatch(4),
298 )?);
299
300 let result = apply_optimizer(repartition);
301
302 assert!(result.starts_with("LiquidCacheClientExec"));
303 assert!(result.contains("RepartitionExec"));
304
305 Ok(())
306 }
307
308 #[test]
309 fn test_partial_aggregate_pushdown() -> Result<()> {
310 let schema = create_test_schema();
312 let datasource = create_datasource_exec(schema.clone());
313
314 let group_by = PhysicalGroupBy::new_single(vec![]);
315
316 let aggregate = Arc::new(AggregateExec::try_new(
317 AggregateMode::Partial,
318 group_by,
319 vec![],
320 vec![],
321 datasource,
322 schema.clone(),
323 )?);
324
325 let result = apply_optimizer(aggregate);
326
327 assert!(result.starts_with("LiquidCacheClientExec"));
328 assert!(result.contains("AggregateExec: mode=Partial"));
329
330 Ok(())
331 }
332
333 #[test]
334 fn test_aggregate_with_repartition_pushdown() -> Result<()> {
335 let schema = create_test_schema();
337 let datasource = create_datasource_exec(schema.clone());
338
339 let repartition = Arc::new(RepartitionExec::try_new(
340 datasource,
341 datafusion::physical_plan::Partitioning::RoundRobinBatch(4),
342 )?);
343
344 let group_by = PhysicalGroupBy::new_single(vec![]);
345 let aggregate = Arc::new(AggregateExec::try_new(
346 AggregateMode::Partial,
347 group_by,
348 vec![],
349 vec![],
350 repartition,
351 schema.clone(),
352 )?);
353
354 let result = apply_optimizer(aggregate);
355
356 assert!(result.starts_with("LiquidCacheClientExec"));
357 assert!(result.contains("AggregateExec: mode=Partial"));
358 assert!(result.contains("RepartitionExec"));
359
360 Ok(())
361 }
362
363 #[test]
364 fn test_non_pushable_aggregate() -> Result<()> {
365 let schema = create_test_schema();
368 let datasource = create_datasource_exec(schema.clone());
369
370 let group_by = PhysicalGroupBy::new_single(vec![]);
371
372 let aggregate = Arc::new(AggregateExec::try_new(
373 AggregateMode::Final,
374 group_by,
375 vec![],
376 vec![],
377 datasource,
378 schema.clone(),
379 )?);
380
381 let result = apply_optimizer(aggregate);
382
383 let parts: Vec<&str> = result.split("LiquidCacheClientExec").collect();
384 assert!(parts.len() > 1);
385
386 let higher_layers = parts[0];
387 assert!(higher_layers.contains("AggregateExec: mode=Final"));
388 let lower_layers = parts[1];
389 assert!(lower_layers.contains("DataSourceExec"));
390
391 Ok(())
392 }
393}