lance_datafusion/
projection.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use arrow_array::RecordBatch;
5use arrow_schema::{Field as ArrowField, Schema as ArrowSchema};
6use datafusion::{logical_expr::Expr, physical_plan::projection::ProjectionExec};
7use datafusion_common::DFSchema;
8use datafusion_physical_expr::{expressions, PhysicalExpr};
9use futures::TryStreamExt;
10use snafu::location;
11use std::{
12    collections::{HashMap, HashSet},
13    sync::Arc,
14};
15
16use lance_core::{
17    datatypes::{OnMissing, Projectable, Projection, Schema},
18    Error, Result,
19};
20
21use crate::{
22    exec::{execute_plan, LanceExecutionOptions, OneShotExec},
23    planner::Planner,
24};
25
26#[derive(Clone, Debug)]
27pub struct ProjectionPlan {
28    /// The base thing we are projecting from (e.g. a dataset)
29    base: Arc<dyn Projectable>,
30    /// The physical schema that must be loaded from the dataset
31    pub physical_projection: Projection,
32
33    /// True if the user wants the row id in the final output
34    ///
35    /// Note: this is related, but slightly different, to physical_projection.with_row_id
36    /// which only tracks if the row id is needed.
37    ///
38    /// desires_row_id implies with_row_id is true
39    /// However, it is possible to have desires_row_id=false and with_row_id=true (e.g. when
40    /// the row id is needed to perform a late materialization take)
41    pub desires_row_id: bool,
42    /// True if the user wants the row address in the final output
43    ///
44    /// Note: this is related, but slightly different, to physical_projection.with_row_addr
45    /// which only tracks if the row address is needed.
46    ///
47    /// desires_row_addr implies with_row_addr is true
48    /// However, it is possible to have deisres_row_addr=false and with_row_addr=true (e.g. during
49    /// a count query)
50    pub desires_row_addr: bool,
51
52    /// If present, expressions that represent the output columns.  These expressions
53    /// run on the output of the physical projection.
54    ///
55    /// If not present, the output is the physical projection.
56    ///
57    /// Note: this doesn't include _distance, and _rowid
58    pub requested_output_expr: Option<Vec<(Expr, String)>>,
59}
60
61impl ProjectionPlan {
62    /// Create a new projection plan which projects all columns and does not include any expressions
63    pub fn new(base: Arc<dyn Projectable>) -> Self {
64        let physical_projection = Projection::full(base.clone());
65        Self {
66            base,
67            physical_projection,
68            requested_output_expr: None,
69            desires_row_addr: false,
70            desires_row_id: false,
71        }
72    }
73
74    /// Set the projection from SQL expressions
75    pub fn project_from_expressions(
76        &mut self,
77        columns: &[(impl AsRef<str>, impl AsRef<str>)],
78    ) -> Result<()> {
79        // Save off values of with_row_id / with_row_addr
80        let had_row_id = self.physical_projection.with_row_id;
81        let had_row_addr = self.physical_projection.with_row_addr;
82
83        // First, look at the expressions to figure out which physical columns are needed
84        let full_schema = Arc::new(Projection::full(self.base.clone()).to_arrow_schema());
85        let planner = Planner::new(full_schema);
86        let mut output = HashMap::new();
87        let mut physical_cols_set = HashSet::new();
88        let mut physical_cols = vec![];
89        for (output_name, raw_expr) in columns {
90            if output.contains_key(output_name.as_ref()) {
91                return Err(Error::io(
92                    format!("Duplicate column name: {}", output_name.as_ref()),
93                    location!(),
94                ));
95            }
96            let expr = planner.parse_expr(raw_expr.as_ref())?;
97            for col in Planner::column_names_in_expr(&expr) {
98                if physical_cols_set.contains(&col) {
99                    continue;
100                }
101                physical_cols.push(col.clone());
102                physical_cols_set.insert(col);
103            }
104            output.insert(output_name.as_ref().to_string(), expr);
105        }
106
107        // Now, calculate the physical projection from the columns referenced by the expressions
108        //
109        // If a column is missing it might be a metadata column (_rowid, _distance, etc.) and so
110        // we ignore it.  We don't need to load that column from disk at least, which is all we are
111        // trying to calculate here.
112        let mut physical_projection = Projection::empty(self.base.clone())
113            .union_columns(&physical_cols, OnMissing::Ignore)?;
114
115        // Restore the row_id and row_addr flags
116        physical_projection.with_row_id = had_row_id;
117        physical_projection.with_row_addr = had_row_addr;
118
119        self.physical_projection = physical_projection;
120
121        // Save off the expressions (they will be evaluated later to run the projection)
122        let mut output_cols = vec![];
123        for (name, _) in columns {
124            output_cols.push((output[name.as_ref()].clone(), name.as_ref().to_string()));
125        }
126        self.requested_output_expr = Some(output_cols);
127
128        Ok(())
129    }
130
131    /// Set the projection from a schema
132    ///
133    /// This plan will have no complex expressions
134    pub fn project_from_schema(&mut self, projection: &Schema) {
135        let had_row_id = self.physical_projection.with_row_id;
136        let had_row_addr = self.physical_projection.with_row_addr;
137
138        let mut physical_projection = Projection::empty(self.base.clone()).union_schema(projection);
139
140        physical_projection.with_row_id = had_row_id;
141        physical_projection.with_row_addr = had_row_addr;
142
143        self.physical_projection = physical_projection;
144    }
145
146    /// Convert the projection to a list of physical expressions
147    ///
148    /// This is used to apply the final projection (including dynamic expressions) to the data.
149    pub fn to_physical_exprs(
150        &self,
151        current_schema: &ArrowSchema,
152    ) -> Result<Vec<(Arc<dyn PhysicalExpr>, String)>> {
153        let physical_df_schema = Arc::new(DFSchema::try_from(current_schema.clone())?);
154        if let Some(output_expr) = &self.requested_output_expr {
155            output_expr
156                .iter()
157                .map(|(expr, name)| {
158                    Ok((
159                        datafusion::physical_expr::create_physical_expr(
160                            expr,
161                            physical_df_schema.as_ref(),
162                            &Default::default(),
163                        )?,
164                        name.clone(),
165                    ))
166                })
167                .collect::<Result<Vec<_>>>()
168        } else {
169            let projection_schema = self.physical_projection.to_schema();
170            projection_schema
171                .fields
172                .iter()
173                .map(|f| {
174                    Ok((
175                        expressions::col(f.name.as_str(), physical_df_schema.as_arrow())?.clone(),
176                        f.name.clone(),
177                    ))
178                })
179                .collect::<Result<Vec<_>>>()
180        }
181    }
182
183    /// Include the row id in the output
184    pub fn include_row_id(&mut self) {
185        self.physical_projection.with_row_id = true;
186        self.desires_row_id = true;
187    }
188
189    /// Include the row address in the output
190    pub fn include_row_addr(&mut self) {
191        self.physical_projection.with_row_addr = true;
192        self.desires_row_addr = true;
193    }
194
195    /// Check if the projection has any output columns
196    ///
197    /// This doesn't mean there is a physical projection.  For example, we may someday support
198    /// something like `SELECT 1 AS foo` which would have an output column (foo) but no physical projection
199    pub fn has_output_cols(&self) -> bool {
200        if self.desires_row_id || self.desires_row_addr {
201            return true;
202        }
203        if let Some(exprs) = &self.requested_output_expr {
204            if !exprs.is_empty() {
205                return true;
206            }
207        }
208        self.physical_projection.has_non_meta_cols()
209    }
210
211    pub fn output_schema(&self) -> Result<ArrowSchema> {
212        let exprs = self.to_physical_exprs(&self.physical_projection.to_arrow_schema())?;
213        let physical_schema = self.physical_projection.to_arrow_schema();
214        let fields = exprs
215            .iter()
216            .map(|(expr, name)| {
217                Ok(ArrowField::new(
218                    name,
219                    expr.data_type(&physical_schema)?,
220                    expr.nullable(&physical_schema)?,
221                ))
222            })
223            .collect::<Result<Vec<_>>>()?;
224        Ok(ArrowSchema::new(fields))
225    }
226
227    pub async fn project_batch(&self, batch: RecordBatch) -> Result<RecordBatch> {
228        if self.requested_output_expr.is_none() {
229            return Ok(batch);
230        }
231        let src = Arc::new(OneShotExec::from_batch(batch));
232        let physical_exprs = self.to_physical_exprs(&self.physical_projection.to_arrow_schema())?;
233        let projection = Arc::new(ProjectionExec::try_new(physical_exprs, src)?);
234        let stream = execute_plan(projection, LanceExecutionOptions::default())?;
235        let batches = stream.try_collect::<Vec<_>>().await?;
236        if batches.len() != 1 {
237            Err(Error::Internal {
238                message: "Expected exactly one batch".to_string(),
239                location: location!(),
240            })
241        } else {
242            Ok(batches.into_iter().next().unwrap())
243        }
244    }
245}