lance_datafusion/
projection.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use arrow_array::RecordBatch;
5use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
6use datafusion::{logical_expr::Expr, physical_plan::projection::ProjectionExec};
7use datafusion_common::{Column, DFSchema};
8use datafusion_physical_expr::PhysicalExpr;
9use futures::TryStreamExt;
10use snafu::location;
11use std::{
12    collections::{HashMap, HashSet},
13    sync::Arc,
14};
15use tracing::instrument;
16
17use lance_core::{
18    datatypes::{BlobVersion, OnMissing, Projectable, Projection, Schema},
19    Error, Result, ROW_ADDR, ROW_CREATED_AT_VERSION, ROW_ID, ROW_LAST_UPDATED_AT_VERSION,
20    ROW_OFFSET, WILDCARD,
21};
22
23use crate::{
24    exec::{execute_plan, LanceExecutionOptions, OneShotExec},
25    planner::Planner,
26};
27
28struct ProjectionBuilder {
29    base: Arc<dyn Projectable>,
30    planner: Planner,
31    output: HashMap<String, Expr>,
32    output_cols: Vec<OutputColumn>,
33    physical_cols_set: HashSet<String>,
34    physical_cols: Vec<String>,
35    needs_row_id: bool,
36    needs_row_addr: bool,
37    needs_row_last_updated_at: bool,
38    needs_row_created_at: bool,
39    must_add_row_offset: bool,
40    has_wildcard: bool,
41    blob_version: BlobVersion,
42}
43
44impl ProjectionBuilder {
45    fn new(base: Arc<dyn Projectable>, blob_version: BlobVersion) -> Self {
46        let full_schema = Arc::new(
47            Projection::full(base.clone())
48                .with_blob_version(blob_version)
49                .to_arrow_schema(),
50        );
51        let full_schema = Arc::new(ProjectionPlan::add_system_columns(&full_schema));
52        let planner = Planner::new(full_schema);
53
54        Self {
55            base,
56            planner,
57            output: HashMap::default(),
58            output_cols: Vec::default(),
59            physical_cols_set: HashSet::default(),
60            physical_cols: Vec::default(),
61            needs_row_id: false,
62            needs_row_addr: false,
63            needs_row_created_at: false,
64            needs_row_last_updated_at: false,
65            must_add_row_offset: false,
66            has_wildcard: false,
67            blob_version,
68        }
69    }
70
71    fn check_duplicate_column(&self, name: &str) -> Result<()> {
72        if self.output.contains_key(name) {
73            return Err(Error::io(
74                format!("Duplicate column name: {}", name),
75                location!(),
76            ));
77        }
78        Ok(())
79    }
80
81    fn add_column(&mut self, output_name: &str, raw_expr: &str) -> Result<()> {
82        self.check_duplicate_column(output_name)?;
83
84        let expr = self.planner.parse_expr(raw_expr)?;
85
86        // If the expression is a bare column reference to a system column, mark that we need it
87        if let Expr::Column(Column {
88            name,
89            relation: None,
90            ..
91        }) = &expr
92        {
93            if name == ROW_ID {
94                self.needs_row_id = true;
95            } else if name == ROW_ADDR {
96                self.needs_row_addr = true;
97            } else if name == ROW_OFFSET {
98                self.must_add_row_offset = true;
99            } else if name == ROW_LAST_UPDATED_AT_VERSION {
100                self.needs_row_last_updated_at = true;
101            } else if name == ROW_CREATED_AT_VERSION {
102                self.needs_row_created_at = true;
103            }
104        }
105
106        for col in Planner::column_names_in_expr(&expr) {
107            if self.physical_cols_set.contains(&col) {
108                continue;
109            }
110            self.physical_cols.push(col.clone());
111            self.physical_cols_set.insert(col);
112        }
113        self.output.insert(output_name.to_string(), expr.clone());
114
115        self.output_cols.push(OutputColumn {
116            expr,
117            name: output_name.to_string(),
118        });
119
120        Ok(())
121    }
122
123    fn add_columns(&mut self, columns: &[(impl AsRef<str>, impl AsRef<str>)]) -> Result<()> {
124        for (output_name, raw_expr) in columns {
125            if raw_expr.as_ref() == WILDCARD {
126                self.has_wildcard = true;
127                for col in self.base.schema().fields.iter().map(|f| f.name.as_str()) {
128                    self.check_duplicate_column(col)?;
129                    self.output_cols.push(OutputColumn {
130                        expr: Expr::Column(Column::from_name(col)),
131                        name: col.to_string(),
132                    });
133                    // Throw placeholder expr in self.output, this will trigger error on duplicates
134                    self.output.insert(col.to_string(), Expr::default());
135                }
136            } else {
137                self.add_column(output_name.as_ref(), raw_expr.as_ref())?;
138            }
139        }
140        Ok(())
141    }
142
143    fn build(self) -> Result<ProjectionPlan> {
144        // Now, calculate the physical projection from the columns referenced by the expressions
145        //
146        // If a column is missing it might be a system column (_rowid, _distance, etc.) and so
147        // we ignore it.  We don't need to load that column from disk at least, which is all we are
148        // trying to calculate here.
149        let mut physical_projection = if self.has_wildcard {
150            Projection::full(self.base.clone())
151        } else {
152            Projection::empty(self.base.clone())
153                .union_columns(&self.physical_cols, OnMissing::Ignore)?
154        };
155
156        physical_projection = physical_projection.with_blob_version(self.blob_version);
157
158        physical_projection.with_row_id = self.needs_row_id;
159        physical_projection.with_row_addr = self.needs_row_addr || self.must_add_row_offset;
160        physical_projection.with_row_last_updated_at_version = self.needs_row_last_updated_at;
161        physical_projection.with_row_created_at_version = self.needs_row_created_at;
162
163        Ok(ProjectionPlan {
164            physical_projection,
165            must_add_row_offset: self.must_add_row_offset,
166            requested_output_expr: self.output_cols,
167        })
168    }
169}
170
171#[derive(Clone, Debug)]
172pub struct OutputColumn {
173    /// The expression that represents the output column
174    pub expr: Expr,
175    /// The name of the output column
176    pub name: String,
177}
178
179#[derive(Clone, Debug)]
180pub struct ProjectionPlan {
181    /// The physical schema that must be loaded from the dataset
182    pub physical_projection: Projection,
183
184    /// Needs the row address converted into a row offset
185    pub must_add_row_offset: bool,
186
187    /// The desired output columns
188    pub requested_output_expr: Vec<OutputColumn>,
189}
190
191impl ProjectionPlan {
192    fn add_system_columns(schema: &ArrowSchema) -> ArrowSchema {
193        let mut fields = Vec::from_iter(schema.fields.iter().cloned());
194        fields.push(Arc::new(ArrowField::new(ROW_ID, DataType::UInt64, true)));
195        fields.push(Arc::new(ArrowField::new(ROW_ADDR, DataType::UInt64, true)));
196        fields.push(Arc::new(ArrowField::new(
197            ROW_OFFSET,
198            DataType::UInt64,
199            true,
200        )));
201        fields.push(Arc::new(
202            (*lance_core::ROW_LAST_UPDATED_AT_VERSION_FIELD).clone(),
203        ));
204        fields.push(Arc::new(
205            (*lance_core::ROW_CREATED_AT_VERSION_FIELD).clone(),
206        ));
207        ArrowSchema::new(fields)
208    }
209
210    /// Set the projection from SQL expressions
211    pub fn from_expressions(
212        base: Arc<dyn Projectable>,
213        columns: &[(impl AsRef<str>, impl AsRef<str>)],
214        blob_version: BlobVersion,
215    ) -> Result<Self> {
216        let mut builder = ProjectionBuilder::new(base, blob_version);
217        builder.add_columns(columns)?;
218        builder.build()
219    }
220
221    /// Set the projection from a schema
222    ///
223    /// This plan will have no complex expressions, the schema must be a subset of the dataset schema.
224    ///
225    /// With this approach it is possible to refer to portions of nested fields.
226    ///
227    /// For example, if the schema is:
228    ///
229    /// ```ignore
230    /// {
231    ///   "metadata": {
232    ///     "location": {
233    ///       "x": f32,
234    ///       "y": f32,
235    ///     },
236    ///     "age": i32,
237    ///   }
238    /// }
239    /// ```
240    ///
241    /// It is possible to project a partial schema that drops `y` like:
242    ///
243    /// ```ignore
244    /// {
245    ///   "metadata": {
246    ///     "location": {
247    ///       "x": f32,
248    ///     },
249    ///     "age": i32,
250    ///   }
251    /// }
252    /// ```
253    ///
254    /// This is something that cannot be done easily using expressions.
255    pub fn from_schema(
256        base: Arc<dyn Projectable>,
257        projection: &Schema,
258        blob_version: BlobVersion,
259    ) -> Result<Self> {
260        // Separate data columns from system columns
261        // System columns (_rowid, _rowaddr, etc.) are handled via flags in Projection,
262        // not as fields in the Schema
263        let mut data_fields = Vec::new();
264        let mut with_row_id = false;
265        let mut with_row_addr = false;
266        let mut must_add_row_offset = false;
267
268        for field in projection.fields.iter() {
269            if lance_core::is_system_column(&field.name) {
270                // Handle known system columns that can be included in projections
271                if field.name == ROW_ID {
272                    with_row_id = true;
273                    must_add_row_offset = true;
274                } else if field.name == ROW_ADDR {
275                    with_row_addr = true;
276                    must_add_row_offset = true;
277                }
278                // Note: Other system columns like _rowoffset are computed differently
279                // and shouldn't appear in the schema at this point
280            } else {
281                // Regular data column - validate it exists in base schema
282                if base.schema().field(&field.name).is_none() {
283                    return Err(Error::io(
284                        format!("Column '{}' not found in schema", field.name),
285                        location!(),
286                    ));
287                }
288                data_fields.push(field.clone());
289            }
290        }
291
292        // Create a schema with only data columns for the physical projection
293        let data_schema = Schema {
294            fields: data_fields,
295            metadata: projection.metadata.clone(),
296        };
297
298        // Calculate the physical projection from data columns only
299        let mut physical_projection = Projection::empty(base)
300            .union_schema(&data_schema)
301            .with_blob_version(blob_version);
302        physical_projection.with_row_id = with_row_id;
303        physical_projection.with_row_addr = with_row_addr;
304
305        // Build output expressions preserving the original order (including system columns)
306        let exprs = projection
307            .fields
308            .iter()
309            .map(|f| OutputColumn {
310                expr: Expr::Column(Column::from_name(&f.name)),
311                name: f.name.clone(),
312            })
313            .collect::<Vec<_>>();
314
315        Ok(Self {
316            physical_projection,
317            requested_output_expr: exprs,
318            must_add_row_offset,
319        })
320    }
321
322    pub fn full(base: Arc<dyn Projectable>, blob_version: BlobVersion) -> Result<Self> {
323        let physical_cols: Vec<&str> = base
324            .schema()
325            .fields
326            .iter()
327            .map(|f| f.name.as_ref())
328            .collect::<Vec<_>>();
329
330        let physical_projection = Projection::empty(base.clone())
331            .union_columns(&physical_cols, OnMissing::Ignore)?
332            .with_blob_version(blob_version);
333
334        let requested_output_expr = physical_cols
335            .into_iter()
336            .map(|col_name| OutputColumn {
337                expr: Expr::Column(Column::from_name(col_name)),
338                name: col_name.to_string(),
339            })
340            .collect();
341
342        Ok(Self {
343            physical_projection,
344            must_add_row_offset: false,
345            requested_output_expr,
346        })
347    }
348
349    /// Convert the projection to a list of physical expressions
350    ///
351    /// This is used to apply the final projection (including dynamic expressions) to the data.
352    pub fn to_physical_exprs(
353        &self,
354        current_schema: &ArrowSchema,
355    ) -> Result<Vec<(Arc<dyn PhysicalExpr>, String)>> {
356        let physical_df_schema = Arc::new(DFSchema::try_from(current_schema.clone())?);
357        self.requested_output_expr
358            .iter()
359            .map(|output_column| {
360                Ok((
361                    datafusion::physical_expr::create_physical_expr(
362                        &output_column.expr,
363                        physical_df_schema.as_ref(),
364                        &Default::default(),
365                    )?,
366                    output_column.name.clone(),
367                ))
368            })
369            .collect::<Result<Vec<_>>>()
370    }
371
372    /// Include the row id in the output
373    pub fn include_row_id(&mut self) {
374        self.physical_projection.with_row_id = true;
375        if !self
376            .requested_output_expr
377            .iter()
378            .any(|OutputColumn { name, .. }| name == ROW_ID)
379        {
380            self.requested_output_expr.push(OutputColumn {
381                expr: Expr::Column(Column::from_name(ROW_ID)),
382                name: ROW_ID.to_string(),
383            });
384        }
385    }
386
387    /// Include the row address in the output
388    pub fn include_row_addr(&mut self) {
389        self.physical_projection.with_row_addr = true;
390        if !self
391            .requested_output_expr
392            .iter()
393            .any(|OutputColumn { name, .. }| name == ROW_ADDR)
394        {
395            self.requested_output_expr.push(OutputColumn {
396                expr: Expr::Column(Column::from_name(ROW_ADDR)),
397                name: ROW_ADDR.to_string(),
398            });
399        }
400    }
401
402    /// Check if the projection has any output columns
403    ///
404    /// This doesn't mean there is a physical projection.  For example, we may someday support
405    /// something like `SELECT 1 AS foo` which would have an output column (foo) but no physical projection
406    pub fn has_output_cols(&self) -> bool {
407        !self.requested_output_expr.is_empty()
408    }
409
410    pub fn output_schema(&self) -> Result<ArrowSchema> {
411        let exprs = self.to_physical_exprs(&self.physical_projection.to_arrow_schema())?;
412        let physical_schema = self.physical_projection.to_arrow_schema();
413        let fields = exprs
414            .iter()
415            .map(|(expr, name)| {
416                Ok(ArrowField::new(
417                    name,
418                    expr.data_type(&physical_schema)?,
419                    expr.nullable(&physical_schema)?,
420                ))
421            })
422            .collect::<Result<Vec<_>>>()?;
423        Ok(ArrowSchema::new(fields))
424    }
425
426    #[instrument(skip_all, level = "debug")]
427    pub async fn project_batch(&self, batch: RecordBatch) -> Result<RecordBatch> {
428        let src = Arc::new(OneShotExec::from_batch(batch));
429        let physical_exprs = self.to_physical_exprs(&self.physical_projection.to_arrow_schema())?;
430        let projection = Arc::new(ProjectionExec::try_new(physical_exprs, src)?);
431        // Run dummy plan to execute projection, do not log the plan run
432        let stream = execute_plan(
433            projection,
434            LanceExecutionOptions {
435                skip_logging: true,
436                ..Default::default()
437            },
438        )?;
439        let batches = stream.try_collect::<Vec<_>>().await?;
440        if batches.len() != 1 {
441            Err(Error::Internal {
442                message: "Expected exactly one batch".to_string(),
443                location: location!(),
444            })
445        } else {
446            Ok(batches.into_iter().next().unwrap())
447        }
448    }
449}