Skip to main content

lance_datafusion/
projection.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use arrow_array::RecordBatch;
5use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
6use datafusion::{logical_expr::Expr, physical_plan::projection::ProjectionExec};
7use datafusion_common::{Column, DFSchema};
8use datafusion_physical_expr::PhysicalExpr;
9use futures::TryStreamExt;
10use std::{
11    collections::{HashMap, HashSet},
12    sync::Arc,
13};
14use tracing::instrument;
15
16use lance_core::{
17    Error, ROW_ADDR, ROW_CREATED_AT_VERSION, ROW_ID, ROW_LAST_UPDATED_AT_VERSION, ROW_OFFSET,
18    Result, WILDCARD,
19    datatypes::{OnMissing, Projectable, Projection, Schema},
20};
21
22use crate::{
23    exec::{LanceExecutionOptions, OneShotExec, execute_plan},
24    planner::Planner,
25};
26
27struct ProjectionBuilder {
28    base: Arc<dyn Projectable>,
29    planner: Planner,
30    output: HashMap<String, Expr>,
31    output_cols: Vec<OutputColumn>,
32    physical_cols_set: HashSet<String>,
33    physical_cols: Vec<String>,
34    needs_row_id: bool,
35    needs_row_addr: bool,
36    needs_row_last_updated_at: bool,
37    needs_row_created_at: bool,
38    must_add_row_offset: bool,
39    has_wildcard: bool,
40}
41
42impl ProjectionBuilder {
43    fn new(base: Arc<dyn Projectable>) -> Self {
44        let full_schema = Arc::new(Projection::full(base.clone()).to_arrow_schema());
45        let full_schema = Arc::new(ProjectionPlan::add_system_columns(&full_schema));
46        let planner = Planner::new(full_schema);
47
48        Self {
49            base,
50            planner,
51            output: HashMap::default(),
52            output_cols: Vec::default(),
53            physical_cols_set: HashSet::default(),
54            physical_cols: Vec::default(),
55            needs_row_id: false,
56            needs_row_addr: false,
57            needs_row_created_at: false,
58            needs_row_last_updated_at: false,
59            must_add_row_offset: false,
60            has_wildcard: false,
61        }
62    }
63
64    fn check_duplicate_column(&self, name: &str) -> Result<()> {
65        if self.output.contains_key(name) {
66            return Err(Error::invalid_input(format!(
67                "Duplicate column name: {}",
68                name
69            )));
70        }
71        Ok(())
72    }
73
74    fn add_column(&mut self, output_name: &str, raw_expr: &str) -> Result<()> {
75        self.check_duplicate_column(output_name)?;
76
77        let expr = self.planner.parse_expr(raw_expr)?;
78        // Run simplification + coercion so that expressions like `coalesce(...)`
79        // (which DataFusion's physical evaluator expects to have been rewritten
80        // into a `CASE` expression by the simplifier) work correctly.
81        let expr = self.planner.optimize_expr(expr)?;
82
83        // If the expression is a bare column reference to a system column, mark that we need it
84        if let Expr::Column(Column {
85            name,
86            relation: None,
87            ..
88        }) = &expr
89        {
90            if name == ROW_ID {
91                self.needs_row_id = true;
92            } else if name == ROW_ADDR {
93                self.needs_row_addr = true;
94            } else if name == ROW_OFFSET {
95                self.must_add_row_offset = true;
96            } else if name == ROW_LAST_UPDATED_AT_VERSION {
97                self.needs_row_last_updated_at = true;
98            } else if name == ROW_CREATED_AT_VERSION {
99                self.needs_row_created_at = true;
100            }
101        }
102
103        for col in Planner::column_names_in_expr(&expr) {
104            if self.physical_cols_set.contains(&col) {
105                continue;
106            }
107            self.physical_cols.push(col.clone());
108            self.physical_cols_set.insert(col);
109        }
110        self.output.insert(output_name.to_string(), expr.clone());
111
112        self.output_cols.push(OutputColumn {
113            expr,
114            name: output_name.to_string(),
115        });
116
117        Ok(())
118    }
119
120    fn add_columns(&mut self, columns: &[(impl AsRef<str>, impl AsRef<str>)]) -> Result<()> {
121        for (output_name, raw_expr) in columns {
122            if raw_expr.as_ref() == WILDCARD {
123                self.has_wildcard = true;
124                for col in self.base.schema().fields.iter().map(|f| f.name.as_str()) {
125                    self.check_duplicate_column(col)?;
126                    self.output_cols.push(OutputColumn {
127                        expr: Expr::Column(Column::from_name(col)),
128                        name: col.to_string(),
129                    });
130                    // Throw placeholder expr in self.output, this will trigger error on duplicates
131                    self.output.insert(col.to_string(), Expr::default());
132                }
133            } else {
134                self.add_column(output_name.as_ref(), raw_expr.as_ref())?;
135            }
136        }
137        Ok(())
138    }
139
140    fn build(self) -> Result<ProjectionPlan> {
141        // Now, calculate the physical projection from the columns referenced by the expressions
142        //
143        // If a column is missing it might be a system column (_rowid, _distance, etc.) and so
144        // we ignore it.  We don't need to load that column from disk at least, which is all we are
145        // trying to calculate here.
146        let mut physical_projection = if self.has_wildcard {
147            Projection::full(self.base.clone())
148        } else {
149            Projection::empty(self.base.clone())
150                .union_columns(&self.physical_cols, OnMissing::Ignore)?
151        };
152
153        physical_projection.with_row_id = self.needs_row_id;
154        physical_projection.with_row_addr = self.needs_row_addr || self.must_add_row_offset;
155        physical_projection.with_row_last_updated_at_version = self.needs_row_last_updated_at;
156        physical_projection.with_row_created_at_version = self.needs_row_created_at;
157
158        Ok(ProjectionPlan {
159            physical_projection,
160            must_add_row_offset: self.must_add_row_offset,
161            requested_output_expr: self.output_cols,
162        })
163    }
164}
165
166#[derive(Clone, Debug)]
167pub struct OutputColumn {
168    /// The expression that represents the output column
169    pub expr: Expr,
170    /// The name of the output column
171    pub name: String,
172}
173
174#[derive(Clone, Debug)]
175pub struct ProjectionPlan {
176    /// The physical schema that must be loaded from the dataset
177    pub physical_projection: Projection,
178
179    /// Needs the row address converted into a row offset
180    pub must_add_row_offset: bool,
181
182    /// The desired output columns
183    pub requested_output_expr: Vec<OutputColumn>,
184}
185
186impl ProjectionPlan {
187    fn add_system_columns(schema: &ArrowSchema) -> ArrowSchema {
188        let mut fields = Vec::from_iter(schema.fields.iter().cloned());
189        fields.push(Arc::new(ArrowField::new(ROW_ID, DataType::UInt64, true)));
190        fields.push(Arc::new(ArrowField::new(ROW_ADDR, DataType::UInt64, true)));
191        fields.push(Arc::new(ArrowField::new(
192            ROW_OFFSET,
193            DataType::UInt64,
194            true,
195        )));
196        fields.push(Arc::new(
197            (*lance_core::ROW_LAST_UPDATED_AT_VERSION_FIELD).clone(),
198        ));
199        fields.push(Arc::new(
200            (*lance_core::ROW_CREATED_AT_VERSION_FIELD).clone(),
201        ));
202        ArrowSchema::new(fields)
203    }
204
205    /// Set the projection from SQL expressions
206    pub fn from_expressions(
207        base: Arc<dyn Projectable>,
208        columns: &[(impl AsRef<str>, impl AsRef<str>)],
209    ) -> Result<Self> {
210        let mut builder = ProjectionBuilder::new(base);
211        builder.add_columns(columns)?;
212        builder.build()
213    }
214
215    /// Set the projection from a schema
216    ///
217    /// This plan will have no complex expressions, the schema must be a subset of the dataset schema.
218    ///
219    /// With this approach it is possible to refer to portions of nested fields.
220    ///
221    /// For example, if the schema is:
222    ///
223    /// ```ignore
224    /// {
225    ///   "metadata": {
226    ///     "location": {
227    ///       "x": f32,
228    ///       "y": f32,
229    ///     },
230    ///     "age": i32,
231    ///   }
232    /// }
233    /// ```
234    ///
235    /// It is possible to project a partial schema that drops `y` like:
236    ///
237    /// ```ignore
238    /// {
239    ///   "metadata": {
240    ///     "location": {
241    ///       "x": f32,
242    ///     },
243    ///     "age": i32,
244    ///   }
245    /// }
246    /// ```
247    ///
248    /// This is something that cannot be done easily using expressions.
249    pub fn from_schema(base: Arc<dyn Projectable>, projection: &Schema) -> Result<Self> {
250        // Separate data columns from system columns
251        // System columns (_rowid, _rowaddr, etc.) are handled via flags in Projection,
252        // not as fields in the Schema
253        let mut data_fields = Vec::new();
254        let mut with_row_id = false;
255        let mut with_row_addr = false;
256        let mut must_add_row_offset = false;
257        let mut with_row_last_updated_at_version = false;
258        let mut with_row_created_at_version = false;
259
260        for field in projection.fields.iter() {
261            if lance_core::is_system_column(&field.name) {
262                // Handle known system columns that can be included in projections
263                if field.name == ROW_ID {
264                    with_row_id = true;
265                    must_add_row_offset = true;
266                } else if field.name == ROW_ADDR {
267                    with_row_addr = true;
268                } else if field.name == ROW_OFFSET {
269                    with_row_addr = true;
270                    must_add_row_offset = true;
271                } else if field.name == ROW_LAST_UPDATED_AT_VERSION {
272                    with_row_last_updated_at_version = true;
273                } else if field.name == ROW_CREATED_AT_VERSION {
274                    with_row_created_at_version = true;
275                }
276            } else {
277                // Regular data column - validate it exists in base schema
278                if base.schema().field(&field.name).is_none() {
279                    return Err(Error::invalid_input(format!(
280                        "Column '{}' not found in schema",
281                        field.name
282                    )));
283                }
284                data_fields.push(field.clone());
285            }
286        }
287
288        // Create a schema with only data columns for the physical projection
289        let data_schema = Schema {
290            fields: data_fields,
291            metadata: projection.metadata.clone(),
292        };
293
294        // Calculate the physical projection from data columns only
295        let mut physical_projection = Projection::empty(base).union_schema(&data_schema);
296        physical_projection.with_row_id = with_row_id;
297        physical_projection.with_row_addr = with_row_addr;
298        physical_projection.with_row_last_updated_at_version = with_row_last_updated_at_version;
299        physical_projection.with_row_created_at_version = with_row_created_at_version;
300
301        // Build output expressions preserving the original order (including system columns)
302        let exprs = projection
303            .fields
304            .iter()
305            .map(|f| OutputColumn {
306                expr: Expr::Column(Column::from_name(&f.name)),
307                name: f.name.clone(),
308            })
309            .collect::<Vec<_>>();
310
311        Ok(Self {
312            physical_projection,
313            requested_output_expr: exprs,
314            must_add_row_offset,
315        })
316    }
317
318    pub fn full(base: Arc<dyn Projectable>) -> Result<Self> {
319        let physical_cols: Vec<&str> = base
320            .schema()
321            .fields
322            .iter()
323            .map(|f| f.name.as_ref())
324            .collect::<Vec<_>>();
325
326        let physical_projection =
327            Projection::empty(base.clone()).union_columns(&physical_cols, OnMissing::Ignore)?;
328
329        let requested_output_expr = physical_cols
330            .into_iter()
331            .map(|col_name| OutputColumn {
332                expr: Expr::Column(Column::from_name(col_name)),
333                name: col_name.to_string(),
334            })
335            .collect();
336
337        Ok(Self {
338            physical_projection,
339            must_add_row_offset: false,
340            requested_output_expr,
341        })
342    }
343
344    /// Convert the projection to a list of physical expressions
345    ///
346    /// This is used to apply the final projection (including dynamic expressions) to the data.
347    pub fn to_physical_exprs(
348        &self,
349        current_schema: &ArrowSchema,
350    ) -> Result<Vec<(Arc<dyn PhysicalExpr>, String)>> {
351        let physical_df_schema = Arc::new(DFSchema::try_from(current_schema.clone())?);
352        self.requested_output_expr
353            .iter()
354            .map(|output_column| {
355                Ok((
356                    datafusion::physical_expr::create_physical_expr(
357                        &output_column.expr,
358                        physical_df_schema.as_ref(),
359                        &Default::default(),
360                    )?,
361                    output_column.name.clone(),
362                ))
363            })
364            .collect::<Result<Vec<_>>>()
365    }
366
367    /// Include the row id in the output
368    pub fn include_row_id(&mut self) {
369        self.physical_projection.with_row_id = true;
370        if !self
371            .requested_output_expr
372            .iter()
373            .any(|OutputColumn { name, .. }| name == ROW_ID)
374        {
375            self.requested_output_expr.push(OutputColumn {
376                expr: Expr::Column(Column::from_name(ROW_ID)),
377                name: ROW_ID.to_string(),
378            });
379        }
380    }
381
382    /// Include the row address in the output
383    pub fn include_row_addr(&mut self) {
384        self.physical_projection.with_row_addr = true;
385        if !self
386            .requested_output_expr
387            .iter()
388            .any(|OutputColumn { name, .. }| name == ROW_ADDR)
389        {
390            self.requested_output_expr.push(OutputColumn {
391                expr: Expr::Column(Column::from_name(ROW_ADDR)),
392                name: ROW_ADDR.to_string(),
393            });
394        }
395    }
396
397    /// Check if the projection has any output columns
398    ///
399    /// This doesn't mean there is a physical projection.  For example, we may someday support
400    /// something like `SELECT 1 AS foo` which would have an output column (foo) but no physical projection
401    pub fn has_output_cols(&self) -> bool {
402        !self.requested_output_expr.is_empty()
403    }
404
405    pub fn output_schema(&self) -> Result<ArrowSchema> {
406        let physical_schema = self.physical_projection.to_arrow_schema();
407        let exprs = self.to_physical_exprs(&physical_schema)?;
408        let fields = exprs
409            .iter()
410            .map(|(expr, name)| {
411                let metadata = expr.return_field(&physical_schema)?.metadata().clone();
412                Ok(ArrowField::new(
413                    name,
414                    expr.data_type(&physical_schema)?,
415                    expr.nullable(&physical_schema)?,
416                )
417                .with_metadata(metadata))
418            })
419            .collect::<Result<Vec<_>>>()?;
420        Ok(ArrowSchema::new_with_metadata(
421            fields,
422            physical_schema.metadata().clone(),
423        ))
424    }
425
426    #[instrument(skip_all, level = "debug")]
427    pub async fn project_batch(&self, batch: RecordBatch) -> Result<RecordBatch> {
428        let src = Arc::new(OneShotExec::from_batch(batch));
429
430        // Need to add ROW_OFFSET to get filterable schema
431        let extra_columns = vec![
432            ArrowField::new(ROW_ADDR, DataType::UInt64, true),
433            ArrowField::new(ROW_OFFSET, DataType::UInt64, true),
434        ];
435        let mut filterable_schema = self.physical_projection.to_schema();
436        filterable_schema = filterable_schema.merge(&ArrowSchema::new(extra_columns))?;
437
438        let physical_exprs = self.to_physical_exprs(&(&filterable_schema).into())?;
439        let projection = Arc::new(ProjectionExec::try_new(physical_exprs, src)?);
440
441        // Run dummy plan to execute projection, do not log the plan run
442        let stream = execute_plan(
443            projection,
444            LanceExecutionOptions {
445                skip_logging: true,
446                ..Default::default()
447            },
448        )?;
449        let batches = stream.try_collect::<Vec<_>>().await?;
450        if batches.len() != 1 {
451            Err(Error::internal("Expected exactly one batch".to_string()))
452        } else {
453            Ok(batches.into_iter().next().unwrap())
454        }
455    }
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461
462    use arrow_array::Int64Array;
463    use lance_arrow::json::{is_json_field, json_field};
464
465    #[tokio::test]
466    async fn test_coalesce_in_column_map() {
467        // Regression test: `coalesce` in a column-map expression used to fail with
468        // "coalesce should have been simplified to case" because the parsed expression
469        // was passed straight to `create_physical_expr` without running the simplifier.
470        let arrow_schema = Arc::new(ArrowSchema::new(vec![
471            ArrowField::new("col_a", DataType::Int64, true),
472            ArrowField::new("col_b", DataType::Int64, true),
473        ]));
474        let base_schema = Schema::try_from(arrow_schema.as_ref()).unwrap();
475        let base = Arc::new(base_schema);
476
477        let plan =
478            ProjectionPlan::from_expressions(base, &[("foo", "coalesce(col_a, col_b)")]).unwrap();
479
480        let batch = RecordBatch::try_new(
481            arrow_schema,
482            vec![
483                Arc::new(Int64Array::from(vec![Some(1), None, Some(3), None])),
484                Arc::new(Int64Array::from(vec![Some(10), Some(20), None, None])),
485            ],
486        )
487        .unwrap();
488
489        let projected = plan.project_batch(batch).await.unwrap();
490        let foo = projected
491            .column(0)
492            .as_any()
493            .downcast_ref::<Int64Array>()
494            .unwrap();
495        assert_eq!(
496            foo.iter().collect::<Vec<_>>(),
497            vec![Some(1), Some(20), Some(3), None],
498        );
499    }
500
501    #[test]
502    fn test_output_schema_preserves_json_extension_metadata() {
503        let arrow_schema = ArrowSchema::new(vec![
504            ArrowField::new("id", DataType::Int32, false),
505            json_field("meta", true),
506        ]);
507        let base_schema = Schema::try_from(&arrow_schema).unwrap();
508        let base = Arc::new(base_schema.clone());
509
510        let plan = ProjectionPlan::from_schema(base, &base_schema).unwrap();
511
512        let physical = plan.physical_projection.to_arrow_schema();
513        assert!(is_json_field(physical.field_with_name("meta").unwrap()));
514
515        let output = plan.output_schema().unwrap();
516        let output_field = output.field_with_name("meta").unwrap();
517        assert!(is_json_field(output_field));
518    }
519}