use std::sync::Arc;
use datafusion_common::Result;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_physical_optimizer::PhysicalOptimizerRule;
use datafusion_physical_plan::ExecutionPlan;
use datafusion_physical_plan::filter::FilterExec;
const TEMPORAL_COLUMNS: &[&str] = &[
"created_at_ms",
"updated_at_ms",
"accessed_at_ms",
"time_start_ms",
"time_end_ms",
];
#[derive(Debug, Default)]
pub struct TemporalIndexRule;
impl TemporalIndexRule {
pub fn new() -> Self {
Self
}
fn references_temporal_column(expr: &dyn std::fmt::Display) -> bool {
let expr_str = expr.to_string();
TEMPORAL_COLUMNS.iter().any(|col| expr_str.contains(col))
}
}
impl PhysicalOptimizerRule for TemporalIndexRule {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
_config: &datafusion_common::config::ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
plan.transform_down(|node| {
let Some(filter) = node.as_any().downcast_ref::<FilterExec>() else {
return Ok(Transformed::no(node));
};
let predicate = filter.predicate();
if !Self::references_temporal_column(predicate) {
return Ok(Transformed::no(node));
}
tracing::debug!(
predicate = %predicate,
"temporal_index_rule: identified temporal filter candidate"
);
Ok(Transformed::no(node))
})
.map(|t| t.data)
}
fn name(&self) -> &str {
"TemporalIndexRule"
}
fn schema_check(&self) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::{Int64Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema};
use datafusion_common::config::ConfigOptions;
use datafusion_datasource::memory::MemorySourceConfig;
fn episodic_batch() -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("created_at_ms", DataType::Int64, false),
]));
RecordBatch::try_new(
schema,
vec![
Arc::new(StringArray::from(vec!["e1", "e2"])),
Arc::new(Int64Array::from(vec![1000, 2000])),
],
)
.unwrap()
}
#[test]
fn identifies_temporal_column() {
assert!(TemporalIndexRule::references_temporal_column(
&"created_at_ms > 1000"
));
assert!(TemporalIndexRule::references_temporal_column(
&"time_start_ms BETWEEN 100 AND 200"
));
assert!(!TemporalIndexRule::references_temporal_column(
&"namespace = 'default'"
));
}
#[test]
fn passthrough_non_temporal() {
let batch = episodic_batch();
let schema = batch.schema();
let mem = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
let predicate = datafusion_physical_expr::expressions::col("id", &mem.schema()).unwrap();
let is_not_null = datafusion_physical_expr::expressions::IsNotNullExpr::new(predicate);
let filter = Arc::new(FilterExec::try_new(Arc::new(is_not_null), mem).unwrap())
as Arc<dyn ExecutionPlan>;
let rule = TemporalIndexRule::new();
let config = ConfigOptions::new();
let optimized = rule.optimize(filter.clone(), &config).unwrap();
assert_eq!(optimized.name(), "FilterExec");
}
#[test]
fn identifies_temporal_filter_in_plan() {
let batch = episodic_batch();
let schema = batch.schema();
let mem = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
let ts_col =
datafusion_physical_expr::expressions::col("created_at_ms", &mem.schema()).unwrap();
let is_not_null = datafusion_physical_expr::expressions::IsNotNullExpr::new(ts_col);
let filter = Arc::new(FilterExec::try_new(Arc::new(is_not_null), mem).unwrap())
as Arc<dyn ExecutionPlan>;
let rule = TemporalIndexRule::new();
let config = ConfigOptions::new();
let optimized = rule.optimize(filter.clone(), &config).unwrap();
assert_eq!(optimized.name(), "FilterExec");
let opt_filter = optimized
.as_any()
.downcast_ref::<FilterExec>()
.expect("should still be FilterExec");
let pred_str = format!("{}", opt_filter.predicate());
assert!(
pred_str.contains("created_at_ms"),
"predicate should reference created_at_ms, got: {pred_str}"
);
}
}