use datafusion::common::ScalarValue;
use datafusion::logical_expr::Expr;
use std::collections::HashMap;
use tracing::{debug, info, trace, warn};
#[derive(Debug, Clone)]
pub struct CoordFilter {
pub coord_name: String,
pub value: ScalarValue,
}
#[derive(Debug, Clone, Default)]
pub struct CoordFilters {
pub filters: HashMap<String, ScalarValue>,
}
impl CoordFilters {
pub fn new() -> Self {
Self {
filters: HashMap::new(),
}
}
pub fn is_empty(&self) -> bool {
self.filters.is_empty()
}
pub fn get(&self, coord_name: &str) -> Option<&ScalarValue> {
self.filters.get(coord_name)
}
pub fn len(&self) -> usize {
self.filters.len()
}
}
pub fn parse_coord_filters(filters: &[Expr], coord_names: &[String]) -> CoordFilters {
let mut result = CoordFilters::new();
for filter in filters {
extract_equality_filters(filter, coord_names, &mut result);
}
if !result.is_empty() {
info!(
num_filters = result.len(),
filters = ?result.filters.keys().collect::<Vec<_>>(),
"Extracted coordinate filters for pushdown"
);
} else {
debug!("No coordinate equality filters found for pushdown");
}
result
}
fn extract_equality_filters(expr: &Expr, coord_names: &[String], result: &mut CoordFilters) {
match expr {
Expr::BinaryExpr(binary) if binary.op == datafusion::logical_expr::Operator::And => {
extract_equality_filters(&binary.left, coord_names, result);
extract_equality_filters(&binary.right, coord_names, result);
}
Expr::BinaryExpr(binary) if binary.op == datafusion::logical_expr::Operator::Eq => {
if let Some((col_name, value)) = extract_column_literal_eq(&binary.left, &binary.right)
{
if coord_names.contains(&col_name) {
debug!(
coord = %col_name,
value = %value,
"Found coordinate equality filter"
);
result.filters.insert(col_name, value);
} else {
trace!(
column = %col_name,
"Equality filter on non-coordinate column, skipping"
);
}
}
}
Expr::Cast(cast) => {
extract_equality_filters(&cast.expr, coord_names, result);
}
other => {
trace!(expr_type = %other.variant_name(), "Skipping non-equality filter expression");
}
}
}
fn extract_column_literal_eq(left: &Expr, right: &Expr) -> Option<(String, ScalarValue)> {
if let (Some(col_name), Some(value)) = (extract_column_name(left), extract_literal(right)) {
return Some((col_name, value));
}
if let (Some(value), Some(col_name)) = (extract_literal(left), extract_column_name(right)) {
return Some((col_name, value));
}
None
}
fn extract_column_name(expr: &Expr) -> Option<String> {
match expr {
Expr::Column(col) => Some(col.name.clone()),
Expr::Cast(cast) => extract_column_name(&cast.expr),
Expr::TryCast(cast) => extract_column_name(&cast.expr),
_ => None,
}
}
fn extract_literal(expr: &Expr) -> Option<ScalarValue> {
match expr {
Expr::Literal(value, _) => Some(unwrap_dictionary_value(value.clone())),
Expr::Cast(cast) => {
if let Expr::Literal(value, _) = cast.expr.as_ref() {
value
.cast_to(&cast.data_type)
.ok()
.map(unwrap_dictionary_value)
} else {
None
}
}
_ => None,
}
}
fn unwrap_dictionary_value(value: ScalarValue) -> ScalarValue {
match value {
ScalarValue::Dictionary(_, inner) => unwrap_dictionary_value(*inner),
other => other,
}
}
pub fn calculate_coord_ranges(
filters: &CoordFilters,
coord_names: &[String],
coord_values: &[CoordValuesRef<'_>],
) -> Option<Vec<(usize, usize)>> {
let mut ranges = Vec::with_capacity(coord_names.len());
for (i, name) in coord_names.iter().enumerate() {
let values = &coord_values[i];
let range = if let Some(filter_value) = filters.get(name) {
if let Some(idx) = find_value_index(values, filter_value) {
debug!(
coord = %name,
filter_value = %filter_value,
index = idx,
"Found filter value at index"
);
(idx, idx + 1) } else {
warn!(
coord = %name,
filter_value = %filter_value,
"Filter value not found in coordinate - query will return no results"
);
return None; }
} else {
(0, values.len())
};
ranges.push(range);
}
Some(ranges)
}
pub enum CoordValuesRef<'a> {
Int64(&'a [i64]),
Float32(&'a [f32]),
Float64(&'a [f64]),
}
impl<'a> CoordValuesRef<'a> {
pub fn len(&self) -> usize {
match self {
CoordValuesRef::Int64(v) => v.len(),
CoordValuesRef::Float32(v) => v.len(),
CoordValuesRef::Float64(v) => v.len(),
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
fn find_value_index(values: &CoordValuesRef<'_>, target: &ScalarValue) -> Option<usize> {
match (values, target) {
(CoordValuesRef::Int64(vals), ScalarValue::Int64(Some(v))) => {
vals.iter().position(|x| x == v)
}
(CoordValuesRef::Int64(vals), ScalarValue::Int32(Some(v))) => {
let v64 = *v as i64;
vals.iter().position(|x| *x == v64)
}
(CoordValuesRef::Float32(vals), ScalarValue::Float32(Some(v))) => {
vals.iter().position(|x| (x - v).abs() < f32::EPSILON)
}
(CoordValuesRef::Float32(vals), ScalarValue::Float64(Some(v))) => {
let v32 = *v as f32;
vals.iter().position(|x| (x - v32).abs() < f32::EPSILON)
}
(CoordValuesRef::Float64(vals), ScalarValue::Float64(Some(v))) => {
vals.iter().position(|x| (x - v).abs() < f64::EPSILON)
}
(CoordValuesRef::Float64(vals), ScalarValue::Float32(Some(v))) => {
let v64 = *v as f64;
vals.iter().position(|x| (x - v64).abs() < f64::EPSILON)
}
(CoordValuesRef::Float32(vals), ScalarValue::Int64(Some(v))) => {
let vf = *v as f32;
vals.iter().position(|x| (x - vf).abs() < f32::EPSILON)
}
(CoordValuesRef::Float64(vals), ScalarValue::Int64(Some(v))) => {
let vf = *v as f64;
vals.iter().position(|x| (x - vf).abs() < f64::EPSILON)
}
_ => {
debug!(
target_type = ?std::mem::discriminant(target),
"Unsupported filter value type for coordinate lookup"
);
None
}
}
}
pub fn calculate_filtered_rows(coord_ranges: &[(usize, usize)]) -> usize {
coord_ranges
.iter()
.map(|(start, end)| end - start)
.product()
}
pub fn coord_ranges_to_array_ranges(coord_ranges: &[(usize, usize)]) -> Vec<std::ops::Range<u64>> {
coord_ranges
.iter()
.map(|(start, end)| (*start as u64)..(*end as u64))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use datafusion::prelude::*;
#[test]
fn test_parse_simple_equality() {
let coord_names = vec!["time".to_string(), "lat".to_string()];
let filter = col("time").eq(lit(100i64));
let filters = parse_coord_filters(&[filter], &coord_names);
assert_eq!(filters.len(), 1);
assert!(filters.get("time").is_some());
}
#[test]
fn test_parse_and_filters() {
let coord_names = vec!["time".to_string(), "hybrid".to_string(), "lat".to_string()];
let filter = col("time")
.eq(lit(100i64))
.and(col("hybrid").eq(lit(50i64)));
let filters = parse_coord_filters(&[filter], &coord_names);
assert_eq!(filters.len(), 2);
assert!(filters.get("time").is_some());
assert!(filters.get("hybrid").is_some());
}
#[test]
fn test_ignore_non_coord_columns() {
let coord_names = vec!["time".to_string()];
let filter = col("temperature").eq(lit(20i64));
let filters = parse_coord_filters(&[filter], &coord_names);
assert!(filters.is_empty());
}
#[test]
fn test_find_value_index() {
let vals = vec![10i64, 20, 30, 40, 50];
let values_ref = CoordValuesRef::Int64(&vals);
assert_eq!(
find_value_index(&values_ref, &ScalarValue::Int64(Some(30))),
Some(2)
);
assert_eq!(
find_value_index(&values_ref, &ScalarValue::Int64(Some(100))),
None
);
}
#[test]
fn test_calculate_filtered_rows() {
let ranges = vec![(5, 6), (10, 11), (0, 721), (0, 1440)];
let rows = calculate_filtered_rows(&ranges);
assert_eq!(rows, 1 * 1 * 721 * 1440);
}
}