liquid_cache_client/
optimizer.rs

1use std::{collections::HashMap, sync::Arc};
2
3use datafusion::{
4    config::ConfigOptions, datasource::source::DataSourceExec, error::Result,
5    execution::object_store::ObjectStoreUrl, physical_optimizer::PhysicalOptimizerRule,
6    physical_plan::ExecutionPlan, physical_plan::aggregates::AggregateExec,
7    physical_plan::aggregates::AggregateMode, physical_plan::repartition::RepartitionExec,
8};
9use liquid_cache_common::CacheMode;
10
11use crate::client_exec::LiquidCacheClientExec;
12
13/// PushdownOptimizer is a physical optimizer rule that pushes down filters to the liquid cache server.
14#[derive(Debug)]
15pub struct PushdownOptimizer {
16    cache_server: String,
17    cache_mode: CacheMode,
18    object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
19}
20
21impl PushdownOptimizer {
22    /// Create a new PushdownOptimizer
23    ///
24    /// # Arguments
25    ///
26    /// * `cache_server` - The address of the liquid cache server
27    /// * `cache_mode` - The cache mode to use
28    ///
29    /// # Returns
30    ///
31    pub fn new(
32        cache_server: String,
33        cache_mode: CacheMode,
34        object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
35    ) -> Self {
36        Self {
37            cache_server,
38            cache_mode,
39            object_stores,
40        }
41    }
42
43    /// Apply the optimization by finding nodes to push down and wrapping them
44    fn optimize_plan(&self, plan: Arc<dyn ExecutionPlan>) -> Result<Arc<dyn ExecutionPlan>> {
45        // If this node is already a LiquidCacheClientExec, return it as is
46        if plan
47            .as_any()
48            .downcast_ref::<LiquidCacheClientExec>()
49            .is_some()
50        {
51            return Ok(plan);
52        }
53
54        // Find the candidate to push down in this branch of the tree
55        if let Some(candidate) = find_pushdown_candidate(&plan) {
56            // If the current node is the one to be pushed down, wrap it
57            if Arc::ptr_eq(&plan, &candidate) {
58                return Ok(Arc::new(LiquidCacheClientExec::new(
59                    plan,
60                    self.cache_server.clone(),
61                    self.cache_mode,
62                    self.object_stores.clone(),
63                )));
64            }
65        }
66
67        // Otherwise, recurse into children
68        let mut new_children = Vec::with_capacity(plan.children().len());
69        let mut children_changed = false;
70
71        for child in plan.children() {
72            let new_child = self.optimize_plan(child.clone())?;
73            if !Arc::ptr_eq(child, &new_child) {
74                children_changed = true;
75            }
76            new_children.push(new_child);
77        }
78
79        // If any children were changed, create a new plan with the updated children
80        if children_changed {
81            plan.with_new_children(new_children)
82        } else {
83            Ok(plan)
84        }
85    }
86}
87
88/// Find the highest pushable node
89fn find_pushdown_candidate(plan: &Arc<dyn ExecutionPlan>) -> Option<Arc<dyn ExecutionPlan>> {
90    // Check if this node is already a LiquidCacheClientExec to avoid redundant wrapping
91    if plan
92        .as_any()
93        .downcast_ref::<LiquidCacheClientExec>()
94        .is_some()
95    {
96        return None;
97    }
98
99    let plan_any = plan.as_any();
100
101    // If we have an AggregateExec (partial, no group by) with a pushable child (direct or through RepartitionExec), push it down
102    if let Some(agg_exec) = plan_any.downcast_ref::<AggregateExec>()
103        && matches!(agg_exec.mode(), AggregateMode::Partial)
104        && agg_exec.group_expr().is_empty()
105    {
106        let child = agg_exec.input();
107
108        // Check if child is DataSourceExec or RepartitionExec->DataSourceExec
109        if child.as_any().downcast_ref::<DataSourceExec>().is_some() {
110            return Some(plan.clone());
111        }
112        if let Some(repart) = child.as_any().downcast_ref::<RepartitionExec>()
113            && let Some(repart_child) = repart.children().first()
114            && repart_child
115                .as_any()
116                .downcast_ref::<DataSourceExec>()
117                .is_some()
118        {
119            return Some(plan.clone());
120        }
121    }
122
123    // If we have a RepartitionExec with a DataSourceExec child, push it down
124    if let Some(repart_exec) = plan_any.downcast_ref::<RepartitionExec>()
125        && let Some(child) = repart_exec.children().first()
126        && child.as_any().downcast_ref::<DataSourceExec>().is_some()
127    {
128        return Some(plan.clone());
129    }
130
131    // If this is a DataSourceExec, push it down
132    if plan_any.downcast_ref::<DataSourceExec>().is_some() {
133        return Some(plan.clone());
134    }
135
136    // Otherwise, recurse into children looking for pushdown candidates
137    for child in plan.children() {
138        if let Some(candidate) = find_pushdown_candidate(child) {
139            return Some(candidate);
140        }
141    }
142
143    None
144}
145
146impl PhysicalOptimizerRule for PushdownOptimizer {
147    fn optimize(
148        &self,
149        plan: Arc<dyn ExecutionPlan>,
150        _config: &ConfigOptions,
151    ) -> Result<Arc<dyn ExecutionPlan>> {
152        self.optimize_plan(plan)
153    }
154
155    fn name(&self) -> &str {
156        "PushdownOptimizer"
157    }
158
159    fn schema_check(&self) -> bool {
160        true
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use std::sync::Arc;
167
168    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
169    use datafusion::{
170        config::ConfigOptions,
171        datasource::memory::MemorySourceConfig,
172        error::Result,
173        execution::SessionStateBuilder,
174        physical_plan::{
175            ExecutionPlan,
176            aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy},
177            display::DisplayableExecutionPlan,
178            repartition::RepartitionExec,
179        },
180        prelude::{SessionConfig, SessionContext},
181    };
182    use liquid_cache_common::CacheMode;
183
184    use super::*;
185
186    async fn create_session_context() -> SessionContext {
187        let mut config = SessionConfig::from_env().unwrap();
188        config.options_mut().execution.parquet.pushdown_filters = true;
189        let builder = SessionStateBuilder::new()
190            .with_config(config)
191            .with_default_features()
192            .with_physical_optimizer_rule(Arc::new(PushdownOptimizer::new(
193                "localhost:15214".to_string(),
194                CacheMode::Liquid,
195                vec![],
196            )));
197        let state = builder.build();
198        let ctx = SessionContext::new_with_state(state);
199        ctx.register_parquet(
200            "nano_hits",
201            "../../examples/nano_hits.parquet",
202            Default::default(),
203        )
204        .await
205        .unwrap();
206        ctx
207    }
208
209    #[tokio::test]
210    async fn test_plan_rewrite() {
211        let ctx = create_session_context().await;
212        let df = ctx
213            .sql("SELECT \"URL\" FROM nano_hits WHERE \"URL\" like 'https://%' limit 10")
214            .await
215            .unwrap();
216        let plan = df.create_physical_plan().await.unwrap();
217        let display_plan = DisplayableExecutionPlan::new(plan.as_ref());
218        let plan_str = display_plan.indent(false).to_string();
219
220        assert!(plan_str.contains("LiquidCacheClientExec"));
221        assert!(plan_str.contains("DataSourceExec"));
222    }
223
224    #[tokio::test]
225    async fn test_aggregate_pushdown() {
226        let ctx = create_session_context().await;
227
228        let df = ctx
229            .sql("SELECT MAX(\"URL\") FROM nano_hits WHERE \"URL\" like 'https://%'")
230            .await
231            .unwrap();
232        let plan = df.create_physical_plan().await.unwrap();
233        let display_plan = DisplayableExecutionPlan::new(plan.as_ref());
234        let plan_str = display_plan.indent(false).to_string();
235
236        println!("Plan: {plan_str}");
237
238        // With the top-down approach, the LiquidCacheClientExec should contain:
239        // 1. The AggregateExec with mode=Partial
240        // 2. Any RepartitionExec below that
241        // 3. The DataSourceExec at the bottom
242
243        // Verify that AggregateExec: mode=Partial is inside the LiquidCacheClientExec
244        assert!(plan_str.contains("LiquidCacheClientExec"));
245
246        let parts: Vec<&str> = plan_str.split("LiquidCacheClientExec").collect();
247        assert!(parts.len() > 1);
248
249        let higher_layers = parts[0];
250        let pushed_down = parts[1];
251
252        assert!(higher_layers.contains("AggregateExec: mode=Final"));
253        assert!(pushed_down.contains("AggregateExec: mode=Partial"));
254        assert!(pushed_down.contains("DataSourceExec"));
255    }
256
257    // Create a test schema for our mock plans
258    fn create_test_schema() -> SchemaRef {
259        Arc::new(Schema::new(vec![
260            Field::new("c1", DataType::Int32, true),
261            Field::new("c2", DataType::Utf8, true),
262            Field::new("c3", DataType::Float64, true),
263        ]))
264    }
265
266    // Mock DataSourceExec that we can use in our tests
267    fn create_datasource_exec(schema: SchemaRef) -> Arc<dyn ExecutionPlan> {
268        Arc::new(DataSourceExec::new(Arc::new(
269            MemorySourceConfig::try_new(&[vec![]], schema, None).unwrap(),
270        )))
271    }
272
273    // Apply the PushdownOptimizer to a plan and get the result as a string for comparison
274    fn apply_optimizer(plan: Arc<dyn ExecutionPlan>) -> String {
275        let optimizer =
276            PushdownOptimizer::new("localhost:15214".to_string(), CacheMode::Liquid, vec![]);
277
278        let optimized = optimizer.optimize(plan, &ConfigOptions::default()).unwrap();
279        let display_plan = DisplayableExecutionPlan::new(optimized.as_ref());
280        display_plan.indent(false).to_string()
281    }
282
283    #[test]
284    fn test_simple_datasource_pushdown() -> Result<()> {
285        let schema = create_test_schema();
286        let datasource = create_datasource_exec(schema);
287        let result = apply_optimizer(datasource);
288        assert!(result.starts_with("LiquidCacheClientExec"));
289        Ok(())
290    }
291
292    #[test]
293    fn test_repartition_datasource_pushdown() -> Result<()> {
294        let schema = create_test_schema();
295        let datasource = create_datasource_exec(schema);
296        let repartition = Arc::new(RepartitionExec::try_new(
297            datasource,
298            datafusion::physical_plan::Partitioning::RoundRobinBatch(4),
299        )?);
300
301        let result = apply_optimizer(repartition);
302
303        assert!(result.starts_with("LiquidCacheClientExec"));
304        assert!(result.contains("RepartitionExec"));
305
306        Ok(())
307    }
308
309    #[test]
310    fn test_partial_aggregate_pushdown() -> Result<()> {
311        // Create an AggregateExec (Partial, no group by) -> DataSourceExec plan
312        let schema = create_test_schema();
313        let datasource = create_datasource_exec(schema.clone());
314
315        let group_by = PhysicalGroupBy::new_single(vec![]);
316
317        let aggregate = Arc::new(AggregateExec::try_new(
318            AggregateMode::Partial,
319            group_by,
320            vec![],
321            vec![],
322            datasource,
323            schema.clone(),
324        )?);
325
326        let result = apply_optimizer(aggregate);
327
328        assert!(result.starts_with("LiquidCacheClientExec"));
329        assert!(result.contains("AggregateExec: mode=Partial"));
330
331        Ok(())
332    }
333
334    #[test]
335    fn test_aggregate_with_repartition_pushdown() -> Result<()> {
336        // Create an AggregateExec (Partial, no group by) -> RepartitionExec -> DataSourceExec plan
337        let schema = create_test_schema();
338        let datasource = create_datasource_exec(schema.clone());
339
340        let repartition = Arc::new(RepartitionExec::try_new(
341            datasource,
342            datafusion::physical_plan::Partitioning::RoundRobinBatch(4),
343        )?);
344
345        let group_by = PhysicalGroupBy::new_single(vec![]);
346        let aggregate = Arc::new(AggregateExec::try_new(
347            AggregateMode::Partial,
348            group_by,
349            vec![],
350            vec![],
351            repartition,
352            schema.clone(),
353        )?);
354
355        let result = apply_optimizer(aggregate);
356
357        assert!(result.starts_with("LiquidCacheClientExec"));
358        assert!(result.contains("AggregateExec: mode=Partial"));
359        assert!(result.contains("RepartitionExec"));
360
361        Ok(())
362    }
363
364    #[test]
365    fn test_non_pushable_aggregate() -> Result<()> {
366        // Create an AggregateExec (Final, no group by) -> DataSourceExec plan
367        // This should not push down the AggregateExec
368        let schema = create_test_schema();
369        let datasource = create_datasource_exec(schema.clone());
370
371        let group_by = PhysicalGroupBy::new_single(vec![]);
372
373        let aggregate = Arc::new(AggregateExec::try_new(
374            AggregateMode::Final,
375            group_by,
376            vec![],
377            vec![],
378            datasource,
379            schema.clone(),
380        )?);
381
382        let result = apply_optimizer(aggregate);
383
384        let parts: Vec<&str> = result.split("LiquidCacheClientExec").collect();
385        assert!(parts.len() > 1);
386
387        let higher_layers = parts[0];
388        assert!(higher_layers.contains("AggregateExec: mode=Final"));
389        let lower_layers = parts[1];
390        assert!(lower_layers.contains("DataSourceExec"));
391
392        Ok(())
393    }
394}