liquid_cache_client/
optimizer.rs

1use std::{collections::HashMap, sync::Arc};
2
3use datafusion::{
4    config::ConfigOptions, datasource::source::DataSourceExec, error::Result,
5    execution::object_store::ObjectStoreUrl, physical_optimizer::PhysicalOptimizerRule,
6    physical_plan::ExecutionPlan, physical_plan::aggregates::AggregateExec,
7    physical_plan::aggregates::AggregateMode, physical_plan::repartition::RepartitionExec,
8};
9
10use crate::client_exec::LiquidCacheClientExec;
11
12/// PushdownOptimizer is a physical optimizer rule that pushes down filters to the liquid cache server.
13#[derive(Debug)]
14pub struct PushdownOptimizer {
15    cache_server: String,
16    object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
17}
18
19impl PushdownOptimizer {
20    /// Create a new PushdownOptimizer
21    ///
22    /// # Arguments
23    ///
24    /// * `cache_server` - The address of the liquid cache server
25    /// * `object_stores` - The object stores to use
26    ///
27    /// # Returns
28    ///
29    pub fn new(
30        cache_server: String,
31        object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
32    ) -> Self {
33        Self {
34            cache_server,
35            object_stores,
36        }
37    }
38
39    /// Apply the optimization by finding nodes to push down and wrapping them
40    fn optimize_plan(&self, plan: Arc<dyn ExecutionPlan>) -> Result<Arc<dyn ExecutionPlan>> {
41        // If this node is already a LiquidCacheClientExec, return it as is
42        if plan
43            .as_any()
44            .downcast_ref::<LiquidCacheClientExec>()
45            .is_some()
46        {
47            return Ok(plan);
48        }
49
50        // Find the candidate to push down in this branch of the tree
51        if let Some(candidate) = find_pushdown_candidate(&plan) {
52            // If the current node is the one to be pushed down, wrap it
53            if Arc::ptr_eq(&plan, &candidate) {
54                return Ok(Arc::new(LiquidCacheClientExec::new(
55                    plan,
56                    self.cache_server.clone(),
57                    self.object_stores.clone(),
58                )));
59            }
60        }
61
62        // Otherwise, recurse into children
63        let mut new_children = Vec::with_capacity(plan.children().len());
64        let mut children_changed = false;
65
66        for child in plan.children() {
67            let new_child = self.optimize_plan(child.clone())?;
68            if !Arc::ptr_eq(child, &new_child) {
69                children_changed = true;
70            }
71            new_children.push(new_child);
72        }
73
74        // If any children were changed, create a new plan with the updated children
75        if children_changed {
76            plan.with_new_children(new_children)
77        } else {
78            Ok(plan)
79        }
80    }
81}
82
83/// Find the highest pushable node
84fn find_pushdown_candidate(plan: &Arc<dyn ExecutionPlan>) -> Option<Arc<dyn ExecutionPlan>> {
85    // Check if this node is already a LiquidCacheClientExec to avoid redundant wrapping
86    if plan
87        .as_any()
88        .downcast_ref::<LiquidCacheClientExec>()
89        .is_some()
90    {
91        return None;
92    }
93
94    let plan_any = plan.as_any();
95
96    // If we have an AggregateExec (partial, no group by) with a pushable child (direct or through RepartitionExec), push it down
97    if let Some(agg_exec) = plan_any.downcast_ref::<AggregateExec>()
98        && matches!(agg_exec.mode(), AggregateMode::Partial)
99        && agg_exec.group_expr().is_empty()
100    {
101        let child = agg_exec.input();
102
103        // Check if child is DataSourceExec or RepartitionExec->DataSourceExec
104        if child.as_any().downcast_ref::<DataSourceExec>().is_some() {
105            return Some(plan.clone());
106        }
107        if let Some(repart) = child.as_any().downcast_ref::<RepartitionExec>()
108            && let Some(repart_child) = repart.children().first()
109            && repart_child
110                .as_any()
111                .downcast_ref::<DataSourceExec>()
112                .is_some()
113        {
114            return Some(plan.clone());
115        }
116    }
117
118    // If we have a RepartitionExec with a DataSourceExec child, push it down
119    if let Some(repart_exec) = plan_any.downcast_ref::<RepartitionExec>()
120        && let Some(child) = repart_exec.children().first()
121        && child.as_any().downcast_ref::<DataSourceExec>().is_some()
122    {
123        return Some(plan.clone());
124    }
125
126    // If this is a DataSourceExec, push it down
127    if plan_any.downcast_ref::<DataSourceExec>().is_some() {
128        return Some(plan.clone());
129    }
130
131    // Otherwise, recurse into children looking for pushdown candidates
132    for child in plan.children() {
133        if let Some(candidate) = find_pushdown_candidate(child) {
134            return Some(candidate);
135        }
136    }
137
138    None
139}
140
141impl PhysicalOptimizerRule for PushdownOptimizer {
142    fn optimize(
143        &self,
144        plan: Arc<dyn ExecutionPlan>,
145        _config: &ConfigOptions,
146    ) -> Result<Arc<dyn ExecutionPlan>> {
147        self.optimize_plan(plan)
148    }
149
150    fn name(&self) -> &str {
151        "PushdownOptimizer"
152    }
153
154    fn schema_check(&self) -> bool {
155        true
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use std::sync::Arc;
162
163    use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
164    use datafusion::{
165        config::ConfigOptions,
166        datasource::memory::MemorySourceConfig,
167        error::Result,
168        execution::SessionStateBuilder,
169        physical_plan::{
170            ExecutionPlan,
171            aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy},
172            display::DisplayableExecutionPlan,
173            repartition::RepartitionExec,
174        },
175        prelude::{SessionConfig, SessionContext},
176    };
177
178    use super::*;
179
180    async fn create_session_context() -> SessionContext {
181        let mut config = SessionConfig::from_env().unwrap();
182        config.options_mut().execution.parquet.pushdown_filters = true;
183        let builder = SessionStateBuilder::new()
184            .with_config(config)
185            .with_default_features()
186            .with_physical_optimizer_rule(Arc::new(PushdownOptimizer::new(
187                "localhost:15214".to_string(),
188                vec![],
189            )));
190        let state = builder.build();
191        let ctx = SessionContext::new_with_state(state);
192        ctx.register_parquet(
193            "nano_hits",
194            "../../examples/nano_hits.parquet",
195            Default::default(),
196        )
197        .await
198        .unwrap();
199        ctx
200    }
201
202    #[tokio::test]
203    async fn test_plan_rewrite() {
204        let ctx = create_session_context().await;
205        let df = ctx
206            .sql("SELECT \"URL\" FROM nano_hits WHERE \"URL\" like 'https://%' limit 10")
207            .await
208            .unwrap();
209        let plan = df.create_physical_plan().await.unwrap();
210        let display_plan = DisplayableExecutionPlan::new(plan.as_ref());
211        let plan_str = display_plan.indent(false).to_string();
212
213        assert!(plan_str.contains("LiquidCacheClientExec"));
214        assert!(plan_str.contains("DataSourceExec"));
215    }
216
217    #[tokio::test]
218    async fn test_aggregate_pushdown() {
219        let ctx = create_session_context().await;
220
221        let df = ctx
222            .sql("SELECT MAX(\"URL\") FROM nano_hits WHERE \"URL\" like 'https://%'")
223            .await
224            .unwrap();
225        let plan = df.create_physical_plan().await.unwrap();
226        let display_plan = DisplayableExecutionPlan::new(plan.as_ref());
227        let plan_str = display_plan.indent(false).to_string();
228
229        println!("Plan: {plan_str}");
230
231        // With the top-down approach, the LiquidCacheClientExec should contain:
232        // 1. The AggregateExec with mode=Partial
233        // 2. Any RepartitionExec below that
234        // 3. The DataSourceExec at the bottom
235
236        // Verify that AggregateExec: mode=Partial is inside the LiquidCacheClientExec
237        assert!(plan_str.contains("LiquidCacheClientExec"));
238
239        let parts: Vec<&str> = plan_str.split("LiquidCacheClientExec").collect();
240        assert!(parts.len() > 1);
241
242        let higher_layers = parts[0];
243        let pushed_down = parts[1];
244
245        assert!(higher_layers.contains("AggregateExec: mode=Final"));
246        assert!(pushed_down.contains("AggregateExec: mode=Partial"));
247        assert!(pushed_down.contains("DataSourceExec"));
248    }
249
250    // Create a test schema for our mock plans
251    fn create_test_schema() -> SchemaRef {
252        Arc::new(Schema::new(vec![
253            Field::new("c1", DataType::Int32, true),
254            Field::new("c2", DataType::Utf8, true),
255            Field::new("c3", DataType::Float64, true),
256        ]))
257    }
258
259    // Mock DataSourceExec that we can use in our tests
260    fn create_datasource_exec(schema: SchemaRef) -> Arc<dyn ExecutionPlan> {
261        Arc::new(DataSourceExec::new(Arc::new(
262            MemorySourceConfig::try_new(&[vec![]], schema, None).unwrap(),
263        )))
264    }
265
266    // Apply the PushdownOptimizer to a plan and get the result as a string for comparison
267    fn apply_optimizer(plan: Arc<dyn ExecutionPlan>) -> String {
268        let optimizer = PushdownOptimizer::new("localhost:15214".to_string(), vec![]);
269
270        let optimized = optimizer.optimize(plan, &ConfigOptions::default()).unwrap();
271        let display_plan = DisplayableExecutionPlan::new(optimized.as_ref());
272        display_plan.indent(false).to_string()
273    }
274
275    #[test]
276    fn test_simple_datasource_pushdown() -> Result<()> {
277        let schema = create_test_schema();
278        let datasource = create_datasource_exec(schema);
279        let result = apply_optimizer(datasource);
280        assert!(result.starts_with("LiquidCacheClientExec"));
281        Ok(())
282    }
283
284    #[test]
285    fn test_repartition_datasource_pushdown() -> Result<()> {
286        let schema = create_test_schema();
287        let datasource = create_datasource_exec(schema);
288        let repartition = Arc::new(RepartitionExec::try_new(
289            datasource,
290            datafusion::physical_plan::Partitioning::RoundRobinBatch(4),
291        )?);
292
293        let result = apply_optimizer(repartition);
294
295        assert!(result.starts_with("LiquidCacheClientExec"));
296        assert!(result.contains("RepartitionExec"));
297
298        Ok(())
299    }
300
301    #[test]
302    fn test_partial_aggregate_pushdown() -> Result<()> {
303        // Create an AggregateExec (Partial, no group by) -> DataSourceExec plan
304        let schema = create_test_schema();
305        let datasource = create_datasource_exec(schema.clone());
306
307        let group_by = PhysicalGroupBy::new_single(vec![]);
308
309        let aggregate = Arc::new(AggregateExec::try_new(
310            AggregateMode::Partial,
311            group_by,
312            vec![],
313            vec![],
314            datasource,
315            schema.clone(),
316        )?);
317
318        let result = apply_optimizer(aggregate);
319
320        assert!(result.starts_with("LiquidCacheClientExec"));
321        assert!(result.contains("AggregateExec: mode=Partial"));
322
323        Ok(())
324    }
325
326    #[test]
327    fn test_aggregate_with_repartition_pushdown() -> Result<()> {
328        // Create an AggregateExec (Partial, no group by) -> RepartitionExec -> DataSourceExec plan
329        let schema = create_test_schema();
330        let datasource = create_datasource_exec(schema.clone());
331
332        let repartition = Arc::new(RepartitionExec::try_new(
333            datasource,
334            datafusion::physical_plan::Partitioning::RoundRobinBatch(4),
335        )?);
336
337        let group_by = PhysicalGroupBy::new_single(vec![]);
338        let aggregate = Arc::new(AggregateExec::try_new(
339            AggregateMode::Partial,
340            group_by,
341            vec![],
342            vec![],
343            repartition,
344            schema.clone(),
345        )?);
346
347        let result = apply_optimizer(aggregate);
348
349        assert!(result.starts_with("LiquidCacheClientExec"));
350        assert!(result.contains("AggregateExec: mode=Partial"));
351        assert!(result.contains("RepartitionExec"));
352
353        Ok(())
354    }
355
356    #[test]
357    fn test_non_pushable_aggregate() -> Result<()> {
358        // Create an AggregateExec (Final, no group by) -> DataSourceExec plan
359        // This should not push down the AggregateExec
360        let schema = create_test_schema();
361        let datasource = create_datasource_exec(schema.clone());
362
363        let group_by = PhysicalGroupBy::new_single(vec![]);
364
365        let aggregate = Arc::new(AggregateExec::try_new(
366            AggregateMode::Final,
367            group_by,
368            vec![],
369            vec![],
370            datasource,
371            schema.clone(),
372        )?);
373
374        let result = apply_optimizer(aggregate);
375
376        let parts: Vec<&str> = result.split("LiquidCacheClientExec").collect();
377        assert!(parts.len() > 1);
378
379        let higher_layers = parts[0];
380        assert!(higher_layers.contains("AggregateExec: mode=Final"));
381        let lower_layers = parts[1];
382        assert!(lower_layers.contains("DataSourceExec"));
383
384        Ok(())
385    }
386}