lance_datafusion/
projection.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use arrow_array::RecordBatch;
5use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
6use datafusion::{logical_expr::Expr, physical_plan::projection::ProjectionExec};
7use datafusion_common::{Column, DFSchema};
8use datafusion_physical_expr::PhysicalExpr;
9use futures::TryStreamExt;
10use snafu::location;
11use std::{
12    collections::{HashMap, HashSet},
13    sync::Arc,
14};
15
16use lance_core::{
17    datatypes::{OnMissing, Projectable, Projection, Schema},
18    Error, Result, ROW_ADDR, ROW_ID, ROW_OFFSET,
19};
20
21use crate::{
22    exec::{execute_plan, LanceExecutionOptions, OneShotExec},
23    planner::Planner,
24};
25
26#[derive(Clone, Debug)]
27pub struct OutputColumn {
28    /// The expression that represents the output column
29    pub expr: Expr,
30    /// The name of the output column
31    pub name: String,
32}
33
34#[derive(Clone, Debug)]
35pub struct ProjectionPlan {
36    /// The physical schema that must be loaded from the dataset
37    pub physical_projection: Projection,
38
39    /// Needs the row address converted into a row offset
40    pub must_add_row_offset: bool,
41
42    /// The desired output columns
43    pub requested_output_expr: Vec<OutputColumn>,
44}
45
46impl ProjectionPlan {
47    fn add_system_columns(schema: &ArrowSchema) -> ArrowSchema {
48        let mut fields = Vec::from_iter(schema.fields.iter().cloned());
49        fields.push(Arc::new(ArrowField::new(ROW_ID, DataType::UInt64, true)));
50        fields.push(Arc::new(ArrowField::new(ROW_ADDR, DataType::UInt64, true)));
51        fields.push(Arc::new(ArrowField::new(
52            ROW_OFFSET,
53            DataType::UInt64,
54            true,
55        )));
56        ArrowSchema::new(fields)
57    }
58
59    /// Set the projection from SQL expressions
60    pub fn from_expressions(
61        base: Arc<dyn Projectable>,
62        columns: &[(impl AsRef<str>, impl AsRef<str>)],
63    ) -> Result<Self> {
64        // First, look at the expressions to figure out which physical columns are needed
65        let full_schema = Arc::new(Projection::full(base.clone()).to_arrow_schema());
66        let full_schema = Arc::new(Self::add_system_columns(&full_schema));
67        let planner = Planner::new(full_schema);
68        let mut output = HashMap::new();
69        let mut physical_cols_set = HashSet::new();
70        let mut physical_cols = vec![];
71        let mut needs_row_id = false;
72        let mut needs_row_addr = false;
73        let mut must_add_row_offset = false;
74        for (output_name, raw_expr) in columns {
75            if output.contains_key(output_name.as_ref()) {
76                return Err(Error::io(
77                    format!("Duplicate column name: {}", output_name.as_ref()),
78                    location!(),
79                ));
80            }
81
82            let expr = planner.parse_expr(raw_expr.as_ref())?;
83
84            // If the expression is a bare column reference to a system column, mark that we need it
85            if let Expr::Column(Column {
86                name,
87                relation: None,
88                ..
89            }) = &expr
90            {
91                if name == ROW_ID {
92                    needs_row_id = true;
93                } else if name == ROW_ADDR {
94                    needs_row_addr = true;
95                } else if name == ROW_OFFSET {
96                    must_add_row_offset = true;
97                }
98            }
99
100            for col in Planner::column_names_in_expr(&expr) {
101                if physical_cols_set.contains(&col) {
102                    continue;
103                }
104                physical_cols.push(col.clone());
105                physical_cols_set.insert(col);
106            }
107            output.insert(output_name.as_ref().to_string(), expr);
108        }
109
110        // Now, calculate the physical projection from the columns referenced by the expressions
111        //
112        // If a column is missing it might be a metadata column (_rowid, _distance, etc.) and so
113        // we ignore it.  We don't need to load that column from disk at least, which is all we are
114        // trying to calculate here.
115        let mut physical_projection =
116            Projection::empty(base.clone()).union_columns(&physical_cols, OnMissing::Ignore)?;
117
118        physical_projection.with_row_id = needs_row_id;
119        physical_projection.with_row_addr = needs_row_addr || must_add_row_offset;
120
121        // Save off the expressions (they will be evaluated later to run the projection)
122        let mut output_cols = vec![];
123        for (name, _) in columns {
124            output_cols.push(OutputColumn {
125                expr: output[name.as_ref()].clone(),
126                name: name.as_ref().to_string(),
127            });
128        }
129
130        Ok(Self {
131            physical_projection,
132            must_add_row_offset,
133            requested_output_expr: output_cols,
134        })
135    }
136
137    /// Set the projection from a schema
138    ///
139    /// This plan will have no complex expressions, the schema must be a subset of the dataset schema.
140    ///
141    /// With this approach it is possible to refer to portions of nested fields.
142    ///
143    /// For example, if the schema is:
144    ///
145    /// ```ignore
146    /// {
147    ///   "metadata": {
148    ///     "location": {
149    ///       "x": f32,
150    ///       "y": f32,
151    ///     },
152    ///     "age": i32,
153    ///   }
154    /// }
155    /// ```
156    ///
157    /// It is possible to project a partial schema that drops `y` like:
158    ///
159    /// ```ignore
160    /// {
161    ///   "metadata": {
162    ///     "location": {
163    ///       "x": f32,
164    ///     },
165    ///     "age": i32,
166    ///   }
167    /// }
168    /// ```
169    ///
170    /// This is something that cannot be done easily using expressions.
171    pub fn from_schema(base: Arc<dyn Projectable>, projection: &Schema) -> Result<Self> {
172        // Calculate the physical projection directly from the schema
173        //
174        // The _rowid and _rowaddr columns will be recognized and added to the physical projection
175        //
176        // Any columns with an id of -1 (e.g. _rowoffset) will be ignored
177        let physical_projection = Projection::empty(base).union_schema(projection);
178        let mut must_add_row_offset = false;
179        // Now calculate the output expressions.  This will only reorder top-level columns.  We don't
180        // support reordering nested fields.
181        let exprs = projection
182            .fields
183            .iter()
184            .map(|f| {
185                if f.name == ROW_ADDR {
186                    must_add_row_offset = true;
187                }
188                OutputColumn {
189                    expr: Expr::Column(Column::from_name(&f.name)),
190                    name: f.name.clone(),
191                }
192            })
193            .collect::<Vec<_>>();
194        Ok(Self {
195            physical_projection,
196            requested_output_expr: exprs,
197            must_add_row_offset,
198        })
199    }
200
201    pub fn full(base: Arc<dyn Projectable>) -> Result<Self> {
202        let physical_cols: Vec<&str> = base
203            .schema()
204            .fields
205            .iter()
206            .map(|f| f.name.as_ref())
207            .collect::<Vec<_>>();
208
209        let physical_projection =
210            Projection::empty(base.clone()).union_columns(&physical_cols, OnMissing::Ignore)?;
211
212        let requested_output_expr = physical_cols
213            .into_iter()
214            .map(|col_name| OutputColumn {
215                expr: Expr::Column(Column::from_name(col_name)),
216                name: col_name.to_string(),
217            })
218            .collect();
219
220        Ok(Self {
221            physical_projection,
222            must_add_row_offset: false,
223            requested_output_expr,
224        })
225    }
226
227    /// Convert the projection to a list of physical expressions
228    ///
229    /// This is used to apply the final projection (including dynamic expressions) to the data.
230    pub fn to_physical_exprs(
231        &self,
232        current_schema: &ArrowSchema,
233    ) -> Result<Vec<(Arc<dyn PhysicalExpr>, String)>> {
234        let physical_df_schema = Arc::new(DFSchema::try_from(current_schema.clone())?);
235        self.requested_output_expr
236            .iter()
237            .map(|output_column| {
238                Ok((
239                    datafusion::physical_expr::create_physical_expr(
240                        &output_column.expr,
241                        physical_df_schema.as_ref(),
242                        &Default::default(),
243                    )?,
244                    output_column.name.clone(),
245                ))
246            })
247            .collect::<Result<Vec<_>>>()
248    }
249
250    /// Include the row id in the output
251    pub fn include_row_id(&mut self) {
252        self.physical_projection.with_row_id = true;
253        if !self
254            .requested_output_expr
255            .iter()
256            .any(|OutputColumn { name, .. }| name == ROW_ID)
257        {
258            self.requested_output_expr.push(OutputColumn {
259                expr: Expr::Column(Column::from_name(ROW_ID)),
260                name: ROW_ID.to_string(),
261            });
262        }
263    }
264
265    /// Include the row address in the output
266    pub fn include_row_addr(&mut self) {
267        self.physical_projection.with_row_addr = true;
268        if !self
269            .requested_output_expr
270            .iter()
271            .any(|OutputColumn { name, .. }| name == ROW_ADDR)
272        {
273            self.requested_output_expr.push(OutputColumn {
274                expr: Expr::Column(Column::from_name(ROW_ADDR)),
275                name: ROW_ADDR.to_string(),
276            });
277        }
278    }
279
280    pub fn include_row_offset(&mut self) {
281        // Need row addr to get row offset
282        self.physical_projection.with_row_addr = true;
283        self.must_add_row_offset = true;
284        if !self
285            .requested_output_expr
286            .iter()
287            .any(|OutputColumn { name, .. }| name == ROW_OFFSET)
288        {
289            self.requested_output_expr.push(OutputColumn {
290                expr: Expr::Column(Column::from_name(ROW_OFFSET)),
291                name: ROW_OFFSET.to_string(),
292            });
293        }
294    }
295
296    /// Check if the projection has any output columns
297    ///
298    /// This doesn't mean there is a physical projection.  For example, we may someday support
299    /// something like `SELECT 1 AS foo` which would have an output column (foo) but no physical projection
300    pub fn has_output_cols(&self) -> bool {
301        !self.requested_output_expr.is_empty()
302    }
303
304    pub fn output_schema(&self) -> Result<ArrowSchema> {
305        let exprs = self.to_physical_exprs(&self.physical_projection.to_arrow_schema())?;
306        let physical_schema = self.physical_projection.to_arrow_schema();
307        let fields = exprs
308            .iter()
309            .map(|(expr, name)| {
310                Ok(ArrowField::new(
311                    name,
312                    expr.data_type(&physical_schema)?,
313                    expr.nullable(&physical_schema)?,
314                ))
315            })
316            .collect::<Result<Vec<_>>>()?;
317        Ok(ArrowSchema::new(fields))
318    }
319
320    pub async fn project_batch(&self, batch: RecordBatch) -> Result<RecordBatch> {
321        let src = Arc::new(OneShotExec::from_batch(batch));
322        let physical_exprs = self.to_physical_exprs(&self.physical_projection.to_arrow_schema())?;
323        let projection = Arc::new(ProjectionExec::try_new(physical_exprs, src)?);
324        let stream = execute_plan(projection, LanceExecutionOptions::default())?;
325        let batches = stream.try_collect::<Vec<_>>().await?;
326        if batches.len() != 1 {
327            Err(Error::Internal {
328                message: "Expected exactly one batch".to_string(),
329                location: location!(),
330            })
331        } else {
332            Ok(batches.into_iter().next().unwrap())
333        }
334    }
335}