liquid_cache_client/
optimizer.rs

1use std::{collections::HashMap, sync::Arc};
2
3use datafusion::{
4    config::ConfigOptions, datasource::source::DataSourceExec, error::Result,
5    execution::object_store::ObjectStoreUrl, physical_optimizer::PhysicalOptimizerRule,
6    physical_plan::ExecutionPlan, physical_plan::aggregates::AggregateExec,
7    physical_plan::aggregates::AggregateMode, physical_plan::repartition::RepartitionExec,
8};
9use liquid_cache_common::CacheMode;
10
11use crate::client_exec::LiquidCacheClientExec;
12
13/// PushdownOptimizer is a physical optimizer rule that pushes down filters to the liquid cache server.
14#[derive(Debug)]
15pub struct PushdownOptimizer {
16    cache_server: String,
17    cache_mode: CacheMode,
18    object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
19}
20
21impl PushdownOptimizer {
22    /// Create a new PushdownOptimizer
23    ///
24    /// # Arguments
25    ///
26    /// * `cache_server` - The address of the liquid cache server
27    /// * `cache_mode` - The cache mode to use
28    ///
29    /// # Returns
30    ///
31    pub fn new(
32        cache_server: String,
33        cache_mode: CacheMode,
34        object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
35    ) -> Self {
36        Self {
37            cache_server,
38            cache_mode,
39            object_stores,
40        }
41    }
42
43    /// Apply the optimization by finding nodes to push down and wrapping them
44    fn optimize_plan(&self, plan: Arc<dyn ExecutionPlan>) -> Result<Arc<dyn ExecutionPlan>> {
45        // If this node is already a LiquidCacheClientExec, return it as is
46        if plan
47            .as_any()
48            .downcast_ref::<LiquidCacheClientExec>()
49            .is_some()
50        {
51            return Ok(plan);
52        }
53
54        // Find the candidate to push down in this branch of the tree
55        if let Some(candidate) = find_pushdown_candidate(&plan) {
56            // If the current node is the one to be pushed down, wrap it
57            if Arc::ptr_eq(&plan, &candidate) {
58                return Ok(Arc::new(LiquidCacheClientExec::new(
59                    plan,
60                    self.cache_server.clone(),
61                    self.cache_mode,
62                    self.object_stores.clone(),
63                )));
64            }
65        }
66
67        // Otherwise, recurse into children
68        let mut new_children = Vec::with_capacity(plan.children().len());
69        let mut children_changed = false;
70
71        for child in plan.children() {
72            let new_child = self.optimize_plan(child.clone())?;
73            if !Arc::ptr_eq(child, &new_child) {
74                children_changed = true;
75            }
76            new_children.push(new_child);
77        }
78
79        // If any children were changed, create a new plan with the updated children
80        if children_changed {
81            plan.with_new_children(new_children)
82        } else {
83            Ok(plan)
84        }
85    }
86}
87
88/// Find the highest pushable node
89fn find_pushdown_candidate(plan: &Arc<dyn ExecutionPlan>) -> Option<Arc<dyn ExecutionPlan>> {
90    // Check if this node is already a LiquidCacheClientExec to avoid redundant wrapping
91    if plan
92        .as_any()
93        .downcast_ref::<LiquidCacheClientExec>()
94        .is_some()
95    {
96        return None;
97    }
98
99    let plan_any = plan.as_any();
100
101    // If we have an AggregateExec (partial, no group by) with a pushable child (direct or through RepartitionExec), push it down
102    if let Some(agg_exec) = plan_any.downcast_ref::<AggregateExec>() {
103        if matches!(agg_exec.mode(), AggregateMode::Partial) && agg_exec.group_expr().is_empty() {
104            let child = agg_exec.input();
105
106            // Check if child is DataSourceExec or RepartitionExec->DataSourceExec
107            if child.as_any().downcast_ref::<DataSourceExec>().is_some() {
108                return Some(plan.clone());
109            } else if let Some(repart) = child.as_any().downcast_ref::<RepartitionExec>() {
110                if let Some(repart_child) = repart.children().first() {
111                    if repart_child
112                        .as_any()
113                        .downcast_ref::<DataSourceExec>()
114                        .is_some()
115                    {
116                        return Some(plan.clone());
117                    }
118                }
119            }
120        }
121    }
122
123    // If we have a RepartitionExec with a DataSourceExec child, push it down
124    if let Some(repart_exec) = plan_any.downcast_ref::<RepartitionExec>() {
125        if let Some(child) = repart_exec.children().first() {
126            if child.as_any().downcast_ref::<DataSourceExec>().is_some() {
127                return Some(plan.clone());
128            }
129        }
130    }
131
132    // If this is a DataSourceExec, push it down
133    if plan_any.downcast_ref::<DataSourceExec>().is_some() {
134        return Some(plan.clone());
135    }
136
137    // Otherwise, recurse into children looking for pushdown candidates
138    for child in plan.children() {
139        if let Some(candidate) = find_pushdown_candidate(child) {
140            return Some(candidate);
141        }
142    }
143
144    None
145}
146
147impl PhysicalOptimizerRule for PushdownOptimizer {
148    fn optimize(
149        &self,
150        plan: Arc<dyn ExecutionPlan>,
151        _config: &ConfigOptions,
152    ) -> Result<Arc<dyn ExecutionPlan>> {
153        self.optimize_plan(plan)
154    }
155
156    fn name(&self) -> &str {
157        "PushdownOptimizer"
158    }
159
160    fn schema_check(&self) -> bool {
161        true
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use std::sync::Arc;
168
169    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
170    use datafusion::{
171        config::ConfigOptions,
172        datasource::memory::MemorySourceConfig,
173        error::Result,
174        execution::SessionStateBuilder,
175        physical_plan::{
176            ExecutionPlan,
177            aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy},
178            display::DisplayableExecutionPlan,
179            repartition::RepartitionExec,
180        },
181        prelude::{SessionConfig, SessionContext},
182    };
183    use liquid_cache_common::CacheMode;
184
185    use super::*;
186
187    async fn create_session_context() -> SessionContext {
188        let mut config = SessionConfig::from_env().unwrap();
189        config.options_mut().execution.parquet.pushdown_filters = true;
190        let builder = SessionStateBuilder::new()
191            .with_config(config)
192            .with_default_features()
193            .with_physical_optimizer_rule(Arc::new(PushdownOptimizer::new(
194                "localhost:50051".to_string(),
195                CacheMode::Liquid,
196                vec![],
197            )));
198        let state = builder.build();
199        let ctx = SessionContext::new_with_state(state);
200        ctx.register_parquet(
201            "nano_hits",
202            "../../examples/nano_hits.parquet",
203            Default::default(),
204        )
205        .await
206        .unwrap();
207        ctx
208    }
209
210    #[tokio::test]
211    async fn test_plan_rewrite() {
212        let ctx = create_session_context().await;
213        let df = ctx
214            .sql("SELECT \"URL\" FROM nano_hits WHERE \"URL\" like 'https://%' limit 10")
215            .await
216            .unwrap();
217        let plan = df.create_physical_plan().await.unwrap();
218        let display_plan = DisplayableExecutionPlan::new(plan.as_ref());
219        let plan_str = display_plan.indent(false).to_string();
220
221        assert!(plan_str.contains("LiquidCacheClientExec"));
222        assert!(plan_str.contains("DataSourceExec"));
223    }
224
225    #[tokio::test]
226    async fn test_aggregate_pushdown() {
227        let ctx = create_session_context().await;
228
229        let df = ctx
230            .sql("SELECT MAX(\"URL\") FROM nano_hits WHERE \"URL\" like 'https://%'")
231            .await
232            .unwrap();
233        let plan = df.create_physical_plan().await.unwrap();
234        let display_plan = DisplayableExecutionPlan::new(plan.as_ref());
235        let plan_str = display_plan.indent(false).to_string();
236
237        println!("Plan: {}", plan_str);
238
239        // With the top-down approach, the LiquidCacheClientExec should contain:
240        // 1. The AggregateExec with mode=Partial
241        // 2. Any RepartitionExec below that
242        // 3. The DataSourceExec at the bottom
243
244        // Verify that AggregateExec: mode=Partial is inside the LiquidCacheClientExec
245        assert!(plan_str.contains("LiquidCacheClientExec"));
246
247        let parts: Vec<&str> = plan_str.split("LiquidCacheClientExec").collect();
248        assert!(parts.len() > 1);
249
250        let higher_layers = parts[0];
251        let pushed_down = parts[1];
252
253        assert!(higher_layers.contains("AggregateExec: mode=Final"));
254        assert!(pushed_down.contains("AggregateExec: mode=Partial"));
255        assert!(pushed_down.contains("DataSourceExec"));
256    }
257
258    // Create a test schema for our mock plans
259    fn create_test_schema() -> SchemaRef {
260        Arc::new(Schema::new(vec![
261            Field::new("c1", DataType::Int32, true),
262            Field::new("c2", DataType::Utf8, true),
263            Field::new("c3", DataType::Float64, true),
264        ]))
265    }
266
267    // Mock DataSourceExec that we can use in our tests
268    fn create_datasource_exec(schema: SchemaRef) -> Arc<dyn ExecutionPlan> {
269        Arc::new(DataSourceExec::new(Arc::new(
270            MemorySourceConfig::try_new(&[vec![]], schema, None).unwrap(),
271        )))
272    }
273
274    // Apply the PushdownOptimizer to a plan and get the result as a string for comparison
275    fn apply_optimizer(plan: Arc<dyn ExecutionPlan>) -> String {
276        let optimizer =
277            PushdownOptimizer::new("localhost:50051".to_string(), CacheMode::Liquid, vec![]);
278
279        let optimized = optimizer.optimize(plan, &ConfigOptions::default()).unwrap();
280        let display_plan = DisplayableExecutionPlan::new(optimized.as_ref());
281        display_plan.indent(false).to_string()
282    }
283
284    #[test]
285    fn test_simple_datasource_pushdown() -> Result<()> {
286        let schema = create_test_schema();
287        let datasource = create_datasource_exec(schema);
288        let result = apply_optimizer(datasource);
289        assert!(result.starts_with("LiquidCacheClientExec"));
290        Ok(())
291    }
292
293    #[test]
294    fn test_repartition_datasource_pushdown() -> Result<()> {
295        let schema = create_test_schema();
296        let datasource = create_datasource_exec(schema);
297        let repartition = Arc::new(RepartitionExec::try_new(
298            datasource,
299            datafusion::physical_plan::Partitioning::RoundRobinBatch(4),
300        )?);
301
302        let result = apply_optimizer(repartition);
303
304        assert!(result.starts_with("LiquidCacheClientExec"));
305        assert!(result.contains("RepartitionExec"));
306
307        Ok(())
308    }
309
310    #[test]
311    fn test_partial_aggregate_pushdown() -> Result<()> {
312        // Create an AggregateExec (Partial, no group by) -> DataSourceExec plan
313        let schema = create_test_schema();
314        let datasource = create_datasource_exec(schema.clone());
315
316        let group_by = PhysicalGroupBy::new_single(vec![]);
317
318        let aggregate = Arc::new(AggregateExec::try_new(
319            AggregateMode::Partial,
320            group_by,
321            vec![],
322            vec![],
323            datasource,
324            schema.clone(),
325        )?);
326
327        let result = apply_optimizer(aggregate);
328
329        assert!(result.starts_with("LiquidCacheClientExec"));
330        assert!(result.contains("AggregateExec: mode=Partial"));
331
332        Ok(())
333    }
334
335    #[test]
336    fn test_aggregate_with_repartition_pushdown() -> Result<()> {
337        // Create an AggregateExec (Partial, no group by) -> RepartitionExec -> DataSourceExec plan
338        let schema = create_test_schema();
339        let datasource = create_datasource_exec(schema.clone());
340
341        let repartition = Arc::new(RepartitionExec::try_new(
342            datasource,
343            datafusion::physical_plan::Partitioning::RoundRobinBatch(4),
344        )?);
345
346        let group_by = PhysicalGroupBy::new_single(vec![]);
347        let aggregate = Arc::new(AggregateExec::try_new(
348            AggregateMode::Partial,
349            group_by,
350            vec![],
351            vec![],
352            repartition,
353            schema.clone(),
354        )?);
355
356        let result = apply_optimizer(aggregate);
357
358        assert!(result.starts_with("LiquidCacheClientExec"));
359        assert!(result.contains("AggregateExec: mode=Partial"));
360        assert!(result.contains("RepartitionExec"));
361
362        Ok(())
363    }
364
365    #[test]
366    fn test_non_pushable_aggregate() -> Result<()> {
367        // Create an AggregateExec (Final, no group by) -> DataSourceExec plan
368        // This should not push down the AggregateExec
369        let schema = create_test_schema();
370        let datasource = create_datasource_exec(schema.clone());
371
372        let group_by = PhysicalGroupBy::new_single(vec![]);
373
374        let aggregate = Arc::new(AggregateExec::try_new(
375            AggregateMode::Final,
376            group_by,
377            vec![],
378            vec![],
379            datasource,
380            schema.clone(),
381        )?);
382
383        let result = apply_optimizer(aggregate);
384
385        let parts: Vec<&str> = result.split("LiquidCacheClientExec").collect();
386        assert!(parts.len() > 1);
387
388        let higher_layers = parts[0];
389        assert!(higher_layers.contains("AggregateExec: mode=Final"));
390        let lower_layers = parts[1];
391        assert!(lower_layers.contains("DataSourceExec"));
392
393        Ok(())
394    }
395}