hirn_exec/rules/
temporal_index.rs1use std::sync::Arc;
8
9use datafusion_common::Result;
10use datafusion_common::tree_node::{Transformed, TreeNode};
11use datafusion_physical_optimizer::PhysicalOptimizerRule;
12use datafusion_physical_plan::ExecutionPlan;
13use datafusion_physical_plan::filter::FilterExec;
14
15const TEMPORAL_COLUMNS: &[&str] = &[
17 "created_at_ms",
18 "updated_at_ms",
19 "accessed_at_ms",
20 "time_start_ms",
21 "time_end_ms",
22];
23
24#[derive(Debug, Default)]
30pub struct TemporalIndexRule;
31
32impl TemporalIndexRule {
33 pub fn new() -> Self {
34 Self
35 }
36
37 fn references_temporal_column(expr: &dyn std::fmt::Display) -> bool {
39 let expr_str = expr.to_string();
40 TEMPORAL_COLUMNS.iter().any(|col| expr_str.contains(col))
41 }
42}
43
44impl PhysicalOptimizerRule for TemporalIndexRule {
45 fn optimize(
46 &self,
47 plan: Arc<dyn ExecutionPlan>,
48 _config: &datafusion_common::config::ConfigOptions,
49 ) -> Result<Arc<dyn ExecutionPlan>> {
50 plan.transform_down(|node| {
51 let Some(filter) = node.as_any().downcast_ref::<FilterExec>() else {
53 return Ok(Transformed::no(node));
54 };
55
56 let predicate = filter.predicate();
57 if !Self::references_temporal_column(predicate) {
58 return Ok(Transformed::no(node));
59 }
60
61 tracing::debug!(
68 predicate = %predicate,
69 "temporal_index_rule: identified temporal filter candidate"
70 );
71
72 Ok(Transformed::no(node))
73 })
74 .map(|t| t.data)
75 }
76
77 fn name(&self) -> &str {
78 "TemporalIndexRule"
79 }
80
81 fn schema_check(&self) -> bool {
82 true
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89 use arrow_array::{Int64Array, RecordBatch, StringArray};
90 use arrow_schema::{DataType, Field, Schema};
91 use datafusion_common::config::ConfigOptions;
92 use datafusion_datasource::memory::MemorySourceConfig;
93
94 fn episodic_batch() -> RecordBatch {
95 let schema = Arc::new(Schema::new(vec![
96 Field::new("id", DataType::Utf8, false),
97 Field::new("created_at_ms", DataType::Int64, false),
98 ]));
99 RecordBatch::try_new(
100 schema,
101 vec![
102 Arc::new(StringArray::from(vec!["e1", "e2"])),
103 Arc::new(Int64Array::from(vec![1000, 2000])),
104 ],
105 )
106 .unwrap()
107 }
108
109 #[test]
110 fn identifies_temporal_column() {
111 assert!(TemporalIndexRule::references_temporal_column(
112 &"created_at_ms > 1000"
113 ));
114 assert!(TemporalIndexRule::references_temporal_column(
115 &"time_start_ms BETWEEN 100 AND 200"
116 ));
117 assert!(!TemporalIndexRule::references_temporal_column(
118 &"namespace = 'default'"
119 ));
120 }
121
122 #[test]
123 fn passthrough_non_temporal() {
124 let batch = episodic_batch();
125 let schema = batch.schema();
126 let mem = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
127
128 let predicate = datafusion_physical_expr::expressions::col("id", &mem.schema()).unwrap();
130 let is_not_null = datafusion_physical_expr::expressions::IsNotNullExpr::new(predicate);
131 let filter = Arc::new(FilterExec::try_new(Arc::new(is_not_null), mem).unwrap())
132 as Arc<dyn ExecutionPlan>;
133
134 let rule = TemporalIndexRule::new();
135 let config = ConfigOptions::new();
136 let optimized = rule.optimize(filter.clone(), &config).unwrap();
137
138 assert_eq!(optimized.name(), "FilterExec");
140 }
141
142 #[test]
147 fn identifies_temporal_filter_in_plan() {
148 let batch = episodic_batch();
149 let schema = batch.schema();
150 let mem = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
151
152 let ts_col =
155 datafusion_physical_expr::expressions::col("created_at_ms", &mem.schema()).unwrap();
156 let is_not_null = datafusion_physical_expr::expressions::IsNotNullExpr::new(ts_col);
157 let filter = Arc::new(FilterExec::try_new(Arc::new(is_not_null), mem).unwrap())
158 as Arc<dyn ExecutionPlan>;
159
160 let rule = TemporalIndexRule::new();
161 let config = ConfigOptions::new();
162 let optimized = rule.optimize(filter.clone(), &config).unwrap();
163
164 assert_eq!(optimized.name(), "FilterExec");
166 let opt_filter = optimized
168 .as_any()
169 .downcast_ref::<FilterExec>()
170 .expect("should still be FilterExec");
171 let pred_str = format!("{}", opt_filter.predicate());
172 assert!(
173 pred_str.contains("created_at_ms"),
174 "predicate should reference created_at_ms, got: {pred_str}"
175 );
176 }
177}