lance_datafusion/
projection.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use arrow_array::RecordBatch;
5use arrow_schema::{Field as ArrowField, Schema as ArrowSchema};
6use datafusion::{logical_expr::Expr, physical_plan::projection::ProjectionExec};
7use datafusion_common::DFSchema;
8use datafusion_physical_expr::{expressions, PhysicalExpr};
9use futures::TryStreamExt;
10use snafu::location;
11use std::{
12    collections::{HashMap, HashSet},
13    sync::Arc,
14};
15
16use lance_core::{
17    datatypes::{OnMissing, Projectable, Projection, Schema},
18    Error, Result,
19};
20
21use crate::{
22    exec::{execute_plan, LanceExecutionOptions, OneShotExec},
23    planner::Planner,
24};
25
26#[derive(Clone, Debug)]
27pub struct ProjectionPlan {
28    /// The base thing we are projecting from (e.g. a dataset)
29    base: Arc<dyn Projectable>,
30    /// The physical schema that must be loaded from the dataset
31    pub physical_projection: Projection,
32
33    /// If present, expressions that represent the output columns.  These expressions
34    /// run on the output of the physical projection.
35    ///
36    /// If not present, the output is the physical projection.
37    ///
38    /// Note: this doesn't include _distance, and _rowid
39    pub requested_output_expr: Option<Vec<(Expr, String)>>,
40}
41
42impl ProjectionPlan {
43    /// Create a new projection plan which projects all columns and does not include any expressions
44    pub fn new(base: Arc<dyn Projectable>) -> Self {
45        let physical_projection = Projection::full(base.clone());
46        Self {
47            base,
48            physical_projection,
49            requested_output_expr: None,
50        }
51    }
52
53    /// Set the projection from SQL expressions
54    pub fn project_from_expressions(
55        &mut self,
56        columns: &[(impl AsRef<str>, impl AsRef<str>)],
57    ) -> Result<()> {
58        // Save off values of with_row_id / with_row_addr
59        let had_row_id = self.physical_projection.with_row_id;
60        let had_row_addr = self.physical_projection.with_row_addr;
61
62        // First, look at the expressions to figure out which physical columns are needed
63        let full_schema = Arc::new(Projection::full(self.base.clone()).to_arrow_schema());
64        let planner = Planner::new(full_schema);
65        let mut output = HashMap::new();
66        let mut physical_cols_set = HashSet::new();
67        let mut physical_cols = vec![];
68        for (output_name, raw_expr) in columns {
69            if output.contains_key(output_name.as_ref()) {
70                return Err(Error::io(
71                    format!("Duplicate column name: {}", output_name.as_ref()),
72                    location!(),
73                ));
74            }
75            let expr = planner.parse_expr(raw_expr.as_ref())?;
76            for col in Planner::column_names_in_expr(&expr) {
77                if physical_cols_set.contains(&col) {
78                    continue;
79                }
80                physical_cols.push(col.clone());
81                physical_cols_set.insert(col);
82            }
83            output.insert(output_name.as_ref().to_string(), expr);
84        }
85
86        // Now, calculate the physical projection from the columns referenced by the expressions
87        //
88        // If a column is missing it might be a metadata column (_rowid, _distance, etc.) and so
89        // we ignore it.  We don't need to load that column from disk at least, which is all we are
90        // trying to calculate here.
91        let mut physical_projection = Projection::empty(self.base.clone())
92            .union_columns(&physical_cols, OnMissing::Ignore)?;
93
94        // Restore the row_id and row_addr flags
95        physical_projection.with_row_id = had_row_id;
96        physical_projection.with_row_addr = had_row_addr;
97
98        self.physical_projection = physical_projection;
99
100        // Save off the expressions (they will be evaluated later to run the projection)
101        let mut output_cols = vec![];
102        for (name, _) in columns {
103            output_cols.push((output[name.as_ref()].clone(), name.as_ref().to_string()));
104        }
105        self.requested_output_expr = Some(output_cols);
106
107        Ok(())
108    }
109
110    /// Set the projection from a schema
111    ///
112    /// This plan will have no complex expressions
113    pub fn project_from_schema(&mut self, projection: &Schema) {
114        let had_row_id = self.physical_projection.with_row_id;
115        let had_row_addr = self.physical_projection.with_row_addr;
116
117        let mut physical_projection = Projection::empty(self.base.clone()).union_schema(projection);
118
119        physical_projection.with_row_id = had_row_id;
120        physical_projection.with_row_addr = had_row_addr;
121
122        self.physical_projection = physical_projection;
123    }
124
125    pub fn to_physical_exprs(
126        &self,
127        current_schema: &ArrowSchema,
128    ) -> Result<Vec<(Arc<dyn PhysicalExpr>, String)>> {
129        let physical_df_schema = Arc::new(DFSchema::try_from(current_schema.clone())?);
130        if let Some(output_expr) = &self.requested_output_expr {
131            output_expr
132                .iter()
133                .map(|(expr, name)| {
134                    Ok((
135                        datafusion::physical_expr::create_physical_expr(
136                            expr,
137                            physical_df_schema.as_ref(),
138                            &Default::default(),
139                        )?,
140                        name.clone(),
141                    ))
142                })
143                .collect::<Result<Vec<_>>>()
144        } else {
145            let projection_schema = self.physical_projection.to_schema();
146            projection_schema
147                .fields
148                .iter()
149                .map(|f| {
150                    Ok((
151                        expressions::col(f.name.as_str(), physical_df_schema.as_arrow())?.clone(),
152                        f.name.clone(),
153                    ))
154                })
155                .collect::<Result<Vec<_>>>()
156        }
157    }
158
159    pub fn include_row_id(&mut self) {
160        self.physical_projection.with_row_id = true;
161    }
162
163    pub fn include_row_addr(&mut self) {
164        self.physical_projection.with_row_addr = true;
165    }
166
167    pub fn output_schema(&self) -> Result<ArrowSchema> {
168        let exprs = self.to_physical_exprs(&self.physical_projection.to_arrow_schema())?;
169        let physical_schema = self.physical_projection.to_arrow_schema();
170        let fields = exprs
171            .iter()
172            .map(|(expr, name)| {
173                Ok(ArrowField::new(
174                    name,
175                    expr.data_type(&physical_schema)?,
176                    expr.nullable(&physical_schema)?,
177                ))
178            })
179            .collect::<Result<Vec<_>>>()?;
180        Ok(ArrowSchema::new(fields))
181    }
182
183    pub async fn project_batch(&self, batch: RecordBatch) -> Result<RecordBatch> {
184        if self.requested_output_expr.is_none() {
185            return Ok(batch);
186        }
187        let src = Arc::new(OneShotExec::from_batch(batch));
188        let physical_exprs = self.to_physical_exprs(&self.physical_projection.to_arrow_schema())?;
189        let projection = Arc::new(ProjectionExec::try_new(physical_exprs, src)?);
190        let stream = execute_plan(projection, LanceExecutionOptions::default())?;
191        let batches = stream.try_collect::<Vec<_>>().await?;
192        if batches.len() != 1 {
193            Err(Error::Internal {
194                message: "Expected exactly one batch".to_string(),
195                location: location!(),
196            })
197        } else {
198            Ok(batches.into_iter().next().unwrap())
199        }
200    }
201}