1use std::time::Duration;
4
5use arrow::array::{Array, StringArray};
6use arrow::record_batch::RecordBatch;
7use krishiv_plan::cep::{
8 CepKeyState, CompiledPattern, PartitionedCepMatcher, Pattern, SequentialPatternMatcher,
9};
10use krishiv_plan::{ExecutionKind, LogicalPlan, NodeOp, PlanNode};
11
12use crate::{SqlError, SqlResult};
13
14#[derive(Debug, Clone)]
28pub struct MatchRecognizeStatement {
29 pub source_table: String,
30 pub key_column: String,
31 pub event_time_column: String,
32 pub pattern: CompiledPattern,
33}
34
35pub fn parse_match_recognize(sql: &str) -> SqlResult<Option<MatchRecognizeStatement>> {
39 let trimmed = sql.trim().trim_end_matches(';');
40 let upper = trimmed.to_ascii_uppercase();
41 let Some(mr_pos) = upper.find(" MATCH_RECOGNIZE ") else {
42 return Ok(None);
43 };
44 let from_pos = upper.find(" FROM ").ok_or_else(|| SqlError::Unsupported {
45 feature: "MATCH_RECOGNIZE requires SELECT ... FROM <table>".into(),
46 })?;
47 let source_table = trimmed[from_pos + 6..mr_pos].trim().to_string();
48 if source_table.is_empty() {
49 return Err(SqlError::EmptyTableName);
50 }
51
52 let body_start = trimmed[mr_pos..]
53 .find('(')
54 .ok_or_else(|| SqlError::Unsupported {
55 feature: "MATCH_RECOGNIZE requires parenthesized body".into(),
56 })?
57 + mr_pos
58 + 1;
59 let body_end = trimmed.rfind(')').ok_or_else(|| SqlError::Unsupported {
60 feature: "MATCH_RECOGNIZE requires closing ')'".into(),
61 })?;
62 let body = &trimmed[body_start..body_end];
63 let body_upper = body.to_ascii_uppercase();
64
65 let key_column = extract_after_keyword(body, &body_upper, "PARTITION BY", "ORDER BY")?;
66 let event_time_column = extract_after_keyword(body, &body_upper, "ORDER BY", "PATTERN")?;
67 let pattern_body = extract_parenthesized_after(body, &body_upper, "PATTERN")?;
68 let stages = pattern_body
69 .split_whitespace()
70 .filter(|s| !s.is_empty())
71 .collect::<Vec<_>>();
72 if stages.is_empty() {
73 return Err(SqlError::Unsupported {
74 feature: "MATCH_RECOGNIZE PATTERN must contain at least one stage".into(),
75 });
76 }
77 let first_stage = stages
78 .first()
79 .copied()
80 .ok_or_else(|| SqlError::Unsupported {
81 feature: "MATCH_RECOGNIZE PATTERN stage list is empty".into(),
82 })?;
83 let mut pattern = Pattern::begin(first_stage);
84 for stage in stages.iter().skip(1) {
85 pattern = pattern.followed_by(*stage);
86 }
87 if let Some(window_ms) = parse_within_ms(body, &body_upper)? {
88 pattern = pattern.within(Duration::from_millis(window_ms));
89 }
90 let pattern = pattern.compile().map_err(|e| SqlError::Unsupported {
91 feature: format!("MATCH_RECOGNIZE pattern: {e}"),
92 })?;
93
94 Ok(Some(MatchRecognizeStatement {
95 source_table,
96 key_column,
97 event_time_column,
98 pattern,
99 }))
100}
101
102pub fn plan_match_recognize(stmt: MatchRecognizeStatement, query: &str) -> LogicalPlan {
104 let stage_names = stmt
105 .pattern
106 .stages
107 .iter()
108 .map(|stage| stage.name.clone())
109 .collect::<Vec<_>>();
110 LogicalPlan::new("match-recognize", ExecutionKind::Streaming).with_node(
111 PlanNode::new(
112 "match-recognize",
113 format!(
114 "MATCH_RECOGNIZE source={} partition_by={} order_by={} pattern=({}) within_ms={}",
115 stmt.source_table,
116 stmt.key_column,
117 stmt.event_time_column,
118 stage_names.join(" "),
119 stmt.pattern.window_ms
120 ),
121 ExecutionKind::Streaming,
122 )
123 .with_op(NodeOp::Other {
124 description: format!("cep:{query}"),
125 }),
126 )
127}
128
129pub fn execute_match_recognize(
139 stmt: MatchRecognizeStatement,
140 source_batches: &[RecordBatch],
141) -> SqlResult<Vec<RecordBatch>> {
142 use arrow::array::Int64Array;
143 use std::collections::HashMap;
144
145 if source_batches.is_empty() {
146 return Ok(Vec::new());
147 }
148
149 let schema = source_batches
151 .first()
152 .ok_or_else(|| SqlError::Unsupported {
153 feature: "source_batches is empty".into(),
154 })?
155 .schema();
156 let key_idx = schema
157 .index_of(&stmt.key_column)
158 .map_err(|_| SqlError::Unsupported {
159 feature: format!(
160 "MATCH_RECOGNIZE: key column '{}' not found",
161 stmt.key_column
162 ),
163 })?;
164 let time_idx = schema
165 .index_of(&stmt.event_time_column)
166 .map_err(|_| SqlError::Unsupported {
167 feature: format!(
168 "MATCH_RECOGNIZE: event time column '{}' not found",
169 stmt.event_time_column
170 ),
171 })?;
172
173 let mut events: Vec<(String, i64, usize, usize)> = Vec::new();
178 for (batch_idx, batch) in source_batches.iter().enumerate() {
179 let key_col = batch.column(key_idx);
180 let time_col = batch
181 .column(time_idx)
182 .as_any()
183 .downcast_ref::<Int64Array>()
184 .ok_or_else(|| SqlError::Unsupported {
185 feature: format!(
186 "MATCH_RECOGNIZE: event time column '{}' must be Int64",
187 stmt.event_time_column
188 ),
189 })?;
190 let key_str = key_col
191 .as_any()
192 .downcast_ref::<StringArray>()
193 .ok_or_else(|| SqlError::Unsupported {
194 feature: format!(
195 "MATCH_RECOGNIZE: partition key column '{}' must be Utf8 (got {})",
196 stmt.key_column,
197 key_col.data_type(),
198 ),
199 })?;
200 for i in 0..batch.num_rows() {
201 let key = if key_str.is_null(i) {
202 continue;
203 } else {
204 key_str.value(i).to_string()
205 };
206 if time_col.is_null(i) {
207 continue;
208 }
209 events.push((key, time_col.value(i), batch_idx, i));
210 }
211 }
212 events.sort_by_key(|(_, t, _, _)| *t);
213
214 let matcher = SequentialPatternMatcher::new(stmt.pattern.clone());
216 let mut key_states: HashMap<String, CepKeyState> = HashMap::new();
217 let mut output: Vec<RecordBatch> = Vec::new();
218
219 let stage_names: Vec<&str> = stmt
220 .pattern
221 .stages
222 .iter()
223 .map(|s| s.name.as_str())
224 .collect();
225
226 for (key, event_time, batch_idx, row_idx) in &events {
227 let Some(batch) = source_batches.get(*batch_idx) else {
231 continue;
232 };
233 let row = batch.slice(*row_idx, 1);
234 let state = key_states.entry(key.clone()).or_default();
235 let partial_key_before = state
240 .partial
241 .as_ref()
242 .map(|p| (p.stage_index, p.start_time_ms));
243 for &stage in &stage_names {
244 let completed = matcher.process_event(state, stage, row.clone(), *event_time);
245 if !completed.is_empty() {
246 for matched_rows in completed {
247 if let Ok(concat) = arrow::compute::concat_batches(&schema, &matched_rows) {
248 output.push(concat);
249 }
250 }
251 break;
252 }
253 let partial_key_after = state
256 .partial
257 .as_ref()
258 .map(|p| (p.stage_index, p.start_time_ms));
259 if partial_key_after != partial_key_before {
260 break;
261 }
262 }
263 }
264
265 Ok(output)
266}
267
268pub fn execute_streaming_match_recognize(
280 stmt: &MatchRecognizeStatement,
281 new_batches: &[RecordBatch],
282 state: &mut PartitionedCepMatcher<String>,
283) -> SqlResult<Vec<RecordBatch>> {
284 use arrow::array::Int64Array;
285
286 if new_batches.is_empty() {
287 return Ok(Vec::new());
288 }
289
290 let schema = new_batches
291 .first()
292 .ok_or_else(|| SqlError::Unsupported {
293 feature: "new_batches is empty".into(),
294 })?
295 .schema();
296 let key_idx = schema
297 .index_of(&stmt.key_column)
298 .map_err(|_| SqlError::Unsupported {
299 feature: format!(
300 "MATCH_RECOGNIZE: key column '{}' not found",
301 stmt.key_column
302 ),
303 })?;
304 let time_idx = schema
305 .index_of(&stmt.event_time_column)
306 .map_err(|_| SqlError::Unsupported {
307 feature: format!(
308 "MATCH_RECOGNIZE: event time column '{}' not found",
309 stmt.event_time_column
310 ),
311 })?;
312
313 let mut events: Vec<(String, i64, usize, usize)> = Vec::new();
315 for (batch_idx, batch) in new_batches.iter().enumerate() {
316 let key_col = batch.column(key_idx);
317 let time_col = batch
318 .column(time_idx)
319 .as_any()
320 .downcast_ref::<Int64Array>()
321 .ok_or_else(|| SqlError::Unsupported {
322 feature: format!(
323 "MATCH_RECOGNIZE: event time column '{}' must be Int64",
324 stmt.event_time_column
325 ),
326 })?;
327 let key_str = key_col
328 .as_any()
329 .downcast_ref::<StringArray>()
330 .ok_or_else(|| SqlError::Unsupported {
331 feature: format!(
332 "MATCH_RECOGNIZE: partition key column '{}' must be Utf8 (got {})",
333 stmt.key_column,
334 key_col.data_type(),
335 ),
336 })?;
337 for i in 0..batch.num_rows() {
338 let key = if key_str.is_null(i) {
339 continue;
340 } else {
341 key_str.value(i).to_string()
342 };
343 if time_col.is_null(i) {
344 continue;
345 }
346 events.push((key, time_col.value(i), batch_idx, i));
347 }
348 }
349 events.sort_by_key(|(_, t, _, _)| *t);
350
351 let stage_names: Vec<&str> = stmt
352 .pattern
353 .stages
354 .iter()
355 .map(|s| s.name.as_str())
356 .collect();
357
358 let mut output: Vec<RecordBatch> = Vec::new();
359 let mut max_event_time: Option<i64> = None;
360
361 for (key, event_time, batch_idx, row_idx) in &events {
362 max_event_time = Some(max_event_time.unwrap_or(*event_time).max(*event_time));
363 let Some(batch) = new_batches.get(*batch_idx) else {
364 continue;
365 };
366 let row = batch.slice(*row_idx, 1);
367 for &stage in &stage_names {
368 let completed = state.process_event(key.clone(), stage, row.clone(), *event_time);
369 if !completed.is_empty() {
370 for matched_rows in completed {
371 if let Ok(concat) = arrow::compute::concat_batches(&schema, &matched_rows) {
372 output.push(concat);
373 }
374 }
375 break;
376 }
377 }
378 }
379
380 if let Some(max_ts) = max_event_time {
383 let evict_before = max_ts - 2 * stmt.pattern.window_ms as i64;
384 state.evict_keys_before(evict_before);
385 }
386
387 Ok(output)
388}
389
390fn extract_after_keyword(
391 body: &str,
392 body_upper: &str,
393 start_keyword: &str,
394 end_keyword: &str,
395) -> SqlResult<String> {
396 let start = body_upper
397 .find(start_keyword)
398 .ok_or_else(|| SqlError::Unsupported {
399 feature: format!("MATCH_RECOGNIZE requires {start_keyword}"),
400 })?
401 + start_keyword.len();
402 let end = body_upper[start..]
403 .find(end_keyword)
404 .ok_or_else(|| SqlError::Unsupported {
405 feature: format!("MATCH_RECOGNIZE requires {end_keyword}"),
406 })?
407 + start;
408 let value = body[start..end].trim().to_string();
409 if value.is_empty() {
410 return Err(SqlError::Unsupported {
411 feature: format!("MATCH_RECOGNIZE empty {start_keyword}"),
412 });
413 }
414 Ok(value)
415}
416
417fn extract_parenthesized_after(body: &str, body_upper: &str, keyword: &str) -> SqlResult<String> {
418 let start = body_upper
419 .find(keyword)
420 .ok_or_else(|| SqlError::Unsupported {
421 feature: format!("MATCH_RECOGNIZE requires {keyword}"),
422 })?
423 + keyword.len();
424 let open = body[start..]
425 .find('(')
426 .ok_or_else(|| SqlError::Unsupported {
427 feature: format!("MATCH_RECOGNIZE {keyword} requires '('"),
428 })?
429 + start;
430 let close = body[open + 1..]
431 .find(')')
432 .ok_or_else(|| SqlError::Unsupported {
433 feature: format!("MATCH_RECOGNIZE {keyword} requires ')'"),
434 })?
435 + open
436 + 1;
437 Ok(body[open + 1..close].trim().to_string())
438}
439
440fn parse_within_ms(body: &str, body_upper: &str) -> SqlResult<Option<u64>> {
441 let Some(start) = body_upper.find("WITHIN") else {
442 return Ok(None);
443 };
444 let mut parts = body[start + "WITHIN".len()..].split_whitespace();
445 let value = parts
446 .next()
447 .ok_or_else(|| SqlError::Unsupported {
448 feature: "MATCH_RECOGNIZE WITHIN requires a value".into(),
449 })?
450 .parse::<u64>()
451 .map_err(|_| SqlError::Unsupported {
452 feature: "MATCH_RECOGNIZE WITHIN value must be an integer".into(),
453 })?;
454 let unit = parts.next().unwrap_or("MILLISECONDS").to_ascii_uppercase();
455 let multiplier = match unit.as_str() {
456 "MILLISECOND" | "MILLISECONDS" | "MS" => 1,
457 "SECOND" | "SECONDS" | "S" => 1_000,
458 "MINUTE" | "MINUTES" | "M" => 60_000,
459 other => {
460 return Err(SqlError::Unsupported {
461 feature: format!("MATCH_RECOGNIZE unsupported WITHIN unit {other}"),
462 });
463 }
464 };
465 Ok(Some(value.saturating_mul(multiplier)))
466}
467
468#[cfg(test)]
469mod tests {
470 use super::*;
471
472 fn make_batch_with_key_ts(keys: &[&str], times: &[i64]) -> arrow::record_batch::RecordBatch {
473 use arrow::array::{Int64Array, StringArray};
474 use arrow::datatypes::{DataType, Field, Schema};
475 use std::sync::Arc;
476 let schema = Arc::new(Schema::new(vec![
477 Field::new("user_id", DataType::Utf8, false),
478 Field::new("ts", DataType::Int64, false),
479 ]));
480 arrow::record_batch::RecordBatch::try_new(
481 schema,
482 vec![
483 Arc::new(StringArray::from(keys.to_vec())) as _,
484 Arc::new(Int64Array::from(times.to_vec())) as _,
485 ],
486 )
487 .unwrap()
488 }
489
490 #[test]
491 fn execute_match_recognize_three_stage_pattern_produces_match() {
492 use krishiv_plan::cep::Pattern;
493 use std::time::Duration;
494 let pattern = Pattern::begin("A")
495 .followed_by("B")
496 .followed_by("C")
497 .within(Duration::from_secs(60))
498 .compile()
499 .unwrap();
500
501 let stmt = MatchRecognizeStatement {
502 source_table: "events".to_string(),
503 key_column: "user_id".to_string(),
504 event_time_column: "ts".to_string(),
505 pattern,
506 };
507
508 let batch =
510 make_batch_with_key_ts(&["u1", "u1", "u1", "u2"], &[1_000, 2_000, 3_000, 9_000]);
511
512 let result = execute_match_recognize(stmt, &[batch]).unwrap();
513 assert_eq!(result.len(), 1, "expected one completed A→B→C match for u1");
514 assert_eq!(
515 result[0].num_rows(),
516 3,
517 "match should span all three stage events"
518 );
519 }
520
521 #[test]
522 fn execute_match_recognize_no_match_when_window_expired() {
523 use krishiv_plan::cep::Pattern;
524 use std::time::Duration;
525 let pattern = Pattern::begin("A")
526 .followed_by("B")
527 .within(Duration::from_millis(100))
528 .compile()
529 .unwrap();
530
531 let stmt = MatchRecognizeStatement {
532 source_table: "events".to_string(),
533 key_column: "user_id".to_string(),
534 event_time_column: "ts".to_string(),
535 pattern,
536 };
537
538 let batch = make_batch_with_key_ts(&["u1", "u1"], &[0, 200]);
540 let result = execute_match_recognize(stmt, &[batch]).unwrap();
541 assert!(result.is_empty(), "expired window must not produce a match");
542 }
543
544 #[test]
545 fn execute_match_recognize_empty_source_returns_empty() {
546 use krishiv_plan::cep::Pattern;
547 use std::time::Duration;
548 let pattern = Pattern::begin("A")
549 .followed_by("B")
550 .within(Duration::from_secs(10))
551 .compile()
552 .unwrap();
553 let stmt = MatchRecognizeStatement {
554 source_table: "events".to_string(),
555 key_column: "user_id".to_string(),
556 event_time_column: "ts".to_string(),
557 pattern,
558 };
559 let result = execute_match_recognize(stmt, &[]).unwrap();
560 assert!(result.is_empty());
561 }
562
563 #[test]
564 fn execute_match_recognize_two_keys_both_complete() {
565 use arrow::array::{Int64Array, StringArray};
566 use arrow::datatypes::{DataType, Field, Schema};
567 use krishiv_plan::cep::Pattern;
568 use std::sync::Arc;
569 use std::time::Duration;
570
571 let schema = Arc::new(Schema::new(vec![
572 Field::new("user_id", DataType::Utf8, false),
573 Field::new("ts", DataType::Int64, false),
574 ]));
575 let batch = arrow::record_batch::RecordBatch::try_new(
578 schema,
579 vec![
580 Arc::new(StringArray::from(vec!["u1", "u2", "u1", "u2"])) as _,
581 Arc::new(Int64Array::from(vec![1_000_i64, 1_500, 2_000, 2_500])) as _,
582 ],
583 )
584 .unwrap();
585
586 let pattern = Pattern::begin("A")
587 .followed_by("B")
588 .within(Duration::from_secs(60))
589 .compile()
590 .unwrap();
591
592 let stmt = MatchRecognizeStatement {
593 source_table: "events".to_string(),
594 key_column: "user_id".to_string(),
595 event_time_column: "ts".to_string(),
596 pattern,
597 };
598
599 let result = execute_match_recognize(stmt, &[batch]).unwrap();
600 assert_eq!(
601 result.len(),
602 2,
603 "both u1 and u2 must independently complete the A→B pattern"
604 );
605 for matched in &result {
606 assert_eq!(
607 matched.num_rows(),
608 2,
609 "each match must contain 2 events (one for stage A, one for stage B)"
610 );
611 }
612 }
613
614 #[test]
615 fn execute_match_recognize_boundary_event_at_exact_window_matches() {
616 use arrow::array::{Int64Array, StringArray};
619 use arrow::datatypes::{DataType, Field, Schema};
620 use krishiv_plan::cep::Pattern;
621 use std::sync::Arc;
622 use std::time::Duration;
623
624 let schema = Arc::new(Schema::new(vec![
625 Field::new("user_id", DataType::Utf8, false),
626 Field::new("ts", DataType::Int64, false),
627 ]));
628 let batch = arrow::record_batch::RecordBatch::try_new(
630 schema,
631 vec![
632 Arc::new(StringArray::from(vec!["u1", "u1"])) as _,
633 Arc::new(Int64Array::from(vec![0_i64, 100])) as _,
634 ],
635 )
636 .unwrap();
637
638 let pattern = Pattern::begin("A")
639 .followed_by("B")
640 .within(Duration::from_millis(100))
641 .compile()
642 .unwrap();
643
644 let stmt = MatchRecognizeStatement {
645 source_table: "events".to_string(),
646 key_column: "user_id".to_string(),
647 event_time_column: "ts".to_string(),
648 pattern,
649 };
650
651 let result = execute_match_recognize(stmt, &[batch]).unwrap();
652 assert_eq!(
653 result.len(),
654 1,
655 "event at exactly start_time + window_ms (t=100) must still match (strict > check)"
656 );
657 }
658
659 #[test]
660 fn execute_match_recognize_one_ms_past_window_does_not_match() {
661 use arrow::array::{Int64Array, StringArray};
662 use arrow::datatypes::{DataType, Field, Schema};
663 use krishiv_plan::cep::Pattern;
664 use std::sync::Arc;
665 use std::time::Duration;
666
667 let schema = Arc::new(Schema::new(vec![
668 Field::new("user_id", DataType::Utf8, false),
669 Field::new("ts", DataType::Int64, false),
670 ]));
671 let batch = arrow::record_batch::RecordBatch::try_new(
673 schema,
674 vec![
675 Arc::new(StringArray::from(vec!["u1", "u1"])) as _,
676 Arc::new(Int64Array::from(vec![0_i64, 101])) as _,
677 ],
678 )
679 .unwrap();
680
681 let pattern = Pattern::begin("A")
682 .followed_by("B")
683 .within(Duration::from_millis(100))
684 .compile()
685 .unwrap();
686
687 let stmt = MatchRecognizeStatement {
688 source_table: "events".to_string(),
689 key_column: "user_id".to_string(),
690 event_time_column: "ts".to_string(),
691 pattern,
692 };
693
694 let result = execute_match_recognize(stmt, &[batch]).unwrap();
695 assert!(
696 result.is_empty(),
697 "event 1 ms past window_ms must not match (expired partial)"
698 );
699 }
700
701 #[test]
702 fn cep_on_streaming_source_returns_unsupported_error() {
703 let engine = crate::SqlEngine::new();
706 engine
707 .register_streaming_source_name("live_events")
708 .unwrap();
709 assert!(
712 engine.is_streaming_source("live_events"),
713 "live_events must be identified as a streaming source"
714 );
715 assert!(
716 !engine.is_streaming_source("batch_table"),
717 "batch_table must not be streaming"
718 );
719 }
720
721 #[test]
722 fn parses_match_recognize_subset() {
723 let stmt = parse_match_recognize(
724 "SELECT * FROM events MATCH_RECOGNIZE (PARTITION BY user_id ORDER BY ts PATTERN (A B) WITHIN 10 SECONDS)",
725 )
726 .unwrap()
727 .unwrap();
728 assert_eq!(stmt.source_table, "events");
729 assert_eq!(stmt.key_column, "user_id");
730 assert_eq!(stmt.event_time_column, "ts");
731 assert_eq!(stmt.pattern.stages.len(), 2);
732 assert_eq!(stmt.pattern.window_ms, 10_000);
733 }
734}