1use crate::rule::Rule;
2use crate::state::PredicateResult;
3use anyhow::{Context, Result};
4use arrow::array::Array;
5use arrow::record_batch::RecordBatch;
6use async_trait::async_trait;
7use datafusion::prelude::*;
8use std::sync::Arc;
9
10#[async_trait]
11pub trait RuleEvaluator: Send + Sync {
12 fn compile(&self, rule: Rule, schema: &arrow::datatypes::Schema) -> Result<CompiledRuleEdge>;
13 async fn evaluate_batch(
14 &self,
15 batch: &RecordBatch,
16 rules: &[CompiledRuleEdge],
17 window_batches: &[Vec<RecordBatch>],
18 ) -> Result<Vec<(PredicateResult, Option<RecordBatch>)>>;
19}
20
21pub struct DataFusionEvaluator {
22 ctx: SessionContext,
23}
24
25impl Default for DataFusionEvaluator {
26 fn default() -> Self {
27 Self::new()
28 }
29}
30
31impl DataFusionEvaluator {
32 pub fn new() -> Self {
33 Self {
34 ctx: SessionContext::new(),
35 }
36 }
37
38 fn contains_aggregates(predicate: &str) -> bool {
41 let upper = predicate.to_uppercase();
43
44 upper.contains("AVG(")
47 || upper.contains("COUNT(")
48 || upper.contains("SUM(")
49 || upper.contains("MIN(")
50 || upper.contains("MAX(")
51 || upper.contains("STDDEV(")
52 || upper.contains("VARIANCE(")
53 || upper.contains("STDDEV_POP(")
54 || upper.contains("STDDEV_SAMP(")
55 || upper.contains("VAR_POP(")
56 || upper.contains("VAR_SAMP(")
57 }
58}
59
60#[derive(Clone)]
61pub struct CompiledRuleEdge {
62 pub rule: Rule,
63 pub logical_expr: datafusion::logical_expr::Expr, pub compiled_sql: String, pub has_aggregates: bool, }
67
68#[async_trait]
69impl RuleEvaluator for DataFusionEvaluator {
70 fn compile(&self, rule: Rule, schema: &arrow::datatypes::Schema) -> Result<CompiledRuleEdge> {
71 let df_schema = datafusion::common::DFSchema::try_from(schema.clone())?;
72
73 let logical_expr = self
76 .ctx
77 .parse_sql_expr(&rule.predicate, &df_schema)
78 .context("Failed to parse rule predicate")?;
79
80 let has_aggregates = Self::contains_aggregates(&rule.predicate);
83
84 let compiled_sql = format!("SELECT ({}) as match_result", rule.predicate);
86
87 Ok(CompiledRuleEdge {
88 rule,
89 logical_expr,
90 compiled_sql,
91 has_aggregates,
92 })
93 }
94
95 async fn evaluate_batch(
96 &self,
97 batch: &RecordBatch,
98 rules: &[CompiledRuleEdge],
99 window_batches: &[Vec<RecordBatch>],
100 ) -> Result<Vec<(PredicateResult, Option<RecordBatch>)>> {
101 let mut results = Vec::new();
102
103 for (i, rule) in rules.iter().enumerate() {
104 let active_batches = if rule.rule.window_seconds.is_some() {
105 let mut all = window_batches[i].clone();
106 all.push(batch.clone());
107 all
108 } else {
109 vec![batch.clone()]
110 };
111
112 if active_batches.is_empty() {
113 results.push((PredicateResult::False, None));
114 continue;
115 }
116
117 let combined_batch = if active_batches.len() == 1 {
119 active_batches[0].clone()
120 } else {
121 let mut arrays = Vec::new();
123 for batch in &active_batches {
124 for col_idx in 0..batch.num_columns() {
125 if arrays.len() <= col_idx {
126 arrays.push(Vec::new());
127 }
128 arrays[col_idx].push(batch.column(col_idx).clone());
129 }
130 }
131 let concatenated_arrays: Vec<Arc<dyn arrow::array::Array>> = arrays
132 .into_iter()
133 .map(|cols| {
134 let refs: Vec<&dyn arrow::array::Array> =
136 cols.iter().map(|a| a.as_ref()).collect();
137 arrow::compute::concat(&refs).expect("Failed to concatenate arrays")
138 })
139 .collect();
140 RecordBatch::try_new(batch.schema(), concatenated_arrays)?
141 };
142
143 let table_name = format!("rule_input_{}", i);
146 let df = self.ctx.read_batches(vec![combined_batch.clone()])?;
147 self.ctx.register_table(&table_name, df.into_view())?;
148
149 let result_batches = if rule.has_aggregates {
150 let sql = format!(
153 "SELECT ({}) as match_result FROM {}",
154 rule.rule.predicate, table_name
155 );
156 self.ctx.sql(&sql).await?.collect().await?
157 } else {
158 let select_expr = vec![rule.logical_expr.clone().alias("match_result")];
160 let select_df = self.ctx.table(&table_name).await?.select(select_expr)?;
161
162 select_df.collect().await?
163 };
164
165 let mut is_true = false;
169 let mut matched_rows: Vec<usize> = Vec::new();
170
171 if !result_batches.is_empty() {
172 let col = result_batches[0]
173 .column(0)
174 .as_any()
175 .downcast_ref::<arrow::array::BooleanArray>();
176 if let Some(bool_col) = col {
177 if rule.has_aggregates {
178 if !bool_col.is_empty() && !bool_col.is_null(0) && bool_col.value(0) {
180 is_true = true;
181 matched_rows = (0..combined_batch.num_rows()).collect();
183 }
184 } else {
185 for row_idx in 0..bool_col.len() {
187 if !bool_col.is_null(row_idx) && bool_col.value(row_idx) {
188 is_true = true;
189 matched_rows.push(row_idx);
190 }
191 }
192 }
193 }
194 }
195
196 let matched_batch = if is_true && !matched_rows.is_empty() {
198 let matched_indices = arrow::array::UInt32Array::from(
200 matched_rows.iter().map(|&i| i as u32).collect::<Vec<_>>(),
201 );
202 let filtered_columns: Result<Vec<Arc<dyn arrow::array::Array>>, _> = combined_batch
204 .columns()
205 .iter()
206 .map(|col| arrow::compute::take(col, &matched_indices, None))
207 .collect();
208 let filtered_batch =
209 RecordBatch::try_new(combined_batch.schema(), filtered_columns?)?;
210 Some(filtered_batch)
211 } else {
212 None
213 };
214
215 if is_true {
216 results.push((PredicateResult::True, matched_batch));
217 } else {
218 results.push((PredicateResult::False, None));
219 }
220
221 self.ctx.deregister_table(&table_name)?;
222 }
223
224 Ok(results)
225 }
226}
227
228pub fn infer_json_schema(value: &serde_json::Value) -> arrow::datatypes::Schema {
229 match value {
230 serde_json::Value::Array(arr) => {
231 if arr.is_empty() {
232 return arrow::datatypes::Schema::empty();
233 }
234 let mut fields = Vec::new();
235 if let Some(serde_json::Value::Object(map)) = arr.first() {
236 for (k, v) in map {
237 let dt = match v {
238 serde_json::Value::Number(n) if n.is_i64() => {
239 arrow::datatypes::DataType::Int32
240 }
241 serde_json::Value::Number(_) => arrow::datatypes::DataType::Float64,
242 serde_json::Value::Bool(_) => arrow::datatypes::DataType::Boolean,
243 _ => arrow::datatypes::DataType::Utf8,
244 };
245 fields.push(arrow::datatypes::Field::new(k, dt, true));
246 }
247 }
248 arrow::datatypes::Schema::new(fields)
249 }
250 _ => arrow::datatypes::Schema::empty(),
251 }
252}