use datafusion::common::ScalarValue;
use datafusion::logical_expr::Operator;
use datafusion::prelude::Expr;
#[derive(Debug, Clone, Default)]
pub struct KeyBounds {
pub start_key: Option<Vec<u8>>,
pub end_key: Option<Vec<u8>>,
}
#[allow(dead_code)]
impl KeyBounds {
pub fn unbounded() -> Self {
Self {
start_key: None,
end_key: None,
}
}
pub fn is_bounded(&self) -> bool {
self.start_key.is_some() || self.end_key.is_some()
}
pub fn start_key_slice(&self) -> &[u8] {
self.start_key.as_deref().unwrap_or(&[])
}
pub fn end_key_slice(&self) -> &[u8] {
self.end_key.as_deref().unwrap_or(&[])
}
}
pub fn extract_key_bounds(filters: &[Expr], pk_column: &str) -> KeyBounds {
let mut bounds = KeyBounds::unbounded();
for filter in filters {
extract_from_expr(filter, pk_column, &mut bounds);
}
if bounds.is_bounded() {
log::debug!(
"Extracted key bounds for '{}': start={:?}, end={:?}",
pk_column,
bounds.start_key.as_ref().map(|k| k.len()),
bounds.end_key.as_ref().map(|k| k.len())
);
}
bounds
}
fn extract_from_expr(expr: &Expr, pk_column: &str, bounds: &mut KeyBounds) {
match expr {
Expr::BinaryExpr(binary) => {
if let Expr::Column(col) = binary.left.as_ref()
&& col.name() == pk_column
{
if let Some(value) = extract_scalar_bytes(binary.right.as_ref()) {
match binary.op {
Operator::Eq => {
bounds.start_key = Some(value.clone());
let mut end = value;
increment_key(&mut end);
bounds.end_key = Some(end);
}
Operator::Gt => {
let mut start = value;
increment_key(&mut start);
update_start_key(bounds, start);
}
Operator::GtEq => {
update_start_key(bounds, value);
}
Operator::Lt => {
update_end_key(bounds, value);
}
Operator::LtEq => {
let mut end = value;
increment_key(&mut end);
update_end_key(bounds, end);
}
_ => {}
}
}
}
if binary.op == Operator::And {
extract_from_expr(binary.left.as_ref(), pk_column, bounds);
extract_from_expr(binary.right.as_ref(), pk_column, bounds);
}
}
Expr::Between(between) => {
if let Expr::Column(col) = between.expr.as_ref()
&& col.name() == pk_column
&& let (Some(low), Some(high)) = (
extract_scalar_bytes(between.low.as_ref()),
extract_scalar_bytes(between.high.as_ref()),
)
{
update_start_key(bounds, low);
let mut end = high;
increment_key(&mut end);
update_end_key(bounds, end);
}
}
_ => {}
}
}
fn extract_scalar_bytes(expr: &Expr) -> Option<Vec<u8>> {
if let Expr::Literal(scalar, _) = expr {
match scalar {
ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s)) => {
Some(s.as_bytes().to_vec())
}
ScalarValue::Binary(Some(b)) | ScalarValue::LargeBinary(Some(b)) => Some(b.clone()),
ScalarValue::Int64(Some(i)) => Some(i.to_be_bytes().to_vec()),
ScalarValue::Int32(Some(i)) => Some(i.to_be_bytes().to_vec()),
ScalarValue::UInt64(Some(i)) => Some(i.to_be_bytes().to_vec()),
ScalarValue::UInt32(Some(i)) => Some(i.to_be_bytes().to_vec()),
_ => None,
}
} else {
None
}
}
fn update_start_key(bounds: &mut KeyBounds, key: Vec<u8>) {
match &bounds.start_key {
None => bounds.start_key = Some(key),
Some(existing) if key > *existing => bounds.start_key = Some(key),
_ => {}
}
}
fn update_end_key(bounds: &mut KeyBounds, key: Vec<u8>) {
match &bounds.end_key {
None => bounds.end_key = Some(key),
Some(existing) if key < *existing => bounds.end_key = Some(key),
_ => {}
}
}
fn increment_key(key: &mut Vec<u8>) {
for byte in key.iter_mut().rev() {
if *byte < 255 {
*byte += 1;
return;
}
*byte = 0;
}
key.push(0);
}
#[cfg(test)]
mod tests {
use super::*;
use datafusion::logical_expr::lit;
use datafusion::prelude::col;
#[test]
fn test_extract_eq_bounds() {
let filters = vec![col("id").eq(lit("abc"))];
let bounds = extract_key_bounds(&filters, "id");
assert!(bounds.is_bounded());
assert_eq!(bounds.start_key, Some(b"abc".to_vec()));
}
#[test]
fn test_extract_range_bounds() {
let filters = vec![col("id").gt_eq(lit("a")), col("id").lt(lit("z"))];
let bounds = extract_key_bounds(&filters, "id");
assert!(bounds.is_bounded());
assert_eq!(bounds.start_key, Some(b"a".to_vec()));
assert_eq!(bounds.end_key, Some(b"z".to_vec()));
}
#[test]
fn test_no_pk_filter() {
let filters = vec![col("name").eq(lit("test"))];
let bounds = extract_key_bounds(&filters, "id");
assert!(!bounds.is_bounded());
}
}