use arrow_array::{BinaryArray, RecordBatch};
use datafusion::common::ScalarValue;
use lance_core::{Error, Result};
#[inline]
fn encode_signed(v: i64) -> u64 {
(v as u64) ^ (1u64 << 63)
}
fn encode_bytes(out: &mut Vec<u8>, bytes: &[u8]) {
for &b in bytes {
out.push(b);
if b == 0x00 {
out.push(0xFF);
}
}
out.extend_from_slice(&[0x00, 0x00]);
}
fn encode_value(out: &mut Vec<u8>, value: &ScalarValue) -> Result<()> {
if value.is_null() {
out.push(0x00);
return Ok(());
}
out.push(0x01);
macro_rules! be_signed {
($v:expr) => {
out.extend_from_slice(&encode_signed($v as i64).to_be_bytes())
};
}
match value {
ScalarValue::Int8(Some(v)) => be_signed!(*v),
ScalarValue::Int16(Some(v)) => be_signed!(*v),
ScalarValue::Int32(Some(v)) => be_signed!(*v),
ScalarValue::Int64(Some(v)) => be_signed!(*v),
ScalarValue::Date32(Some(v)) => be_signed!(*v),
ScalarValue::Date64(Some(v)) => be_signed!(*v),
ScalarValue::UInt8(Some(v)) => out.extend_from_slice(&(*v as u64).to_be_bytes()),
ScalarValue::UInt16(Some(v)) => out.extend_from_slice(&(*v as u64).to_be_bytes()),
ScalarValue::UInt32(Some(v)) => out.extend_from_slice(&(*v as u64).to_be_bytes()),
ScalarValue::UInt64(Some(v)) => out.extend_from_slice(&v.to_be_bytes()),
ScalarValue::Boolean(Some(b)) => out.push(*b as u8),
ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s)) => {
encode_bytes(out, s.as_bytes())
}
ScalarValue::Binary(Some(b))
| ScalarValue::LargeBinary(Some(b))
| ScalarValue::FixedSizeBinary(_, Some(b)) => encode_bytes(out, b),
other => {
return Err(Error::invalid_input(format!(
"Unsupported primary-key column type for composite key: {other:?}"
)));
}
}
Ok(())
}
pub fn encode_pk_tuple(values: &[ScalarValue]) -> Result<Vec<u8>> {
let mut out = Vec::with_capacity(values.len() * 9);
for value in values {
encode_value(&mut out, value)?;
}
Ok(out)
}
fn encode_pk_row(batch: &RecordBatch, pk_indices: &[usize], row: usize) -> Result<Vec<u8>> {
let mut out = Vec::with_capacity(pk_indices.len() * 9);
for &col in pk_indices {
let value = ScalarValue::try_from_array(batch.column(col), row)?;
encode_value(&mut out, &value)?;
}
Ok(out)
}
pub fn encode_pk_batch(batch: &RecordBatch, pk_indices: &[usize]) -> Result<BinaryArray> {
let mut keys: Vec<Vec<u8>> = Vec::with_capacity(batch.num_rows());
for row in 0..batch.num_rows() {
keys.push(encode_pk_row(batch, pk_indices, row)?);
}
Ok(BinaryArray::from_iter_values(keys.iter()))
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::{Int32Array, StringArray};
use arrow_schema::{DataType, Field, Schema};
use std::sync::Arc;
fn tuple(a: i32, b: &str) -> Vec<ScalarValue> {
vec![ScalarValue::Int32(Some(a)), ScalarValue::from(b)]
}
#[test]
fn encoding_is_order_preserving_and_injective() {
let tuples = [
tuple(1, "a"),
tuple(1, "ab"),
tuple(1, "b"),
tuple(2, "a"),
tuple(-1, "z"),
];
let mut encoded: Vec<(Vec<u8>, &Vec<ScalarValue>)> = tuples
.iter()
.map(|t| (encode_pk_tuple(t).unwrap(), t))
.collect();
encoded.sort_by(|x, y| x.0.cmp(&y.0));
let order: Vec<_> = encoded.iter().map(|(_, t)| (*t).clone()).collect();
assert_eq!(
order,
vec![
tuple(-1, "z"),
tuple(1, "a"),
tuple(1, "ab"),
tuple(1, "b"),
tuple(2, "a"),
]
);
let mut keys: Vec<Vec<u8>> = tuples.iter().map(|t| encode_pk_tuple(t).unwrap()).collect();
keys.sort();
keys.dedup();
assert_eq!(keys.len(), 5);
}
#[test]
fn null_sorts_first_and_is_distinct() {
let null_a = vec![ScalarValue::Int32(None), ScalarValue::from("a")];
let one_a = tuple(1, "a");
assert!(encode_pk_tuple(&null_a).unwrap() < encode_pk_tuple(&one_a).unwrap());
assert_ne!(
encode_pk_tuple(&null_a).unwrap(),
encode_pk_tuple(&one_a).unwrap()
);
}
#[test]
fn prefix_safety_with_embedded_zero() {
let with_zero = vec![ScalarValue::Binary(Some(vec![0x00]))];
let empty = vec![ScalarValue::Binary(Some(vec![]))];
assert!(encode_pk_tuple(&empty).unwrap() < encode_pk_tuple(&with_zero).unwrap());
}
#[test]
fn encode_pk_batch_matches_per_tuple_encoding() {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, false),
]));
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(Int32Array::from(vec![2, 1])),
Arc::new(StringArray::from(vec!["a", "b"])),
],
)
.unwrap();
let encoded = encode_pk_batch(&batch, &[0, 1]).unwrap();
assert_eq!(encoded.value(0), encode_pk_tuple(&tuple(2, "a")).unwrap());
assert_eq!(encoded.value(1), encode_pk_tuple(&tuple(1, "b")).unwrap());
assert!(encoded.value(1) < encoded.value(0));
}
}