lance_datafusion/
projection.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use arrow_array::RecordBatch;
5use arrow_schema::{Field as ArrowField, Schema as ArrowSchema};
6use datafusion::{logical_expr::Expr, physical_plan::projection::ProjectionExec};
7use datafusion_common::DFSchema;
8use datafusion_physical_expr::{expressions, PhysicalExpr};
9use futures::TryStreamExt;
10use snafu::location;
11use std::{
12    collections::{HashMap, HashSet},
13    sync::Arc,
14};
15
16use lance_core::{
17    datatypes::{OnMissing, Projectable, Projection, Schema},
18    Error, Result,
19};
20
21use crate::{
22    exec::{execute_plan, LanceExecutionOptions, OneShotExec},
23    planner::Planner,
24};
25
26#[derive(Clone, Debug)]
27pub struct ProjectionPlan {
28    /// The base thing we are projecting from (e.g. a dataset)
29    base: Arc<dyn Projectable>,
30    /// The physical schema that must be loaded from the dataset
31    pub physical_projection: Projection,
32    /// True if the final schema of datafusion execution plan is empty.
33    /// e.g.
34    ///   select count(*) from dataset
35    pub final_projection_is_empty: bool,
36
37    /// If present, expressions that represent the output columns.  These expressions
38    /// run on the output of the physical projection.
39    ///
40    /// If not present, the output is the physical projection.
41    ///
42    /// Note: this doesn't include _distance, and _rowid
43    pub requested_output_expr: Option<Vec<(Expr, String)>>,
44}
45
46impl ProjectionPlan {
47    /// Create a new projection plan which projects all columns and does not include any expressions
48    pub fn new(base: Arc<dyn Projectable>) -> Self {
49        let physical_projection = Projection::full(base.clone());
50        Self {
51            base,
52            physical_projection,
53            final_projection_is_empty: false,
54            requested_output_expr: None,
55        }
56    }
57
58    /// Set the projection from SQL expressions
59    pub fn project_from_expressions(
60        &mut self,
61        columns: &[(impl AsRef<str>, impl AsRef<str>)],
62    ) -> Result<()> {
63        // Save off values of with_row_id / with_row_addr
64        let had_row_id = self.physical_projection.with_row_id;
65        let had_row_addr = self.physical_projection.with_row_addr;
66
67        // First, look at the expressions to figure out which physical columns are needed
68        let full_schema = Arc::new(Projection::full(self.base.clone()).to_arrow_schema());
69        let planner = Planner::new(full_schema);
70        let mut output = HashMap::new();
71        let mut physical_cols_set = HashSet::new();
72        let mut physical_cols = vec![];
73        for (output_name, raw_expr) in columns {
74            if output.contains_key(output_name.as_ref()) {
75                return Err(Error::io(
76                    format!("Duplicate column name: {}", output_name.as_ref()),
77                    location!(),
78                ));
79            }
80            let expr = planner.parse_expr(raw_expr.as_ref())?;
81            for col in Planner::column_names_in_expr(&expr) {
82                if physical_cols_set.contains(&col) {
83                    continue;
84                }
85                physical_cols.push(col.clone());
86                physical_cols_set.insert(col);
87            }
88            output.insert(output_name.as_ref().to_string(), expr);
89        }
90
91        // Now, calculate the physical projection from the columns referenced by the expressions
92        //
93        // If a column is missing it might be a metadata column (_rowid, _distance, etc.) and so
94        // we ignore it.  We don't need to load that column from disk at least, which is all we are
95        // trying to calculate here.
96        let mut physical_projection = Projection::empty(self.base.clone())
97            .union_columns(&physical_cols, OnMissing::Ignore)?;
98
99        // Restore the row_id and row_addr flags
100        physical_projection.with_row_id = had_row_id;
101        physical_projection.with_row_addr = had_row_addr;
102
103        self.physical_projection = physical_projection;
104
105        // Save off the expressions (they will be evaluated later to run the projection)
106        let mut output_cols = vec![];
107        for (name, _) in columns {
108            output_cols.push((output[name.as_ref()].clone(), name.as_ref().to_string()));
109        }
110        self.requested_output_expr = Some(output_cols);
111
112        Ok(())
113    }
114
115    /// Set the projection from a schema
116    ///
117    /// This plan will have no complex expressions
118    pub fn project_from_schema(&mut self, projection: &Schema) {
119        let had_row_id = self.physical_projection.with_row_id;
120        let had_row_addr = self.physical_projection.with_row_addr;
121
122        let mut physical_projection = Projection::empty(self.base.clone()).union_schema(projection);
123
124        physical_projection.with_row_id = had_row_id;
125        physical_projection.with_row_addr = had_row_addr;
126
127        self.physical_projection = physical_projection;
128    }
129
130    pub fn to_physical_exprs(
131        &self,
132        current_schema: &ArrowSchema,
133    ) -> Result<Vec<(Arc<dyn PhysicalExpr>, String)>> {
134        let physical_df_schema = Arc::new(DFSchema::try_from(current_schema.clone())?);
135        if let Some(output_expr) = &self.requested_output_expr {
136            output_expr
137                .iter()
138                .map(|(expr, name)| {
139                    Ok((
140                        datafusion::physical_expr::create_physical_expr(
141                            expr,
142                            physical_df_schema.as_ref(),
143                            &Default::default(),
144                        )?,
145                        name.clone(),
146                    ))
147                })
148                .collect::<Result<Vec<_>>>()
149        } else {
150            let projection_schema = self.physical_projection.to_schema();
151            projection_schema
152                .fields
153                .iter()
154                .map(|f| {
155                    Ok((
156                        expressions::col(f.name.as_str(), physical_df_schema.as_arrow())?.clone(),
157                        f.name.clone(),
158                    ))
159                })
160                .collect::<Result<Vec<_>>>()
161        }
162    }
163
164    pub fn include_row_id(&mut self) {
165        self.physical_projection.with_row_id = true;
166    }
167
168    pub fn include_row_addr(&mut self) {
169        self.physical_projection.with_row_addr = true;
170    }
171
172    pub fn output_schema(&self) -> Result<ArrowSchema> {
173        let exprs = self.to_physical_exprs(&self.physical_projection.to_arrow_schema())?;
174        let physical_schema = self.physical_projection.to_arrow_schema();
175        let fields = exprs
176            .iter()
177            .map(|(expr, name)| {
178                Ok(ArrowField::new(
179                    name,
180                    expr.data_type(&physical_schema)?,
181                    expr.nullable(&physical_schema)?,
182                ))
183            })
184            .collect::<Result<Vec<_>>>()?;
185        Ok(ArrowSchema::new(fields))
186    }
187
188    pub async fn project_batch(&self, batch: RecordBatch) -> Result<RecordBatch> {
189        if self.requested_output_expr.is_none() {
190            return Ok(batch);
191        }
192        let src = Arc::new(OneShotExec::from_batch(batch));
193        let physical_exprs = self.to_physical_exprs(&self.physical_projection.to_arrow_schema())?;
194        let projection = Arc::new(ProjectionExec::try_new(physical_exprs, src)?);
195        let stream = execute_plan(projection, LanceExecutionOptions::default())?;
196        let batches = stream.try_collect::<Vec<_>>().await?;
197        if batches.len() != 1 {
198            Err(Error::Internal {
199                message: "Expected exactly one batch".to_string(),
200                location: location!(),
201            })
202        } else {
203            Ok(batches.into_iter().next().unwrap())
204        }
205    }
206}