use arrow_array::{Array, FixedSizeListArray, Float32Array, RecordBatch};
use crate::supertable::{error::BuildError, options::SupertableOptions};
const MAX_NULL_OFFSETS_IN_ERROR: usize = 5;
pub(crate) fn split_vectors<'a>(
batch: &'a RecordBatch,
options: &SupertableOptions,
) -> Result<(RecordBatch, Vec<&'a [f32]>), BuildError> {
if batch.schema().as_ref() != options.schema.as_ref() {
return Err(BuildError::BatchSchemaMismatch);
}
let mut vectors: Vec<&'a [f32]> = Vec::with_capacity(options.vector_columns.len());
for vc in &options.vector_columns {
let idx =
batch
.schema()
.index_of(&vc.column)
.map_err(|_| BuildError::VectorColumnMissing {
column: vc.column.clone(),
})?;
let col = batch.column(idx);
let fsl = col
.as_any()
.downcast_ref::<FixedSizeListArray>()
.ok_or_else(|| BuildError::VectorColumnNotFixedSizeList {
column: vc.column.clone(),
dim: vc.dim,
actual: format!("{:?}", col.data_type()),
})?;
let list_size = usize::try_from(fsl.value_length()).unwrap_or(usize::MAX);
if list_size != vc.dim {
return Err(BuildError::VectorColumnDimMismatch {
column: vc.column.clone(),
expected: vc.dim,
actual: list_size,
});
}
if fsl.null_count() > 0 {
let first_nulls = collect_first_nulls(fsl, MAX_NULL_OFFSETS_IN_ERROR);
return Err(BuildError::VectorColumnHasNulls {
column: vc.column.clone(),
first_nulls,
});
}
let inner = fsl
.values()
.as_any()
.downcast_ref::<Float32Array>()
.ok_or_else(|| BuildError::VectorColumnNotFixedSizeList {
column: vc.column.clone(),
dim: vc.dim,
actual: format!("{:?}", col.data_type()),
})?;
if inner.null_count() > 0 {
let first_nulls = collect_first_nulls_primitive(inner, MAX_NULL_OFFSETS_IN_ERROR);
return Err(BuildError::VectorColumnHasNulls {
column: vc.column.clone(),
first_nulls,
});
}
vectors.push(inner.values());
}
let scalar_field_names: Vec<&str> = options
.schema
.fields()
.iter()
.filter(|f| {
!options
.vector_columns
.iter()
.any(|vc| vc.column == *f.name())
})
.map(|f| f.name().as_str())
.collect();
let scalar_batch = batch
.project(
&scalar_field_names
.iter()
.map(|n| {
batch.schema().index_of(n).expect(
"invariant: name from options.schema is in batch.schema (checked above)",
)
})
.collect::<Vec<_>>(),
)
.map_err(|_| BuildError::BatchSchemaMismatch)?;
Ok((scalar_batch, vectors))
}
fn collect_first_nulls(arr: &FixedSizeListArray, max: usize) -> Vec<usize> {
let mut out = Vec::with_capacity(max);
for i in 0..arr.len() {
if arr.is_null(i) {
out.push(i);
if out.len() >= max {
break;
}
}
}
out
}
fn collect_first_nulls_primitive(arr: &Float32Array, max: usize) -> Vec<usize> {
let mut out = Vec::with_capacity(max);
for i in 0..arr.len() {
if arr.is_null(i) {
out.push(i);
if out.len() >= max {
break;
}
}
}
out
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow_array::{Array, Float32Array, LargeStringArray, UInt64Array};
use arrow_schema::{DataType, Field, Schema};
use super::*;
use crate::superfile::{
builder::{FtsConfig, VectorConfig},
vector::{distance::Metric, rerank_codec::RerankCodec},
};
fn fixed_list_f32(dim: usize) -> DataType {
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
dim as i32,
)
}
fn schema_id_title_emb(dim: usize) -> Arc<Schema> {
Arc::new(Schema::new(vec![
Field::new("title", DataType::LargeUtf8, false),
Field::new("emb", fixed_list_f32(dim), false),
]))
}
fn vc(name: &str, dim: usize) -> VectorConfig {
VectorConfig {
column: name.into(),
dim,
n_cent: 4,
rot_seed: 0,
metric: Metric::Cosine,
rerank_codec: RerankCodec::Fp32,
}
}
fn fc(name: &str) -> FtsConfig {
FtsConfig {
column: name.into(),
}
}
use crate::test_helpers::default_tokenizer as tok;
fn build_fsl(flat: Vec<f32>, dim: usize) -> FixedSizeListArray {
let item_field = Arc::new(Field::new("item", DataType::Float32, true));
let values = Float32Array::from(flat);
FixedSizeListArray::try_new(item_field, dim as i32, Arc::new(values), None)
.expect("build FixedSizeListArray")
}
fn build_batch(schema: Arc<Schema>, n: usize, dim: usize) -> RecordBatch {
let titles = LargeStringArray::from((0..n).map(|i| format!("doc {i}")).collect::<Vec<_>>());
let mut flat = Vec::with_capacity(n * dim);
for i in 0..n {
for j in 0..dim {
flat.push((i * dim + j) as f32);
}
}
let fsl = build_fsl(flat, dim);
RecordBatch::try_new(schema, vec![Arc::new(titles), Arc::new(fsl)])
.expect("build RecordBatch")
}
#[test]
fn split_extracts_vectors_and_drops_columns() {
let dim = 16;
let schema = schema_id_title_emb(dim);
let opts = SupertableOptions::new(
schema.clone(),
vec![fc("title")],
vec![vc("emb", dim)],
Some(tok()),
)
.expect("valid options");
let batch = build_batch(schema, 4, dim);
let (scalar, vectors) = split_vectors(&batch, &opts).expect("split should succeed");
let names: Vec<_> = scalar
.schema()
.fields()
.iter()
.map(|f| f.name().clone())
.collect();
assert_eq!(names, vec!["title"]);
assert_eq!(scalar.num_rows(), 4);
assert_eq!(scalar.num_columns(), 1);
assert_eq!(vectors.len(), 1);
assert_eq!(vectors[0].len(), 4 * dim);
assert_eq!(vectors[0][2 * dim + 3], 35.0);
}
#[test]
fn split_rejects_batch_with_wrong_schema() {
let dim = 16;
let schema = schema_id_title_emb(dim);
let opts = SupertableOptions::new(
schema.clone(),
vec![fc("title")],
vec![vc("emb", dim)],
Some(tok()),
)
.expect("valid options");
let other_schema = Arc::new(Schema::new(vec![Field::new(
"emb",
fixed_list_f32(dim),
false,
)]));
let fsl = build_fsl(vec![0.0; 2 * dim], dim);
let other_batch =
RecordBatch::try_new(other_schema, vec![Arc::new(fsl)]).expect("build batch");
let err = split_vectors(&other_batch, &opts).expect_err("expected error");
assert!(matches!(err, BuildError::BatchSchemaMismatch));
}
#[test]
fn split_rejects_null_vector_row() {
let dim = 16;
let schema = Arc::new(Schema::new(vec![
Field::new("title", DataType::LargeUtf8, false),
Field::new("emb", fixed_list_f32(dim), true),
]));
let opts = SupertableOptions::new(
schema.clone(),
vec![fc("title")],
vec![vc("emb", dim)],
Some(tok()),
)
.expect("valid options");
use arrow::buffer::NullBuffer;
let item_field = Arc::new(Field::new("item", DataType::Float32, true));
let values = Float32Array::from(vec![0.0f32; 3 * dim]);
let nulls = NullBuffer::from(vec![true, false, true]); let fsl =
FixedSizeListArray::try_new(item_field, dim as i32, Arc::new(values), Some(nulls))
.expect("build FSL with nulls");
let titles = LargeStringArray::from(vec!["a", "b", "c"]);
let batch = RecordBatch::try_new(schema, vec![Arc::new(titles), Arc::new(fsl)])
.expect("build batch");
let err = split_vectors(&batch, &opts).expect_err("expected error");
match err {
BuildError::VectorColumnHasNulls {
column,
first_nulls,
} => {
assert_eq!(column, "emb");
assert_eq!(first_nulls, vec![1]);
}
other => panic!("expected VectorColumnHasNulls, got {:?}", other),
}
}
#[test]
fn split_succeeds_with_zero_vector_columns() {
let schema = Arc::new(Schema::new(vec![Field::new(
"title",
DataType::LargeUtf8,
false,
)]));
let opts = SupertableOptions::new(schema.clone(), vec![fc("title")], vec![], Some(tok()))
.expect("valid options");
let titles = LargeStringArray::from(vec!["x", "y"]);
let batch = RecordBatch::try_new(schema, vec![Arc::new(titles)]).expect("build batch");
let (scalar, vectors) = split_vectors(&batch, &opts).expect("split should succeed");
assert_eq!(scalar.num_rows(), 2);
assert_eq!(scalar.num_columns(), 1);
assert_eq!(vectors.len(), 0);
}
#[test]
fn split_preserves_scalar_column_order() {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::UInt64, false),
Field::new("emb", fixed_list_f32(16), false),
Field::new("b", DataType::UInt64, false),
Field::new("other", fixed_list_f32(16), false),
Field::new("c", DataType::UInt64, false),
]));
let opts = SupertableOptions::new(
schema.clone(),
vec![],
vec![vc("emb", 16), vc("other", 16)],
None,
)
.expect("valid options");
let n = 2;
let dim = 16;
let a = UInt64Array::from(vec![10u64, 20]);
let b = UInt64Array::from(vec![30u64, 40]);
let c = UInt64Array::from(vec![50u64, 60]);
let fsl1 = build_fsl(vec![0.0; n * dim], dim);
let fsl2 = build_fsl(vec![1.0; n * dim], dim);
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(a),
Arc::new(fsl1),
Arc::new(b),
Arc::new(fsl2),
Arc::new(c),
],
)
.expect("build batch");
let (scalar, vectors) = split_vectors(&batch, &opts).expect("split should succeed");
let names: Vec<_> = scalar
.schema()
.fields()
.iter()
.map(|f| f.name().clone())
.collect();
assert_eq!(names, vec!["a", "b", "c"]);
assert_eq!(vectors.len(), 2);
assert_eq!(vectors[0].len(), n * dim);
assert_eq!(vectors[1].len(), n * dim);
assert_eq!(vectors[0][0], 0.0);
assert_eq!(vectors[1][0], 1.0);
}
#[test]
fn split_rejects_inner_lane_null() {
let dim = 16;
let schema = schema_id_title_emb(dim);
let opts = SupertableOptions::new(
schema.clone(),
vec![fc("title")],
vec![vc("emb", dim)],
Some(tok()),
)
.expect("valid options");
const NULL_LANE_FLAT_INDEX: usize = 1;
let n_rows = 3;
let item_field = Arc::new(Field::new("item", DataType::Float32, true));
let mut builder = Float32Array::builder(n_rows * dim);
for i in 0..(n_rows * dim) {
if i == NULL_LANE_FLAT_INDEX {
builder.append_null();
} else {
builder.append_value(i as f32);
}
}
let values = builder.finish();
let fsl = FixedSizeListArray::try_new(item_field, dim as i32, Arc::new(values), None)
.expect("build FSL with inner null");
let titles = LargeStringArray::from(vec!["a", "b", "c"]);
let batch = RecordBatch::try_new(schema, vec![Arc::new(titles), Arc::new(fsl)])
.expect("build batch");
let err = split_vectors(&batch, &opts).expect_err("expected error");
match err {
BuildError::VectorColumnHasNulls {
column,
first_nulls,
} => {
assert_eq!(column, "emb");
assert_eq!(first_nulls, vec![NULL_LANE_FLAT_INDEX]);
}
other => panic!("expected VectorColumnHasNulls, got {:?}", other),
}
}
#[test]
fn split_rejects_missing_vector_column() {
let dim = 16;
let schema = schema_id_title_emb(dim);
let mut opts = SupertableOptions::new(
schema.clone(),
vec![fc("title")],
vec![vc("emb", dim)],
Some(tok()),
)
.expect("valid options");
opts.vector_columns[0].column = "not_a_column".into();
let batch = build_batch(schema, 2, dim);
let err = split_vectors(&batch, &opts).expect_err("expected error");
assert!(matches!(
err,
BuildError::VectorColumnMissing { column } if column == "not_a_column"
));
}
#[test]
fn split_rejects_non_fixed_size_list_column() {
let dim = 16;
let schema = schema_id_title_emb(dim);
let mut opts = SupertableOptions::new(
schema.clone(),
vec![fc("title")],
vec![vc("emb", dim)],
Some(tok()),
)
.expect("valid options");
opts.vector_columns[0].column = "title".into();
let batch = build_batch(schema, 2, dim);
let err = split_vectors(&batch, &opts).expect_err("expected error");
assert!(matches!(
err,
BuildError::VectorColumnNotFixedSizeList { column, .. } if column == "title"
));
}
#[test]
fn split_rejects_dim_mismatch() {
let dim = 16;
const WRONG_DIM: usize = 32;
let schema = schema_id_title_emb(dim);
let mut opts = SupertableOptions::new(
schema.clone(),
vec![fc("title")],
vec![vc("emb", dim)],
Some(tok()),
)
.expect("valid options");
opts.vector_columns[0].dim = WRONG_DIM;
let batch = build_batch(schema, 2, dim);
let err = split_vectors(&batch, &opts).expect_err("expected error");
assert!(matches!(
err,
BuildError::VectorColumnDimMismatch {
expected: WRONG_DIM,
actual,
column,
} if actual == dim && column == "emb"
));
}
#[test]
fn split_returns_zero_copy_view_into_batch() {
let dim = 16;
let schema = schema_id_title_emb(dim);
let opts = SupertableOptions::new(
schema.clone(),
vec![fc("title")],
vec![vc("emb", dim)],
Some(tok()),
)
.expect("valid options");
let batch = build_batch(schema, 4, dim);
let (_scalar, vectors) = split_vectors(&batch, &opts).expect("split should succeed");
let original = batch
.column(1)
.as_any()
.downcast_ref::<FixedSizeListArray>()
.expect("FSL")
.values()
.as_any()
.downcast_ref::<Float32Array>()
.expect("Float32Array")
.values()
.as_ptr();
assert_eq!(vectors[0].as_ptr(), original);
}
}