lance_datafusion/
projection.rs1use arrow_array::RecordBatch;
5use arrow_schema::{Field as ArrowField, Schema as ArrowSchema};
6use datafusion::{logical_expr::Expr, physical_plan::projection::ProjectionExec};
7use datafusion_common::DFSchema;
8use datafusion_physical_expr::{expressions, PhysicalExpr};
9use futures::TryStreamExt;
10use snafu::location;
11use std::{
12 collections::{HashMap, HashSet},
13 sync::Arc,
14};
15
16use lance_core::{
17 datatypes::{OnMissing, Projectable, Projection, Schema},
18 Error, Result,
19};
20
21use crate::{
22 exec::{execute_plan, LanceExecutionOptions, OneShotExec},
23 planner::Planner,
24};
25
26#[derive(Clone, Debug)]
27pub struct ProjectionPlan {
28 base: Arc<dyn Projectable>,
30 pub physical_projection: Projection,
32
33 pub requested_output_expr: Option<Vec<(Expr, String)>>,
40}
41
42impl ProjectionPlan {
43 pub fn new(base: Arc<dyn Projectable>) -> Self {
45 let physical_projection = Projection::full(base.clone());
46 Self {
47 base,
48 physical_projection,
49 requested_output_expr: None,
50 }
51 }
52
53 pub fn project_from_expressions(
55 &mut self,
56 columns: &[(impl AsRef<str>, impl AsRef<str>)],
57 ) -> Result<()> {
58 let had_row_id = self.physical_projection.with_row_id;
60 let had_row_addr = self.physical_projection.with_row_addr;
61
62 let full_schema = Arc::new(Projection::full(self.base.clone()).to_arrow_schema());
64 let planner = Planner::new(full_schema);
65 let mut output = HashMap::new();
66 let mut physical_cols_set = HashSet::new();
67 let mut physical_cols = vec![];
68 for (output_name, raw_expr) in columns {
69 if output.contains_key(output_name.as_ref()) {
70 return Err(Error::io(
71 format!("Duplicate column name: {}", output_name.as_ref()),
72 location!(),
73 ));
74 }
75 let expr = planner.parse_expr(raw_expr.as_ref())?;
76 for col in Planner::column_names_in_expr(&expr) {
77 if physical_cols_set.contains(&col) {
78 continue;
79 }
80 physical_cols.push(col.clone());
81 physical_cols_set.insert(col);
82 }
83 output.insert(output_name.as_ref().to_string(), expr);
84 }
85
86 let mut physical_projection = Projection::empty(self.base.clone())
92 .union_columns(&physical_cols, OnMissing::Ignore)?;
93
94 physical_projection.with_row_id = had_row_id;
96 physical_projection.with_row_addr = had_row_addr;
97
98 self.physical_projection = physical_projection;
99
100 let mut output_cols = vec![];
102 for (name, _) in columns {
103 output_cols.push((output[name.as_ref()].clone(), name.as_ref().to_string()));
104 }
105 self.requested_output_expr = Some(output_cols);
106
107 Ok(())
108 }
109
110 pub fn project_from_schema(&mut self, projection: &Schema) {
114 let had_row_id = self.physical_projection.with_row_id;
115 let had_row_addr = self.physical_projection.with_row_addr;
116
117 let mut physical_projection = Projection::empty(self.base.clone()).union_schema(projection);
118
119 physical_projection.with_row_id = had_row_id;
120 physical_projection.with_row_addr = had_row_addr;
121
122 self.physical_projection = physical_projection;
123 }
124
125 pub fn to_physical_exprs(
126 &self,
127 current_schema: &ArrowSchema,
128 ) -> Result<Vec<(Arc<dyn PhysicalExpr>, String)>> {
129 let physical_df_schema = Arc::new(DFSchema::try_from(current_schema.clone())?);
130 if let Some(output_expr) = &self.requested_output_expr {
131 output_expr
132 .iter()
133 .map(|(expr, name)| {
134 Ok((
135 datafusion::physical_expr::create_physical_expr(
136 expr,
137 physical_df_schema.as_ref(),
138 &Default::default(),
139 )?,
140 name.clone(),
141 ))
142 })
143 .collect::<Result<Vec<_>>>()
144 } else {
145 let projection_schema = self.physical_projection.to_schema();
146 projection_schema
147 .fields
148 .iter()
149 .map(|f| {
150 Ok((
151 expressions::col(f.name.as_str(), physical_df_schema.as_arrow())?.clone(),
152 f.name.clone(),
153 ))
154 })
155 .collect::<Result<Vec<_>>>()
156 }
157 }
158
159 pub fn include_row_id(&mut self) {
160 self.physical_projection.with_row_id = true;
161 }
162
163 pub fn include_row_addr(&mut self) {
164 self.physical_projection.with_row_addr = true;
165 }
166
167 pub fn output_schema(&self) -> Result<ArrowSchema> {
168 let exprs = self.to_physical_exprs(&self.physical_projection.to_arrow_schema())?;
169 let physical_schema = self.physical_projection.to_arrow_schema();
170 let fields = exprs
171 .iter()
172 .map(|(expr, name)| {
173 Ok(ArrowField::new(
174 name,
175 expr.data_type(&physical_schema)?,
176 expr.nullable(&physical_schema)?,
177 ))
178 })
179 .collect::<Result<Vec<_>>>()?;
180 Ok(ArrowSchema::new(fields))
181 }
182
183 pub async fn project_batch(&self, batch: RecordBatch) -> Result<RecordBatch> {
184 if self.requested_output_expr.is_none() {
185 return Ok(batch);
186 }
187 let src = Arc::new(OneShotExec::from_batch(batch));
188 let physical_exprs = self.to_physical_exprs(&self.physical_projection.to_arrow_schema())?;
189 let projection = Arc::new(ProjectionExec::try_new(physical_exprs, src)?);
190 let stream = execute_plan(projection, LanceExecutionOptions::default())?;
191 let batches = stream.try_collect::<Vec<_>>().await?;
192 if batches.len() != 1 {
193 Err(Error::Internal {
194 message: "Expected exactly one batch".to_string(),
195 location: location!(),
196 })
197 } else {
198 Ok(batches.into_iter().next().unwrap())
199 }
200 }
201}