liquid_cache_client/
optimizer.rs

1use std::{collections::HashMap, sync::Arc};
2
3use datafusion::{
4    config::ConfigOptions, datasource::source::DataSourceExec, error::Result,
5    execution::object_store::ObjectStoreUrl, physical_optimizer::PhysicalOptimizerRule,
6    physical_plan::ExecutionPlan, physical_plan::aggregates::AggregateExec,
7    physical_plan::aggregates::AggregateMode, physical_plan::repartition::RepartitionExec,
8};
9use liquid_cache_common::CacheMode;
10
11use crate::client_exec::LiquidCacheClientExec;
12
13/// PushdownOptimizer is a physical optimizer rule that pushes down filters to the liquid cache server.
14#[derive(Debug)]
15pub struct PushdownOptimizer {
16    cache_server: String,
17    cache_mode: CacheMode,
18    object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
19}
20
21impl PushdownOptimizer {
22    /// Create a new PushdownOptimizer
23    ///
24    /// # Arguments
25    ///
26    /// * `cache_server` - The address of the liquid cache server
27    /// * `cache_mode` - The cache mode to use
28    ///
29    /// # Returns
30    ///
31    pub fn new(
32        cache_server: String,
33        cache_mode: CacheMode,
34        object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
35    ) -> Self {
36        Self {
37            cache_server,
38            cache_mode,
39            object_stores,
40        }
41    }
42
43    /// Apply the optimization by finding nodes to push down and wrapping them
44    fn optimize_plan(&self, plan: Arc<dyn ExecutionPlan>) -> Result<Arc<dyn ExecutionPlan>> {
45        // If this node is already a LiquidCacheClientExec, return it as is
46        if plan
47            .as_any()
48            .downcast_ref::<LiquidCacheClientExec>()
49            .is_some()
50        {
51            return Ok(plan);
52        }
53
54        // Find the candidate to push down in this branch of the tree
55        if let Some(candidate) = find_pushdown_candidate(&plan) {
56            // If the current node is the one to be pushed down, wrap it
57            if Arc::ptr_eq(&plan, &candidate) {
58                return Ok(Arc::new(LiquidCacheClientExec::new(
59                    plan,
60                    self.cache_server.clone(),
61                    self.cache_mode,
62                    self.object_stores.clone(),
63                )));
64            }
65        }
66
67        // Otherwise, recurse into children
68        let mut new_children = Vec::with_capacity(plan.children().len());
69        let mut children_changed = false;
70
71        for child in plan.children() {
72            let new_child = self.optimize_plan(child.clone())?;
73            if !Arc::ptr_eq(child, &new_child) {
74                children_changed = true;
75            }
76            new_children.push(new_child);
77        }
78
79        // If any children were changed, create a new plan with the updated children
80        if children_changed {
81            plan.with_new_children(new_children)
82        } else {
83            Ok(plan)
84        }
85    }
86}
87
88/// Find the highest pushable node
89fn find_pushdown_candidate(plan: &Arc<dyn ExecutionPlan>) -> Option<Arc<dyn ExecutionPlan>> {
90    // Check if this node is already a LiquidCacheClientExec to avoid redundant wrapping
91    if plan
92        .as_any()
93        .downcast_ref::<LiquidCacheClientExec>()
94        .is_some()
95    {
96        return None;
97    }
98
99    let plan_any = plan.as_any();
100
101    // If we have an AggregateExec (partial, no group by) with a pushable child (direct or through RepartitionExec), push it down
102    if let Some(agg_exec) = plan_any.downcast_ref::<AggregateExec>()
103        && matches!(agg_exec.mode(), AggregateMode::Partial)
104        && agg_exec.group_expr().is_empty()
105    {
106        let child = agg_exec.input();
107
108        // Check if child is DataSourceExec or RepartitionExec->DataSourceExec
109        if child.as_any().downcast_ref::<DataSourceExec>().is_some() {
110            return Some(plan.clone());
111        } else if let Some(repart) = child.as_any().downcast_ref::<RepartitionExec>()
112            && let Some(repart_child) = repart.children().first()
113            && repart_child
114                .as_any()
115                .downcast_ref::<DataSourceExec>()
116                .is_some()
117        {
118            return Some(plan.clone());
119        }
120    }
121
122    // If we have a RepartitionExec with a DataSourceExec child, push it down
123    if let Some(repart_exec) = plan_any.downcast_ref::<RepartitionExec>()
124        && let Some(child) = repart_exec.children().first()
125        && child.as_any().downcast_ref::<DataSourceExec>().is_some()
126    {
127        return Some(plan.clone());
128    }
129
130    // If this is a DataSourceExec, push it down
131    if plan_any.downcast_ref::<DataSourceExec>().is_some() {
132        return Some(plan.clone());
133    }
134
135    // Otherwise, recurse into children looking for pushdown candidates
136    for child in plan.children() {
137        if let Some(candidate) = find_pushdown_candidate(child) {
138            return Some(candidate);
139        }
140    }
141
142    None
143}
144
145impl PhysicalOptimizerRule for PushdownOptimizer {
146    fn optimize(
147        &self,
148        plan: Arc<dyn ExecutionPlan>,
149        _config: &ConfigOptions,
150    ) -> Result<Arc<dyn ExecutionPlan>> {
151        self.optimize_plan(plan)
152    }
153
154    fn name(&self) -> &str {
155        "PushdownOptimizer"
156    }
157
158    fn schema_check(&self) -> bool {
159        true
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use std::sync::Arc;
166
167    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
168    use datafusion::{
169        config::ConfigOptions,
170        datasource::memory::MemorySourceConfig,
171        error::Result,
172        execution::SessionStateBuilder,
173        physical_plan::{
174            ExecutionPlan,
175            aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy},
176            display::DisplayableExecutionPlan,
177            repartition::RepartitionExec,
178        },
179        prelude::{SessionConfig, SessionContext},
180    };
181    use liquid_cache_common::CacheMode;
182
183    use super::*;
184
185    async fn create_session_context() -> SessionContext {
186        let mut config = SessionConfig::from_env().unwrap();
187        config.options_mut().execution.parquet.pushdown_filters = true;
188        let builder = SessionStateBuilder::new()
189            .with_config(config)
190            .with_default_features()
191            .with_physical_optimizer_rule(Arc::new(PushdownOptimizer::new(
192                "localhost:15214".to_string(),
193                CacheMode::Liquid,
194                vec![],
195            )));
196        let state = builder.build();
197        let ctx = SessionContext::new_with_state(state);
198        ctx.register_parquet(
199            "nano_hits",
200            "../../examples/nano_hits.parquet",
201            Default::default(),
202        )
203        .await
204        .unwrap();
205        ctx
206    }
207
208    #[tokio::test]
209    async fn test_plan_rewrite() {
210        let ctx = create_session_context().await;
211        let df = ctx
212            .sql("SELECT \"URL\" FROM nano_hits WHERE \"URL\" like 'https://%' limit 10")
213            .await
214            .unwrap();
215        let plan = df.create_physical_plan().await.unwrap();
216        let display_plan = DisplayableExecutionPlan::new(plan.as_ref());
217        let plan_str = display_plan.indent(false).to_string();
218
219        assert!(plan_str.contains("LiquidCacheClientExec"));
220        assert!(plan_str.contains("DataSourceExec"));
221    }
222
223    #[tokio::test]
224    async fn test_aggregate_pushdown() {
225        let ctx = create_session_context().await;
226
227        let df = ctx
228            .sql("SELECT MAX(\"URL\") FROM nano_hits WHERE \"URL\" like 'https://%'")
229            .await
230            .unwrap();
231        let plan = df.create_physical_plan().await.unwrap();
232        let display_plan = DisplayableExecutionPlan::new(plan.as_ref());
233        let plan_str = display_plan.indent(false).to_string();
234
235        println!("Plan: {}", plan_str);
236
237        // With the top-down approach, the LiquidCacheClientExec should contain:
238        // 1. The AggregateExec with mode=Partial
239        // 2. Any RepartitionExec below that
240        // 3. The DataSourceExec at the bottom
241
242        // Verify that AggregateExec: mode=Partial is inside the LiquidCacheClientExec
243        assert!(plan_str.contains("LiquidCacheClientExec"));
244
245        let parts: Vec<&str> = plan_str.split("LiquidCacheClientExec").collect();
246        assert!(parts.len() > 1);
247
248        let higher_layers = parts[0];
249        let pushed_down = parts[1];
250
251        assert!(higher_layers.contains("AggregateExec: mode=Final"));
252        assert!(pushed_down.contains("AggregateExec: mode=Partial"));
253        assert!(pushed_down.contains("DataSourceExec"));
254    }
255
256    // Create a test schema for our mock plans
257    fn create_test_schema() -> SchemaRef {
258        Arc::new(Schema::new(vec![
259            Field::new("c1", DataType::Int32, true),
260            Field::new("c2", DataType::Utf8, true),
261            Field::new("c3", DataType::Float64, true),
262        ]))
263    }
264
265    // Mock DataSourceExec that we can use in our tests
266    fn create_datasource_exec(schema: SchemaRef) -> Arc<dyn ExecutionPlan> {
267        Arc::new(DataSourceExec::new(Arc::new(
268            MemorySourceConfig::try_new(&[vec![]], schema, None).unwrap(),
269        )))
270    }
271
272    // Apply the PushdownOptimizer to a plan and get the result as a string for comparison
273    fn apply_optimizer(plan: Arc<dyn ExecutionPlan>) -> String {
274        let optimizer =
275            PushdownOptimizer::new("localhost:15214".to_string(), CacheMode::Liquid, vec![]);
276
277        let optimized = optimizer.optimize(plan, &ConfigOptions::default()).unwrap();
278        let display_plan = DisplayableExecutionPlan::new(optimized.as_ref());
279        display_plan.indent(false).to_string()
280    }
281
282    #[test]
283    fn test_simple_datasource_pushdown() -> Result<()> {
284        let schema = create_test_schema();
285        let datasource = create_datasource_exec(schema);
286        let result = apply_optimizer(datasource);
287        assert!(result.starts_with("LiquidCacheClientExec"));
288        Ok(())
289    }
290
291    #[test]
292    fn test_repartition_datasource_pushdown() -> Result<()> {
293        let schema = create_test_schema();
294        let datasource = create_datasource_exec(schema);
295        let repartition = Arc::new(RepartitionExec::try_new(
296            datasource,
297            datafusion::physical_plan::Partitioning::RoundRobinBatch(4),
298        )?);
299
300        let result = apply_optimizer(repartition);
301
302        assert!(result.starts_with("LiquidCacheClientExec"));
303        assert!(result.contains("RepartitionExec"));
304
305        Ok(())
306    }
307
308    #[test]
309    fn test_partial_aggregate_pushdown() -> Result<()> {
310        // Create an AggregateExec (Partial, no group by) -> DataSourceExec plan
311        let schema = create_test_schema();
312        let datasource = create_datasource_exec(schema.clone());
313
314        let group_by = PhysicalGroupBy::new_single(vec![]);
315
316        let aggregate = Arc::new(AggregateExec::try_new(
317            AggregateMode::Partial,
318            group_by,
319            vec![],
320            vec![],
321            datasource,
322            schema.clone(),
323        )?);
324
325        let result = apply_optimizer(aggregate);
326
327        assert!(result.starts_with("LiquidCacheClientExec"));
328        assert!(result.contains("AggregateExec: mode=Partial"));
329
330        Ok(())
331    }
332
333    #[test]
334    fn test_aggregate_with_repartition_pushdown() -> Result<()> {
335        // Create an AggregateExec (Partial, no group by) -> RepartitionExec -> DataSourceExec plan
336        let schema = create_test_schema();
337        let datasource = create_datasource_exec(schema.clone());
338
339        let repartition = Arc::new(RepartitionExec::try_new(
340            datasource,
341            datafusion::physical_plan::Partitioning::RoundRobinBatch(4),
342        )?);
343
344        let group_by = PhysicalGroupBy::new_single(vec![]);
345        let aggregate = Arc::new(AggregateExec::try_new(
346            AggregateMode::Partial,
347            group_by,
348            vec![],
349            vec![],
350            repartition,
351            schema.clone(),
352        )?);
353
354        let result = apply_optimizer(aggregate);
355
356        assert!(result.starts_with("LiquidCacheClientExec"));
357        assert!(result.contains("AggregateExec: mode=Partial"));
358        assert!(result.contains("RepartitionExec"));
359
360        Ok(())
361    }
362
363    #[test]
364    fn test_non_pushable_aggregate() -> Result<()> {
365        // Create an AggregateExec (Final, no group by) -> DataSourceExec plan
366        // This should not push down the AggregateExec
367        let schema = create_test_schema();
368        let datasource = create_datasource_exec(schema.clone());
369
370        let group_by = PhysicalGroupBy::new_single(vec![]);
371
372        let aggregate = Arc::new(AggregateExec::try_new(
373            AggregateMode::Final,
374            group_by,
375            vec![],
376            vec![],
377            datasource,
378            schema.clone(),
379        )?);
380
381        let result = apply_optimizer(aggregate);
382
383        let parts: Vec<&str> = result.split("LiquidCacheClientExec").collect();
384        assert!(parts.len() > 1);
385
386        let higher_layers = parts[0];
387        assert!(higher_layers.contains("AggregateExec: mode=Final"));
388        let lower_layers = parts[1];
389        assert!(lower_layers.contains("DataSourceExec"));
390
391        Ok(())
392    }
393}