use std::sync::Arc;
use arrow::row::{RowConverter, SortField};
use arrow_array::{ArrayRef, RecordBatch};
use rustc_hash::FxHashMap;
use crate::lookup::source::LookupError;
pub struct KeyAligner {
converter: RowConverter,
pk_columns: Vec<String>,
}
impl KeyAligner {
pub fn new(
pk_sort_fields: Vec<SortField>,
pk_columns: Vec<String>,
) -> Result<Self, LookupError> {
if pk_columns.is_empty() {
return Err(LookupError::Internal(
"primary_key_columns must not be empty".into(),
));
}
let converter = RowConverter::new(pk_sort_fields)
.map_err(|e| LookupError::Internal(format!("row converter: {e}")))?;
Ok(Self {
converter,
pk_columns,
})
}
#[must_use]
pub fn pk_columns(&self) -> &[String] {
&self.pk_columns
}
pub fn decode_keys(&self, keys: &[&[u8]]) -> Result<Vec<ArrayRef>, LookupError> {
let parser = self.converter.parser();
let parsed = keys.iter().map(|k| parser.parse(k));
self.converter
.convert_rows(parsed)
.map_err(|e| LookupError::Internal(format!("decode keys: {e}")))
}
pub fn align(
&self,
keys: &[&[u8]],
fetched: &[RecordBatch],
) -> Result<Vec<Option<RecordBatch>>, LookupError> {
let mut index: FxHashMap<Vec<u8>, (usize, usize)> = FxHashMap::default();
for (batch_idx, batch) in fetched.iter().enumerate() {
if batch.num_rows() == 0 {
continue;
}
let pk_cols = self
.pk_columns
.iter()
.map(|name| {
let idx = batch.schema().index_of(name).map_err(|_| {
LookupError::Internal(format!("pk column not found in result: {name}"))
})?;
Ok(Arc::clone(batch.column(idx)))
})
.collect::<Result<Vec<ArrayRef>, LookupError>>()?;
let rows = self
.converter
.convert_columns(&pk_cols)
.map_err(|e| LookupError::Internal(format!("encode result keys: {e}")))?;
for row in 0..batch.num_rows() {
index
.entry(rows.row(row).as_ref().to_vec())
.or_insert((batch_idx, row));
}
}
Ok(keys
.iter()
.map(|key| index.get(*key).map(|&(bi, row)| fetched[bi].slice(row, 1)))
.collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::Int64Array;
use arrow_schema::{DataType, Field, Schema};
fn aligner() -> KeyAligner {
KeyAligner::new(vec![SortField::new(DataType::Int64)], vec!["id".into()]).unwrap()
}
fn batch(ids: &[i64]) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(ids.to_vec()))]).unwrap()
}
fn encode(ids: &[i64]) -> Vec<Vec<u8>> {
let conv = RowConverter::new(vec![SortField::new(DataType::Int64)]).unwrap();
let rows = conv
.convert_columns(&[Arc::new(Int64Array::from(ids.to_vec()))])
.unwrap();
(0..ids.len())
.map(|i| rows.row(i).as_ref().to_vec())
.collect()
}
#[test]
fn aligns_out_of_order_with_misses_and_dups() {
let aligner = aligner();
let fetched = vec![batch(&[2, 5])];
let keys = encode(&[5, 2, 99, 2]);
let key_refs: Vec<&[u8]> = keys.iter().map(Vec::as_slice).collect();
let out = aligner.align(&key_refs, &fetched).unwrap();
let id = |b: &Option<RecordBatch>| {
b.as_ref().map(|b| {
b.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap()
.value(0)
})
};
assert_eq!(id(&out[0]), Some(5));
assert_eq!(id(&out[1]), Some(2));
assert_eq!(id(&out[2]), None); assert_eq!(id(&out[3]), Some(2)); }
#[test]
fn decode_round_trips_to_pk_columns() {
let aligner = aligner();
let keys = encode(&[7, 8]);
let key_refs: Vec<&[u8]> = keys.iter().map(Vec::as_slice).collect();
let cols = aligner.decode_keys(&key_refs).unwrap();
let ids = cols[0].as_any().downcast_ref::<Int64Array>().unwrap();
assert_eq!(ids.values(), &[7, 8]);
}
}