use std::sync::Arc;
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use datafusion::physical_expr::PhysicalExpr;
use datafusion::physical_expr::expressions::{Column, Literal};
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::projection::ProjectionExec;
use datafusion::scalar::ScalarValue;
use lance_core::{ROW_ADDR, ROW_ID, Result, is_system_column};
pub const DISTANCE_COLUMN: &str = "_distance";
pub fn wants_row_id(projection: Option<&[String]>) -> bool {
projection
.map(|p| p.iter().any(|c| c == ROW_ID))
.unwrap_or(false)
}
pub fn wants_row_address(projection: Option<&[String]>) -> bool {
projection
.map(|p| p.iter().any(|c| c == ROW_ADDR))
.unwrap_or(false)
}
fn is_auto_managed(col: &str) -> bool {
col == DISTANCE_COLUMN || is_system_column(col)
}
pub fn build_scanner_projection(
user_projection: Option<&[String]>,
base_schema: &SchemaRef,
pk_columns: &[String],
) -> Vec<String> {
let mut cols: Vec<String> = if let Some(p) = user_projection {
p.iter().filter(|c| !is_auto_managed(c)).cloned().collect()
} else {
base_schema
.fields()
.iter()
.map(|f| f.name().clone())
.collect()
};
for pk in pk_columns {
if !cols.contains(pk) {
cols.push(pk.clone());
}
}
cols
}
pub fn canonical_output_schema(
user_projection: Option<&[String]>,
base_schema: &SchemaRef,
pk_columns: &[String],
include_distance: bool,
) -> SchemaRef {
let mut ordered: Vec<String> = if let Some(p) = user_projection {
p.to_vec()
} else {
base_schema
.fields()
.iter()
.map(|f| f.name().clone())
.collect()
};
for pk in pk_columns {
if !ordered.contains(pk) {
ordered.push(pk.clone());
}
}
if include_distance && !ordered.iter().any(|c| c == DISTANCE_COLUMN) {
ordered.push(DISTANCE_COLUMN.to_string());
}
let fields: Vec<Arc<Field>> = ordered
.iter()
.filter_map(|name| {
if name == DISTANCE_COLUMN {
include_distance
.then(|| Arc::new(Field::new(DISTANCE_COLUMN, DataType::Float32, true)))
} else if is_system_column(name) {
Some(Arc::new(Field::new(name.clone(), DataType::UInt64, true)))
} else {
base_schema
.field_with_name(name)
.ok()
.map(|f| Arc::new(f.clone()))
}
})
.collect();
Arc::new(Schema::new(fields))
}
pub fn canonical_internal_schema(
user_projection: Option<&[String]>,
base_schema: &SchemaRef,
pk_columns: &[String],
include_distance: bool,
) -> SchemaRef {
use crate::dataset::mem_wal::scanner::exec::{FRESHNESS_COLUMN, MEMTABLE_GEN_COLUMN};
let canonical =
canonical_output_schema(user_projection, base_schema, pk_columns, include_distance);
let mut fields: Vec<Arc<Field>> = canonical.fields().iter().cloned().collect();
fields.push(Arc::new(Field::new(
MEMTABLE_GEN_COLUMN,
DataType::UInt64,
false,
)));
fields.push(Arc::new(Field::new(
FRESHNESS_COLUMN,
DataType::UInt64,
true,
)));
Arc::new(Schema::new(fields))
}
pub fn null_columns(
plan: Arc<dyn ExecutionPlan>,
names: &[&str],
) -> Result<Arc<dyn ExecutionPlan>> {
let input_schema = plan.schema();
let mut project_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> =
Vec::with_capacity(input_schema.fields().len());
for (idx, field) in input_schema.fields().iter().enumerate() {
let name = field.name();
let expr: Arc<dyn PhysicalExpr> = if names.contains(&name.as_str()) {
Arc::new(Literal::new(
ScalarValue::try_from(field.data_type()).map_err(|e| {
lance_core::Error::internal(format!(
"Cannot build NULL literal for {}: {}",
field.data_type(),
e
))
})?,
))
} else {
Arc::new(Column::new(name, idx))
};
project_exprs.push((expr, name.clone()));
}
let projection_exec = ProjectionExec::try_new(project_exprs, plan).map_err(|e| {
lance_core::Error::internal(format!(
"Failed to build null_columns ProjectionExec: {}",
e
))
})?;
Ok(Arc::new(projection_exec))
}
pub fn project_to_canonical(
plan: Arc<dyn ExecutionPlan>,
target_schema: &SchemaRef,
) -> Result<Arc<dyn ExecutionPlan>> {
let input_schema = plan.schema();
let mut project_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> =
Vec::with_capacity(target_schema.fields().len());
for field in target_schema.fields() {
let name = field.name();
let expr: Arc<dyn PhysicalExpr> = match input_schema.column_with_name(name) {
Some((idx, _)) => Arc::new(Column::new(name, idx)),
None if is_system_column(name) => Arc::new(Literal::new(ScalarValue::UInt64(None))),
None if name == DISTANCE_COLUMN => Arc::new(Literal::new(ScalarValue::Float32(None))),
None => {
return Err(lance_core::Error::internal(format!(
"Column '{}' missing from canonical projection source schema (have: {:?})",
name,
input_schema
.fields()
.iter()
.map(|f| f.name().clone())
.collect::<Vec<_>>()
)));
}
};
project_exprs.push((expr, name.clone()));
}
let projection_exec = ProjectionExec::try_new(project_exprs, plan).map_err(|e| {
lance_core::Error::internal(format!("Failed to build canonical ProjectionExec: {}", e))
})?;
Ok(Arc::new(projection_exec))
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_schema::Schema as ArrowSchema;
fn schema() -> SchemaRef {
Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, true),
Field::new("vector", DataType::Float32, true),
]))
}
#[test]
fn scanner_projection_strips_system_and_distance() {
let s = schema();
let pks = vec!["id".to_string()];
let user = vec![
"_distance".to_string(),
"vector".to_string(),
"_rowid".to_string(),
"_rowaddr".to_string(),
];
let cols = build_scanner_projection(Some(&user), &s, &pks);
assert_eq!(cols, vec!["vector".to_string(), "id".to_string()]);
}
#[test]
fn scanner_projection_default_uses_base_schema() {
let s = schema();
let pks = vec!["id".to_string()];
let cols = build_scanner_projection(None, &s, &pks);
assert_eq!(
cols,
vec!["id".to_string(), "name".to_string(), "vector".to_string()]
);
}
#[test]
fn canonical_schema_honors_user_order_for_distance() {
let s = schema();
let pks = vec!["id".to_string()];
let user = vec!["_distance".to_string(), "vector".to_string()];
let out = canonical_output_schema(Some(&user), &s, &pks, true);
let names: Vec<&str> = out.fields().iter().map(|f| f.name().as_str()).collect();
assert_eq!(names, vec!["_distance", "vector", "id"]);
assert_eq!(
out.field_with_name("_distance").unwrap().data_type(),
&DataType::Float32
);
}
#[test]
fn canonical_schema_includes_system_cols_as_nullable_uint64() {
let s = schema();
let pks = vec!["id".to_string()];
let user = vec![
"vector".to_string(),
"_rowid".to_string(),
"_rowaddr".to_string(),
"_rowoffset".to_string(),
];
let out = canonical_output_schema(Some(&user), &s, &pks, false);
let names: Vec<&str> = out.fields().iter().map(|f| f.name().as_str()).collect();
assert_eq!(
names,
vec!["vector", "_rowid", "_rowaddr", "_rowoffset", "id"]
);
for sys in ["_rowid", "_rowaddr", "_rowoffset"] {
let field = out.field_with_name(sys).unwrap();
assert_eq!(field.data_type(), &DataType::UInt64);
assert!(field.is_nullable(), "{sys} must be nullable for NULL fill");
}
}
#[test]
fn canonical_schema_appends_distance_when_missing() {
let s = schema();
let pks = vec!["id".to_string()];
let user = vec!["vector".to_string()];
let out = canonical_output_schema(Some(&user), &s, &pks, true);
let names: Vec<&str> = out.fields().iter().map(|f| f.name().as_str()).collect();
assert_eq!(names, vec!["vector", "id", "_distance"]);
}
#[test]
fn canonical_schema_drops_distance_when_not_requested() {
let s = schema();
let pks = vec!["id".to_string()];
let user = vec!["_distance".to_string(), "vector".to_string()];
let out = canonical_output_schema(Some(&user), &s, &pks, false);
let names: Vec<&str> = out.fields().iter().map(|f| f.name().as_str()).collect();
assert_eq!(names, vec!["vector", "id"]);
}
}