use std::sync::Arc;
use arrow_array::{Array, RecordBatch};
use arrow_schema::{DataType, Field, SchemaRef};
use datafusion::common::{ScalarValue, ToDFSchema};
use datafusion::physical_plan::limit::GlobalLimitExec;
use datafusion::physical_plan::{ExecutionPlan, SendableRecordBatchStream};
use datafusion::prelude::{Expr, SessionContext};
use futures::TryStreamExt;
use lance_core::{Error, ROW_ID, Result};
use lance_datafusion::expr::safe_coerce_scalar;
use lance_datafusion::planner::Planner;
use lance_linalg::distance::DistanceType;
use super::exec::{BTreeIndexExec, FtsIndexExec, MemTableScanExec, VectorIndexExec};
use crate::dataset::mem_wal::write::{BatchStore, IndexStore};
#[derive(Debug, Clone)]
pub struct VectorQuery {
pub column: String,
pub query_vector: Arc<dyn Array>,
pub k: usize,
pub nprobes: usize,
pub maximum_nprobes: Option<usize>,
pub distance_type: Option<DistanceType>,
pub ef: Option<usize>,
pub refine_factor: Option<u32>,
pub distance_lower_bound: Option<f32>,
pub distance_upper_bound: Option<f32>,
}
#[derive(Debug, Clone)]
pub enum FtsQueryType {
Match {
query: String,
},
Phrase {
query: String,
slop: u32,
},
Boolean {
must: Vec<String>,
should: Vec<String>,
must_not: Vec<String>,
},
Fuzzy {
query: String,
fuzziness: Option<u32>,
max_expansions: usize,
},
}
#[derive(Debug, Clone)]
pub struct FtsQuery {
pub column: String,
pub query_type: FtsQueryType,
pub wand_factor: f32,
}
pub const DEFAULT_MAX_EXPANSIONS: usize = 50;
pub const DEFAULT_WAND_FACTOR: f32 = 1.0;
impl FtsQuery {
pub fn match_query(column: impl Into<String>, query: impl Into<String>) -> Self {
Self {
column: column.into(),
query_type: FtsQueryType::Match {
query: query.into(),
},
wand_factor: DEFAULT_WAND_FACTOR,
}
}
pub fn phrase(column: impl Into<String>, query: impl Into<String>, slop: u32) -> Self {
Self {
column: column.into(),
query_type: FtsQueryType::Phrase {
query: query.into(),
slop,
},
wand_factor: DEFAULT_WAND_FACTOR,
}
}
pub fn boolean(
column: impl Into<String>,
must: Vec<String>,
should: Vec<String>,
must_not: Vec<String>,
) -> Self {
Self {
column: column.into(),
query_type: FtsQueryType::Boolean {
must,
should,
must_not,
},
wand_factor: DEFAULT_WAND_FACTOR,
}
}
pub fn fuzzy(column: impl Into<String>, query: impl Into<String>) -> Self {
Self {
column: column.into(),
query_type: FtsQueryType::Fuzzy {
query: query.into(),
fuzziness: None,
max_expansions: DEFAULT_MAX_EXPANSIONS,
},
wand_factor: DEFAULT_WAND_FACTOR,
}
}
pub fn fuzzy_with_distance(
column: impl Into<String>,
query: impl Into<String>,
fuzziness: u32,
) -> Self {
Self {
column: column.into(),
query_type: FtsQueryType::Fuzzy {
query: query.into(),
fuzziness: Some(fuzziness),
max_expansions: DEFAULT_MAX_EXPANSIONS,
},
wand_factor: DEFAULT_WAND_FACTOR,
}
}
pub fn fuzzy_with_options(
column: impl Into<String>,
query: impl Into<String>,
fuzziness: Option<u32>,
max_expansions: usize,
) -> Self {
Self {
column: column.into(),
query_type: FtsQueryType::Fuzzy {
query: query.into(),
fuzziness,
max_expansions,
},
wand_factor: DEFAULT_WAND_FACTOR,
}
}
pub fn with_wand_factor(mut self, wand_factor: f32) -> Self {
self.wand_factor = wand_factor.clamp(0.0, 1.0);
self
}
}
#[derive(Debug, Clone)]
pub enum ScalarPredicate {
Eq { column: String, value: ScalarValue },
Range {
column: String,
lower: Option<ScalarValue>,
upper: Option<ScalarValue>,
},
In {
column: String,
values: Vec<ScalarValue>,
},
}
impl ScalarPredicate {
pub fn column(&self) -> &str {
match self {
Self::Eq { column, .. } => column,
Self::Range { column, .. } => column,
Self::In { column, .. } => column,
}
}
}
pub struct MemTableScanner {
batch_store: Arc<BatchStore>,
indexes: Arc<IndexStore>,
schema: SchemaRef,
max_visible_batch_position: usize,
projection: Option<Vec<String>>,
filter: Option<Expr>,
limit: Option<usize>,
offset: Option<usize>,
nearest: Option<VectorQuery>,
full_text_query: Option<FtsQuery>,
use_index: bool,
batch_size: Option<usize>,
with_row_id: bool,
with_row_address: bool,
}
impl MemTableScanner {
pub fn new(batch_store: Arc<BatchStore>, indexes: Arc<IndexStore>, schema: SchemaRef) -> Self {
let max_visible_batch_position = indexes.max_indexed_batch_position();
Self {
batch_store,
indexes,
schema,
max_visible_batch_position,
projection: None,
filter: None,
limit: None,
offset: None,
nearest: None,
full_text_query: None,
use_index: true,
batch_size: None,
with_row_id: false,
with_row_address: false,
}
}
pub fn project(&mut self, columns: &[&str]) -> &mut Self {
let mut filtered_columns = Vec::new();
for col in columns {
if *col == ROW_ID {
self.with_row_id = true;
} else {
filtered_columns.push(col.to_string());
}
}
if !filtered_columns.is_empty() || self.with_row_id {
self.projection = Some(filtered_columns);
}
self
}
pub fn with_row_id(&mut self) -> &mut Self {
self.with_row_id = true;
self
}
pub fn with_row_address(&mut self) -> &mut Self {
self.with_row_address = true;
self
}
pub fn filter(&mut self, filter_expr: &str) -> Result<&mut Self> {
let ctx = SessionContext::new();
let df_schema = self
.schema
.clone()
.to_dfschema()
.map_err(|e| Error::invalid_input(format!("Failed to create DFSchema: {}", e)))?;
let expr = ctx.parse_sql_expr(filter_expr, &df_schema).map_err(|e| {
Error::invalid_input(format!("Failed to parse filter expression: {}", e))
})?;
self.filter = Some(expr);
Ok(self)
}
pub fn filter_expr(&mut self, expr: Expr) -> &mut Self {
self.filter = Some(expr);
self
}
pub fn limit(&mut self, limit: usize, offset: Option<usize>) -> &mut Self {
self.limit = Some(limit);
self.offset = offset;
self
}
pub fn nearest(&mut self, column: &str, query: Arc<dyn Array>, k: usize) -> &mut Self {
self.nearest = Some(VectorQuery {
column: column.to_string(),
query_vector: query,
k,
nprobes: 1,
maximum_nprobes: None,
distance_type: None,
ef: None,
refine_factor: None,
distance_lower_bound: None,
distance_upper_bound: None,
});
self
}
pub fn nprobes(&mut self, n: usize) -> &mut Self {
if let Some(ref mut q) = self.nearest {
q.nprobes = n;
q.maximum_nprobes = Some(n);
} else {
log::warn!("nprobes is not set because nearest has not been called yet");
}
self
}
pub fn minimum_nprobes(&mut self, n: usize) -> &mut Self {
if let Some(ref mut q) = self.nearest {
q.nprobes = n;
} else {
log::warn!("minimum_nprobes is not set because nearest has not been called yet");
}
self
}
pub fn maximum_nprobes(&mut self, n: usize) -> &mut Self {
if let Some(ref mut q) = self.nearest {
q.maximum_nprobes = Some(n);
} else {
log::warn!("maximum_nprobes is not set because nearest has not been called yet");
}
self
}
pub fn distance_metric(&mut self, metric: DistanceType) -> &mut Self {
if let Some(ref mut q) = self.nearest {
q.distance_type = Some(metric);
} else {
log::warn!("distance_metric is not set because nearest has not been called yet");
}
self
}
pub fn ef(&mut self, ef: usize) -> &mut Self {
if let Some(ref mut q) = self.nearest {
q.ef = Some(ef);
} else {
log::warn!("ef is not set because nearest has not been called yet");
}
self
}
pub fn refine(&mut self, factor: u32) -> &mut Self {
if let Some(ref mut q) = self.nearest {
q.refine_factor = Some(factor);
} else {
log::warn!("refine is not set because nearest has not been called yet");
}
self
}
pub fn distance_range(&mut self, lower: Option<f32>, upper: Option<f32>) -> &mut Self {
if let Some(ref mut q) = self.nearest {
q.distance_lower_bound = lower;
q.distance_upper_bound = upper;
} else {
log::warn!("distance_range is not set because nearest has not been called yet");
}
self
}
pub fn full_text_search(&mut self, column: &str, query: &str) -> &mut Self {
self.full_text_query = Some(FtsQuery::match_query(column, query));
self
}
pub fn full_text_phrase(&mut self, column: &str, phrase: &str, slop: u32) -> &mut Self {
self.full_text_query = Some(FtsQuery::phrase(column, phrase, slop));
self
}
pub fn full_text_boolean(
&mut self,
column: &str,
must: Vec<String>,
should: Vec<String>,
must_not: Vec<String>,
) -> &mut Self {
self.full_text_query = Some(FtsQuery::boolean(column, must, should, must_not));
self
}
pub fn full_text_fuzzy(&mut self, column: &str, query: &str) -> &mut Self {
self.full_text_query = Some(FtsQuery::fuzzy(column, query));
self
}
pub fn full_text_fuzzy_with_distance(
&mut self,
column: &str,
query: &str,
fuzziness: u32,
) -> &mut Self {
self.full_text_query = Some(FtsQuery::fuzzy_with_distance(column, query, fuzziness));
self
}
pub fn full_text_fuzzy_with_options(
&mut self,
column: &str,
query: &str,
fuzziness: Option<u32>,
max_expansions: usize,
) -> &mut Self {
self.full_text_query = Some(FtsQuery::fuzzy_with_options(
column,
query,
fuzziness,
max_expansions,
));
self
}
pub fn fts_wand_factor(&mut self, wand_factor: f32) -> &mut Self {
if let Some(ref mut q) = self.full_text_query {
q.wand_factor = wand_factor.clamp(0.0, 1.0);
} else {
log::warn!(
"fts_wand_factor is not set because full_text_query has not been called yet"
);
}
self
}
pub fn use_index(&mut self, use_index: bool) -> &mut Self {
self.use_index = use_index;
self
}
pub fn batch_size(&mut self, size: usize) -> &mut Self {
self.batch_size = Some(size);
self
}
pub async fn try_into_stream(&self) -> Result<SendableRecordBatchStream> {
let plan = self.create_plan().await?;
let ctx = SessionContext::new();
let task_ctx = ctx.task_ctx();
plan.execute(0, task_ctx)
.map_err(|e| Error::io(format!("Failed to execute plan: {}", e)))
}
pub async fn try_into_batch(&self) -> Result<RecordBatch> {
let stream = self.try_into_stream().await?;
let batches: Vec<RecordBatch> = stream
.try_collect()
.await
.map_err(|e| Error::io(format!("Failed to collect batches: {}", e)))?;
if batches.is_empty() {
return Ok(RecordBatch::new_empty(self.output_schema()));
}
arrow_select::concat::concat_batches(&self.output_schema(), &batches)
.map_err(|e| Error::io(format!("Failed to concatenate batches: {}", e)))
}
pub async fn count_rows(&self) -> Result<u64> {
let stream = self.try_into_stream().await?;
let batches: Vec<RecordBatch> = stream
.try_collect()
.await
.map_err(|e| Error::io(format!("Failed to count rows: {}", e)))?;
Ok(batches.iter().map(|b| b.num_rows() as u64).sum())
}
pub fn output_schema(&self) -> SchemaRef {
use super::exec::ROW_ADDRESS_COLUMN;
let mut fields: Vec<Field> = if let Some(ref projection) = self.projection {
projection
.iter()
.filter_map(|name| self.schema.field_with_name(name).ok().cloned())
.collect()
} else {
self.schema
.fields()
.iter()
.map(|f| f.as_ref().clone())
.collect()
};
if self.with_row_id {
fields.push(Field::new(ROW_ID, DataType::UInt64, true));
}
if self.with_row_address {
fields.push(Field::new(ROW_ADDRESS_COLUMN, DataType::UInt64, true));
}
Arc::new(arrow_schema::Schema::new(fields))
}
fn base_output_schema(&self) -> SchemaRef {
let fields: Vec<Field> = if let Some(ref projection) = self.projection {
projection
.iter()
.filter_map(|name| self.schema.field_with_name(name).ok().cloned())
.collect()
} else {
self.schema
.fields()
.iter()
.map(|f| f.as_ref().clone())
.collect()
};
Arc::new(arrow_schema::Schema::new(fields))
}
pub async fn create_plan(&self) -> Result<Arc<dyn ExecutionPlan>> {
if let Some(ref vector_query) = self.nearest {
return self.plan_vector_search(vector_query).await;
}
if let Some(ref fts_query) = self.full_text_query {
return self.plan_fts_search(fts_query).await;
}
if self.use_index
&& let Some(predicate) = self.extract_btree_predicate()
&& self.has_btree_index(predicate.column())
{
return self.plan_btree_query(&predicate).await;
}
self.plan_full_scan().await
}
async fn plan_full_scan(&self) -> Result<Arc<dyn ExecutionPlan>> {
let projection_indices = self.compute_projection_indices()?;
let (filter_predicate, filter_expr) = if let Some(ref filter) = self.filter {
let planner = Planner::new(self.schema.clone());
let optimized = planner.optimize_expr(filter.clone())?;
let predicate = planner.create_physical_expr(&optimized)?;
(Some(predicate), Some(optimized))
} else {
(None, None)
};
let scan = MemTableScanExec::with_filter(
self.batch_store.clone(),
self.max_visible_batch_position,
projection_indices,
self.output_schema(),
self.schema.clone(),
self.with_row_id,
self.with_row_address,
filter_predicate,
filter_expr,
);
let mut plan: Arc<dyn ExecutionPlan> = Arc::new(scan);
if let Some(limit) = self.limit {
plan = Arc::new(GlobalLimitExec::new(
plan,
self.offset.unwrap_or(0),
Some(limit),
));
}
Ok(plan)
}
async fn plan_btree_query(
&self,
predicate: &ScalarPredicate,
) -> Result<Arc<dyn ExecutionPlan>> {
if !self.has_btree_index(predicate.column()) {
return self.plan_full_scan().await;
}
let max_visible = self.max_visible_batch_position;
let projection_indices = self.compute_projection_indices()?;
let index_exec = BTreeIndexExec::new(
self.batch_store.clone(),
self.indexes.clone(),
predicate.clone(),
max_visible,
projection_indices,
self.output_schema(),
self.with_row_id,
self.with_row_address,
)?;
self.apply_post_index_ops(Arc::new(index_exec)).await
}
async fn plan_vector_search(&self, query: &VectorQuery) -> Result<Arc<dyn ExecutionPlan>> {
if !self.has_vector_index(&query.column) {
return self.plan_full_scan().await;
}
let max_visible = self.max_visible_batch_position;
let projection_indices = self.compute_projection_indices()?;
let index_exec = VectorIndexExec::new(
self.batch_store.clone(),
self.indexes.clone(),
query.clone(),
max_visible,
projection_indices,
self.base_output_schema(),
self.with_row_id,
)?;
self.apply_post_index_ops(Arc::new(index_exec)).await
}
async fn plan_fts_search(&self, query: &FtsQuery) -> Result<Arc<dyn ExecutionPlan>> {
if !self.has_fts_index(&query.column) {
return self.plan_full_scan().await;
}
let max_visible = self.max_visible_batch_position;
let projection_indices = self.compute_projection_indices()?;
let index_exec = FtsIndexExec::new(
self.batch_store.clone(),
self.indexes.clone(),
query.clone(),
max_visible,
projection_indices,
self.base_output_schema(),
self.with_row_id,
)?;
self.apply_post_index_ops(Arc::new(index_exec)).await
}
async fn apply_post_index_ops(
&self,
plan: Arc<dyn ExecutionPlan>,
) -> Result<Arc<dyn ExecutionPlan>> {
let mut result = plan;
if let Some(limit) = self.limit {
result = Arc::new(GlobalLimitExec::new(
result,
self.offset.unwrap_or(0),
Some(limit),
));
}
Ok(result)
}
fn compute_projection_indices(&self) -> Result<Option<Vec<usize>>> {
if let Some(ref columns) = self.projection {
let indices: Result<Vec<usize>> = columns
.iter()
.map(|name| {
self.schema
.column_with_name(name)
.map(|(idx, _)| idx)
.ok_or_else(|| {
Error::invalid_input(format!("Column '{}' not found in schema", name))
})
})
.collect();
Ok(Some(indices?))
} else {
Ok(None)
}
}
fn extract_btree_predicate(&self) -> Option<ScalarPredicate> {
let filter = self.filter.as_ref()?;
match filter {
Expr::BinaryExpr(binary) => {
if let (Expr::Column(col), Expr::Literal(lit, _)) =
(binary.left.as_ref(), binary.right.as_ref())
{
let coerced_lit = self.coerce_literal_to_column(&col.name, lit)?;
match binary.op {
datafusion::logical_expr::Operator::Eq => {
return Some(ScalarPredicate::Eq {
column: col.name.clone(),
value: coerced_lit,
});
}
datafusion::logical_expr::Operator::Lt
| datafusion::logical_expr::Operator::LtEq => {
return Some(ScalarPredicate::Range {
column: col.name.clone(),
lower: None,
upper: Some(coerced_lit),
});
}
datafusion::logical_expr::Operator::Gt
| datafusion::logical_expr::Operator::GtEq => {
return Some(ScalarPredicate::Range {
column: col.name.clone(),
lower: Some(coerced_lit),
upper: None,
});
}
_ => {}
}
}
}
Expr::InList(in_list) => {
if let Expr::Column(col) = in_list.expr.as_ref() {
let values: Vec<ScalarValue> = in_list
.list
.iter()
.filter_map(|e| {
if let Expr::Literal(lit, _) = e {
self.coerce_literal_to_column(&col.name, lit)
} else {
None
}
})
.collect();
if values.len() == in_list.list.len() {
return Some(ScalarPredicate::In {
column: col.name.clone(),
values,
});
}
}
}
_ => {}
}
None
}
fn coerce_literal_to_column(&self, column: &str, lit: &ScalarValue) -> Option<ScalarValue> {
let field = self.schema.field_with_name(column).ok()?;
let target_type = field.data_type();
if &lit.data_type() == target_type {
return Some(lit.clone());
}
safe_coerce_scalar(lit, target_type)
}
fn has_btree_index(&self, column: &str) -> bool {
self.indexes.get_btree_by_column(column).is_some()
}
fn has_vector_index(&self, column: &str) -> bool {
self.indexes.get_ivf_pq_by_column(column).is_some()
}
fn has_fts_index(&self, column: &str) -> bool {
self.indexes.get_fts_by_column(column).is_some()
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::{Int32Array, StringArray};
use arrow_schema::{DataType, Field, Schema};
fn create_test_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, true),
]))
}
fn create_test_batch(schema: &Schema, start_id: i32, count: usize) -> RecordBatch {
let ids: Vec<i32> = (start_id..start_id + count as i32).collect();
let names: Vec<String> = ids.iter().map(|id| format!("name_{}", id)).collect();
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from(ids)),
Arc::new(StringArray::from(names)),
],
)
.unwrap()
}
fn create_index_store_with_batches(
batch_store: &Arc<BatchStore>,
schema: &Schema,
batches: &[(i32, usize)], ) -> Arc<IndexStore> {
let mut index_store = IndexStore::new();
index_store.add_btree("id_idx".to_string(), 0, "id".to_string());
let mut row_offset = 0u64;
for (batch_pos, (start_id, count)) in batches.iter().enumerate() {
let batch = create_test_batch(schema, *start_id, *count);
batch_store.append(batch.clone()).unwrap();
index_store
.insert_with_batch_position(&batch, row_offset, Some(batch_pos))
.unwrap();
row_offset += *count as u64;
}
Arc::new(index_store)
}
#[tokio::test]
async fn test_scanner_basic_scan() {
let schema = create_test_schema();
let batch_store = Arc::new(BatchStore::with_capacity(100));
let indexes = create_index_store_with_batches(&batch_store, &schema, &[(0, 10)]);
let scanner = MemTableScanner::new(batch_store, indexes, schema.clone());
let result = scanner.try_into_batch().await.unwrap();
assert_eq!(result.num_rows(), 10);
}
#[tokio::test]
async fn test_scanner_visibility_filtering() {
let schema = create_test_schema();
let batch_store = Arc::new(BatchStore::with_capacity(100));
let mut index_store = IndexStore::new();
index_store.add_btree("id_idx".to_string(), 0, "id".to_string());
let batch1 = create_test_batch(&schema, 0, 10);
batch_store.append(batch1.clone()).unwrap();
index_store
.insert_with_batch_position(&batch1, 0, Some(0))
.unwrap();
let batch2 = create_test_batch(&schema, 10, 10);
batch_store.append(batch2.clone()).unwrap();
index_store
.insert_with_batch_position(&batch2, 10, Some(1))
.unwrap();
let batch3 = create_test_batch(&schema, 20, 10);
batch_store.append(batch3).unwrap();
let indexes = Arc::new(index_store);
let scanner = MemTableScanner::new(batch_store, indexes, schema.clone());
let result = scanner.try_into_batch().await.unwrap();
assert_eq!(result.num_rows(), 20);
}
#[tokio::test]
async fn test_scanner_projection() {
let schema = create_test_schema();
let batch_store = Arc::new(BatchStore::with_capacity(100));
let indexes = create_index_store_with_batches(&batch_store, &schema, &[(0, 10)]);
let mut scanner = MemTableScanner::new(batch_store, indexes, schema.clone());
scanner.project(&["id"]);
let result = scanner.try_into_batch().await.unwrap();
assert_eq!(result.num_columns(), 1);
assert_eq!(result.schema().field(0).name(), "id");
}
#[tokio::test]
async fn test_scanner_limit() {
let schema = create_test_schema();
let batch_store = Arc::new(BatchStore::with_capacity(100));
let indexes = create_index_store_with_batches(&batch_store, &schema, &[(0, 100)]);
let mut scanner = MemTableScanner::new(batch_store, indexes, schema.clone());
scanner.limit(10, None);
let result = scanner.try_into_batch().await.unwrap();
assert_eq!(result.num_rows(), 10);
}
#[tokio::test]
async fn test_scanner_count_rows() {
let schema = create_test_schema();
let batch_store = Arc::new(BatchStore::with_capacity(100));
let indexes = create_index_store_with_batches(&batch_store, &schema, &[(0, 50)]);
let scanner = MemTableScanner::new(batch_store, indexes, schema.clone());
let count = scanner.count_rows().await.unwrap();
assert_eq!(count, 50);
}
#[tokio::test]
async fn test_scanner_with_row_id() {
let schema = create_test_schema();
let batch_store = Arc::new(BatchStore::with_capacity(100));
let indexes = create_index_store_with_batches(&batch_store, &schema, &[(0, 10)]);
let mut scanner = MemTableScanner::new(batch_store, indexes, schema.clone());
scanner.with_row_id();
let output_schema = scanner.output_schema();
assert_eq!(output_schema.fields().len(), 3);
assert_eq!(output_schema.field(0).name(), "id");
assert_eq!(output_schema.field(1).name(), "name");
assert_eq!(output_schema.field(2).name(), "_rowid");
assert_eq!(output_schema.field(2).data_type(), &DataType::UInt64);
let result = scanner.try_into_batch().await.unwrap();
assert_eq!(result.num_columns(), 3);
assert_eq!(result.schema().field(2).name(), "_rowid");
let row_ids = result
.column(2)
.as_any()
.downcast_ref::<arrow_array::UInt64Array>()
.unwrap();
assert_eq!(row_ids.len(), 10);
for i in 0..10 {
assert_eq!(row_ids.value(i), i as u64);
}
}
#[tokio::test]
async fn test_scanner_project_with_row_id() {
let schema = create_test_schema();
let batch_store = Arc::new(BatchStore::with_capacity(100));
let indexes = create_index_store_with_batches(&batch_store, &schema, &[(0, 10)]);
let mut scanner = MemTableScanner::new(batch_store, indexes, schema.clone());
scanner.project(&["id", "_rowid"]);
let output_schema = scanner.output_schema();
assert_eq!(output_schema.fields().len(), 2);
assert_eq!(output_schema.field(0).name(), "id");
assert_eq!(output_schema.field(1).name(), "_rowid");
let result = scanner.try_into_batch().await.unwrap();
assert_eq!(result.num_columns(), 2);
assert_eq!(result.schema().field(0).name(), "id");
assert_eq!(result.schema().field(1).name(), "_rowid");
}
#[tokio::test]
async fn test_scanner_row_id_across_batches() {
let schema = create_test_schema();
let batch_store = Arc::new(BatchStore::with_capacity(100));
let indexes = create_index_store_with_batches(&batch_store, &schema, &[(0, 5), (5, 5)]);
let mut scanner = MemTableScanner::new(batch_store, indexes, schema.clone());
scanner.with_row_id();
let result = scanner.try_into_batch().await.unwrap();
assert_eq!(result.num_rows(), 10);
let row_ids = result
.column(2)
.as_any()
.downcast_ref::<arrow_array::UInt64Array>()
.unwrap();
for i in 0..10 {
assert_eq!(row_ids.value(i), i as u64);
}
}
#[test]
fn test_output_schema_with_row_id() {
let schema = create_test_schema();
let batch_store = Arc::new(BatchStore::with_capacity(100));
let indexes = Arc::new(IndexStore::new());
let mut scanner = MemTableScanner::new(batch_store, indexes, schema);
let output_schema = scanner.output_schema();
assert_eq!(output_schema.fields().len(), 2);
assert!(output_schema.field_with_name("_rowid").is_err());
scanner.with_row_id();
let output_schema = scanner.output_schema();
assert_eq!(output_schema.fields().len(), 3);
assert!(output_schema.field_with_name("_rowid").is_ok());
}
#[test]
fn test_project_extracts_row_id() {
let schema = create_test_schema();
let batch_store = Arc::new(BatchStore::with_capacity(100));
let indexes = Arc::new(IndexStore::new());
let mut scanner = MemTableScanner::new(batch_store, indexes, schema);
scanner.project(&["id", "_rowid"]);
assert!(scanner.with_row_id);
assert_eq!(scanner.projection, Some(vec!["id".to_string()]));
let output_schema = scanner.output_schema();
assert_eq!(output_schema.fields().len(), 2);
assert_eq!(output_schema.field(0).name(), "id");
assert_eq!(output_schema.field(1).name(), "_rowid");
}
#[tokio::test]
async fn test_scan_plan_with_row_id() {
use crate::utils::test::assert_plan_node_equals;
let schema = create_test_schema();
let batch_store = Arc::new(BatchStore::with_capacity(100));
let indexes = create_index_store_with_batches(&batch_store, &schema, &[(0, 10)]);
let mut scanner = MemTableScanner::new(batch_store, indexes, schema.clone());
scanner.with_row_id();
let plan = scanner.create_plan().await.unwrap();
assert_plan_node_equals(
plan,
"MemTableScanExec: projection=[id, name, _rowid], with_row_id=true",
)
.await
.unwrap();
}
#[tokio::test]
async fn test_scan_plan_projection_with_row_id() {
use crate::utils::test::assert_plan_node_equals;
let schema = create_test_schema();
let batch_store = Arc::new(BatchStore::with_capacity(100));
let indexes = create_index_store_with_batches(&batch_store, &schema, &[(0, 10)]);
let mut scanner = MemTableScanner::new(batch_store, indexes, schema.clone());
scanner.project(&["id", "_rowid"]);
let plan = scanner.create_plan().await.unwrap();
assert_plan_node_equals(
plan,
"MemTableScanExec: projection=[id, _rowid], with_row_id=true",
)
.await
.unwrap();
}
#[tokio::test]
async fn test_scan_plan_without_row_id() {
use crate::utils::test::assert_plan_node_equals;
let schema = create_test_schema();
let batch_store = Arc::new(BatchStore::with_capacity(100));
let indexes = create_index_store_with_batches(&batch_store, &schema, &[(0, 10)]);
let scanner = MemTableScanner::new(batch_store, indexes, schema.clone());
let plan = scanner.create_plan().await.unwrap();
assert_plan_node_equals(
plan,
"MemTableScanExec: projection=[id, name], with_row_id=false",
)
.await
.unwrap();
}
#[test]
fn test_output_schema_with_row_address() {
let schema = create_test_schema();
let batch_store = Arc::new(BatchStore::with_capacity(100));
let indexes = Arc::new(IndexStore::new());
let mut scanner = MemTableScanner::new(batch_store, indexes, schema);
let output_schema = scanner.output_schema();
assert_eq!(output_schema.fields().len(), 2);
assert!(output_schema.field_with_name("_rowaddr").is_err());
scanner.with_row_address();
let output_schema = scanner.output_schema();
assert_eq!(output_schema.fields().len(), 3);
assert!(output_schema.field_with_name("_rowaddr").is_ok());
}
#[tokio::test]
async fn test_scanner_with_row_address() {
let schema = create_test_schema();
let batch_store = Arc::new(BatchStore::with_capacity(100));
let indexes = create_index_store_with_batches(&batch_store, &schema, &[(0, 10)]);
let mut scanner = MemTableScanner::new(batch_store, indexes, schema.clone());
scanner.with_row_address();
let output_schema = scanner.output_schema();
assert_eq!(output_schema.fields().len(), 3);
assert_eq!(output_schema.field(0).name(), "id");
assert_eq!(output_schema.field(1).name(), "name");
assert_eq!(output_schema.field(2).name(), "_rowaddr");
assert_eq!(output_schema.field(2).data_type(), &DataType::UInt64);
let result = scanner.try_into_batch().await.unwrap();
assert_eq!(result.num_columns(), 3);
assert_eq!(result.schema().field(2).name(), "_rowaddr");
let row_addrs = result
.column(2)
.as_any()
.downcast_ref::<arrow_array::UInt64Array>()
.unwrap();
assert_eq!(row_addrs.len(), 10);
for i in 0..10 {
assert_eq!(row_addrs.value(i), i as u64);
}
}
#[tokio::test]
async fn test_scan_plan_with_row_address() {
use crate::utils::test::assert_plan_node_equals;
let schema = create_test_schema();
let batch_store = Arc::new(BatchStore::with_capacity(100));
let indexes = create_index_store_with_batches(&batch_store, &schema, &[(0, 10)]);
let mut scanner = MemTableScanner::new(batch_store, indexes, schema.clone());
scanner.with_row_address();
let plan = scanner.create_plan().await.unwrap();
assert_plan_node_equals(
plan,
"MemTableScanExec: projection=[id, name, _rowaddr], with_row_id=false, with_row_address=true",
)
.await
.unwrap();
}
#[tokio::test]
async fn test_scanner_with_both_row_id_and_row_address() {
let schema = create_test_schema();
let batch_store = Arc::new(BatchStore::with_capacity(100));
let indexes = create_index_store_with_batches(&batch_store, &schema, &[(0, 5)]);
let mut scanner = MemTableScanner::new(batch_store, indexes, schema.clone());
scanner.with_row_id();
scanner.with_row_address();
let output_schema = scanner.output_schema();
assert_eq!(output_schema.fields().len(), 4);
assert_eq!(output_schema.field(2).name(), "_rowid");
assert_eq!(output_schema.field(3).name(), "_rowaddr");
let result = scanner.try_into_batch().await.unwrap();
assert_eq!(result.num_columns(), 4);
let row_ids = result
.column(2)
.as_any()
.downcast_ref::<arrow_array::UInt64Array>()
.unwrap();
let row_addrs = result
.column(3)
.as_any()
.downcast_ref::<arrow_array::UInt64Array>()
.unwrap();
for i in 0..5 {
assert_eq!(row_ids.value(i), i as u64);
assert_eq!(row_addrs.value(i), i as u64);
}
}
}