1use std::collections::VecDeque;
18use std::sync::Arc;
19
20use arrow_array::{
21 Array, Float64Array, Int64Array, RecordBatch, StringArray, TimestampMicrosecondArray,
22};
23use arrow_schema::{DataType, Field, Schema};
24use fxhash::FxHashMap;
25
26use super::{
27 Event, Operator, OperatorContext, OperatorError, OperatorState, Output, OutputVec, Timer,
28};
29
30#[derive(Debug, Clone)]
32pub struct LagLeadConfig {
33 pub operator_id: String,
35 pub functions: Vec<LagLeadFunctionSpec>,
37 pub partition_columns: Vec<String>,
39 pub max_partitions: usize,
41}
42
43#[derive(Debug, Clone)]
45pub struct LagLeadFunctionSpec {
46 pub is_lag: bool,
48 pub source_column: String,
50 pub offset: usize,
52 pub default_value: Option<f64>,
54 pub output_column: String,
56}
57
58#[derive(Debug, Clone)]
60struct PartitionState {
61 lag_history: VecDeque<f64>,
63 lead_pending: VecDeque<PendingLeadEvent>,
65}
66
67#[derive(Debug, Clone)]
69struct PendingLeadEvent {
70 event: Event,
72 remaining: usize,
74 value: f64,
76}
77
78#[derive(Debug, Default)]
80pub struct LagLeadMetrics {
81 pub events_processed: u64,
83 pub lag_lookups: u64,
85 pub lead_buffered: u64,
87 pub lead_flushed: u64,
89 pub partitions_active: u64,
91}
92
93pub struct LagLeadOperator {
98 operator_id: String,
100 functions: Vec<LagLeadFunctionSpec>,
102 partition_columns: Vec<String>,
104 partitions: FxHashMap<Vec<u8>, PartitionState>,
106 max_partitions: usize,
108 metrics: LagLeadMetrics,
110}
111
112impl LagLeadOperator {
113 #[must_use]
115 pub fn new(config: LagLeadConfig) -> Self {
116 Self {
117 operator_id: config.operator_id,
118 functions: config.functions,
119 partition_columns: config.partition_columns,
120 partitions: FxHashMap::default(),
121 max_partitions: config.max_partitions,
122 metrics: LagLeadMetrics::default(),
123 }
124 }
125
126 #[must_use]
128 pub fn partition_count(&self) -> usize {
129 self.partitions.len()
130 }
131
132 #[must_use]
134 pub fn metrics(&self) -> &LagLeadMetrics {
135 &self.metrics
136 }
137
138 fn extract_partition_key(&self, event: &Event) -> Vec<u8> {
140 let batch = &event.data;
141 let schema = batch.schema();
142 let mut key = Vec::new();
143
144 for col_name in &self.partition_columns {
145 let Ok(col_idx) = schema.index_of(col_name) else {
146 key.push(0x00); continue;
148 };
149
150 let array = batch.column(col_idx);
151
152 if array.is_null(0) {
153 key.push(0x00);
154 continue;
155 }
156
157 key.push(0x01); match array.data_type() {
160 DataType::Int64 => {
161 let arr = array.as_any().downcast_ref::<Int64Array>().unwrap();
162 key.extend_from_slice(&arr.value(0).to_le_bytes());
163 }
164 DataType::Utf8 => {
165 let arr = array.as_any().downcast_ref::<StringArray>().unwrap();
166 key.extend_from_slice(arr.value(0).as_bytes());
167 key.push(0x00); }
169 DataType::Float64 => {
170 let arr = array.as_any().downcast_ref::<Float64Array>().unwrap();
171 key.extend_from_slice(&arr.value(0).to_bits().to_le_bytes());
172 }
173 _ => {
174 key.push(0x00);
175 }
176 }
177 }
178
179 key
180 }
181
182 fn extract_column_value(event: &Event, column: &str) -> f64 {
184 let batch = &event.data;
185 let schema = batch.schema();
186 let Ok(col_idx) = schema.index_of(column) else {
187 return f64::NAN;
188 };
189
190 let array = batch.column(col_idx);
191 if array.is_null(0) {
192 return f64::NAN;
193 }
194
195 match array.data_type() {
196 DataType::Float64 => {
197 let arr = array.as_any().downcast_ref::<Float64Array>().unwrap();
198 arr.value(0)
199 }
200 DataType::Int64 => {
201 let arr = array.as_any().downcast_ref::<Int64Array>().unwrap();
202 #[allow(clippy::cast_precision_loss)]
203 {
204 arr.value(0) as f64
205 }
206 }
207 DataType::Timestamp(_, _) => {
208 let arr = array
209 .as_any()
210 .downcast_ref::<TimestampMicrosecondArray>()
211 .unwrap();
212 #[allow(clippy::cast_precision_loss)]
213 {
214 arr.value(0) as f64
215 }
216 }
217 _ => f64::NAN,
218 }
219 }
220
221 fn compute_lag_values(functions: &[LagLeadFunctionSpec], state: &PartitionState) -> Vec<f64> {
223 functions
224 .iter()
225 .filter(|f| f.is_lag)
226 .map(|func| {
227 let history = &state.lag_history;
228 if history.len() >= func.offset {
229 let idx = history.len() - func.offset;
230 history[idx]
231 } else {
232 func.default_value.unwrap_or(f64::NAN)
233 }
234 })
235 .collect()
236 }
237
238 fn build_output(
240 functions: &[LagLeadFunctionSpec],
241 event: &Event,
242 lag_values: &[f64],
243 lead_values: &[f64],
244 ) -> Event {
245 let original_batch = &event.data;
246 let mut fields: Vec<Field> = original_batch
247 .schema()
248 .fields()
249 .iter()
250 .map(|f| f.as_ref().clone())
251 .collect();
252 let mut columns: Vec<Arc<dyn Array>> = (0..original_batch.num_columns())
253 .map(|i| original_batch.column(i).clone())
254 .collect();
255
256 let mut lag_idx = 0;
257 let mut lead_idx = 0;
258
259 for func in functions {
260 let value = if func.is_lag {
261 let v = lag_values.get(lag_idx).copied().unwrap_or(f64::NAN);
262 lag_idx += 1;
263 v
264 } else {
265 let v = lead_values.get(lead_idx).copied().unwrap_or(f64::NAN);
266 lead_idx += 1;
267 v
268 };
269
270 fields.push(Field::new(&func.output_column, DataType::Float64, true));
271 columns.push(Arc::new(Float64Array::from(vec![value])));
272 }
273
274 let schema = Arc::new(Schema::new(fields));
275 let batch = RecordBatch::try_new(schema, columns)
276 .unwrap_or_else(|_| RecordBatch::new_empty(Arc::new(Schema::empty())));
277 Event::new(event.timestamp, batch)
278 }
279
280 #[allow(clippy::too_many_lines)]
283 fn process_event(&mut self, event: &Event) -> OutputVec {
284 let partition_key = self.extract_partition_key(event);
285
286 if !self.partitions.contains_key(&partition_key)
288 && self.partitions.len() >= self.max_partitions
289 {
290 return OutputVec::new();
291 }
292
293 let has_lag = self.functions.iter().any(|f| f.is_lag);
294 let has_lead = self.functions.iter().any(|f| !f.is_lag);
295
296 let max_lag_offset = self
298 .functions
299 .iter()
300 .filter(|f| f.is_lag)
301 .map(|f| f.offset)
302 .max()
303 .unwrap_or(1);
304 let max_lead_offset = self
305 .functions
306 .iter()
307 .filter(|f| !f.is_lag)
308 .map(|f| f.offset)
309 .max()
310 .unwrap_or(1);
311 let lag_source_col = self
312 .functions
313 .iter()
314 .find(|f| f.is_lag)
315 .map(|f| f.source_column.clone());
316 let lead_source_col = self
317 .functions
318 .iter()
319 .find(|f| !f.is_lag)
320 .map(|f| f.source_column.clone());
321 let lead_func_specs: Vec<(usize, Option<f64>)> = self
323 .functions
324 .iter()
325 .filter(|f| !f.is_lag)
326 .map(|f| (f.offset, f.default_value))
327 .collect();
328
329 let state = self
331 .partitions
332 .entry(partition_key)
333 .or_insert_with(|| PartitionState {
334 lag_history: VecDeque::new(),
335 lead_pending: VecDeque::new(),
336 });
337
338 let mut outputs = OutputVec::new();
339
340 let lag_values = if has_lag {
342 Self::compute_lag_values(&self.functions, state)
343 } else {
344 vec![]
345 };
346
347 if has_lag {
349 if let Some(col) = &lag_source_col {
350 let value = Self::extract_column_value(event, col);
351 state.lag_history.push_back(value);
352 while state.lag_history.len() > max_lag_offset {
353 state.lag_history.pop_front();
354 }
355 }
356 }
357
358 if has_lead {
359 let value = if let Some(col) = &lead_source_col {
361 Self::extract_column_value(event, col)
362 } else {
363 f64::NAN
364 };
365
366 for pending in &mut state.lead_pending {
368 pending.remaining = pending.remaining.saturating_sub(1);
369 }
370
371 state.lead_pending.push_back(PendingLeadEvent {
372 event: event.clone(),
373 remaining: max_lead_offset,
374 value,
375 });
376 self.metrics.lead_buffered += 1;
377
378 let mut resolved_events = Vec::new();
380 while state.lead_pending.front().is_some_and(|p| p.remaining == 0) {
381 let resolved = state.lead_pending.pop_front().unwrap();
382 let lead_values: Vec<f64> = lead_func_specs
383 .iter()
384 .map(|(offset, default)| {
385 if *offset <= state.lead_pending.len() {
386 state.lead_pending[*offset - 1].value
387 } else {
388 default.unwrap_or(f64::NAN)
389 }
390 })
391 .collect();
392 resolved_events.push((resolved, lead_values));
393 }
394
395 for (resolved, lead_values) in resolved_events {
396 let output =
397 Self::build_output(&self.functions, &resolved.event, &lag_values, &lead_values);
398 outputs.push(Output::Event(output));
399 self.metrics.lead_flushed += 1;
400 }
401 } else {
402 let output = Self::build_output(&self.functions, event, &lag_values, &[]);
404 outputs.push(Output::Event(output));
405 }
406
407 self.metrics.events_processed += 1;
408 if has_lag {
409 self.metrics.lag_lookups += 1;
410 }
411 self.metrics.partitions_active = self.partitions.len() as u64;
412
413 outputs
414 }
415
416 fn flush_pending_leads(&mut self) -> OutputVec {
419 let mut outputs = OutputVec::new();
420
421 let lead_defaults: Vec<f64> = self
423 .functions
424 .iter()
425 .filter(|f| !f.is_lag)
426 .map(|func| func.default_value.unwrap_or(f64::NAN))
427 .collect();
428 let lead_output_columns: Vec<String> = self
429 .functions
430 .iter()
431 .filter(|f| !f.is_lag)
432 .map(|f| f.output_column.clone())
433 .collect();
434
435 let mut flushed_count = 0u64;
436
437 for state in self.partitions.values_mut() {
438 while let Some(pending) = state.lead_pending.pop_front() {
439 let original_batch = &pending.event.data;
440 let mut fields: Vec<Field> = original_batch
441 .schema()
442 .fields()
443 .iter()
444 .map(|f| f.as_ref().clone())
445 .collect();
446 let mut columns: Vec<Arc<dyn Array>> = (0..original_batch.num_columns())
447 .map(|i| original_batch.column(i).clone())
448 .collect();
449
450 for (col_name, &default) in lead_output_columns.iter().zip(lead_defaults.iter()) {
451 fields.push(Field::new(col_name, DataType::Float64, true));
452 columns.push(Arc::new(Float64Array::from(vec![default])));
453 }
454
455 let schema = Arc::new(Schema::new(fields));
456 if let Ok(batch) = RecordBatch::try_new(schema, columns) {
457 let output_event = Event::new(pending.event.timestamp, batch);
458 outputs.push(Output::Event(output_event));
459 flushed_count += 1;
460 }
461 }
462 }
463
464 self.metrics.lead_flushed += flushed_count;
465 outputs
466 }
467}
468
469impl Operator for LagLeadOperator {
470 fn process(&mut self, event: &Event, _ctx: &mut OperatorContext) -> OutputVec {
471 self.process_event(event)
472 }
473
474 fn on_timer(&mut self, _timer: Timer, _ctx: &mut OperatorContext) -> OutputVec {
475 self.flush_pending_leads()
477 }
478
479 fn checkpoint(&self) -> OperatorState {
480 let mut data = Vec::new();
481
482 let num_partitions = self.partitions.len() as u64;
484 data.extend_from_slice(&num_partitions.to_le_bytes());
485
486 for (key, state) in &self.partitions {
488 let key_len = key.len() as u64;
490 data.extend_from_slice(&key_len.to_le_bytes());
491 data.extend_from_slice(key);
492
493 let history_len = state.lag_history.len() as u64;
495 data.extend_from_slice(&history_len.to_le_bytes());
496 for &val in &state.lag_history {
497 data.extend_from_slice(&val.to_le_bytes());
498 }
499
500 let pending_len = state.lead_pending.len() as u64;
502 data.extend_from_slice(&pending_len.to_le_bytes());
503 for pending in &state.lead_pending {
504 data.extend_from_slice(&pending.event.timestamp.to_le_bytes());
505 data.extend_from_slice(&(pending.remaining as u64).to_le_bytes());
506 data.extend_from_slice(&pending.value.to_le_bytes());
507 }
508 }
509
510 OperatorState {
511 operator_id: self.operator_id.clone(),
512 data,
513 }
514 }
515
516 #[allow(clippy::cast_possible_truncation)]
517 fn restore(&mut self, state: OperatorState) -> Result<(), OperatorError> {
518 if state.data.len() < 8 {
519 return Err(OperatorError::SerializationFailed(
520 "LagLead checkpoint data too short".to_string(),
521 ));
522 }
523
524 let mut offset = 0;
525
526 let num_partitions = u64::from_le_bytes(
527 state.data[offset..offset + 8]
528 .try_into()
529 .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
530 ) as usize;
531 offset += 8;
532
533 self.partitions.clear();
534
535 for _ in 0..num_partitions {
536 if offset + 8 > state.data.len() {
537 return Err(OperatorError::SerializationFailed(
538 "LagLead checkpoint truncated".to_string(),
539 ));
540 }
541
542 let key_len = u64::from_le_bytes(
544 state.data[offset..offset + 8]
545 .try_into()
546 .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
547 ) as usize;
548 offset += 8;
549
550 let partition_key = state.data[offset..offset + key_len].to_vec();
551 offset += key_len;
552
553 let history_len = u64::from_le_bytes(
555 state.data[offset..offset + 8]
556 .try_into()
557 .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
558 ) as usize;
559 offset += 8;
560
561 let mut lag_history = VecDeque::with_capacity(history_len);
562 for _ in 0..history_len {
563 let val = f64::from_le_bytes(
564 state.data[offset..offset + 8]
565 .try_into()
566 .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
567 );
568 offset += 8;
569 lag_history.push_back(val);
570 }
571
572 let pending_len = u64::from_le_bytes(
574 state.data[offset..offset + 8]
575 .try_into()
576 .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
577 ) as usize;
578 offset += 8;
579
580 let mut lead_pending = VecDeque::with_capacity(pending_len);
581 for _ in 0..pending_len {
582 let timestamp = i64::from_le_bytes(
583 state.data[offset..offset + 8]
584 .try_into()
585 .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
586 );
587 offset += 8;
588
589 let remaining = u64::from_le_bytes(
590 state.data[offset..offset + 8]
591 .try_into()
592 .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
593 ) as usize;
594 offset += 8;
595
596 let value = f64::from_le_bytes(
597 state.data[offset..offset + 8]
598 .try_into()
599 .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
600 );
601 offset += 8;
602
603 let batch = RecordBatch::new_empty(Arc::new(Schema::empty()));
604 lead_pending.push_back(PendingLeadEvent {
605 event: Event::new(timestamp, batch),
606 remaining,
607 value,
608 });
609 }
610
611 self.partitions.insert(
612 partition_key,
613 PartitionState {
614 lag_history,
615 lead_pending,
616 },
617 );
618 }
619
620 Ok(())
621 }
622}
623
624#[cfg(test)]
625#[allow(clippy::float_cmp)]
626mod tests {
627 use super::*;
628 use crate::operator::TimerKey;
629 use crate::state::InMemoryStore;
630 use crate::time::{BoundedOutOfOrdernessGenerator, TimerService};
631
632 fn make_trade(timestamp: i64, symbol: &str, price: f64) -> Event {
633 let schema = Arc::new(Schema::new(vec![
634 Field::new("symbol", DataType::Utf8, false),
635 Field::new("price", DataType::Float64, false),
636 ]));
637 let batch = RecordBatch::try_new(
638 schema,
639 vec![
640 Arc::new(StringArray::from(vec![symbol])),
641 Arc::new(Float64Array::from(vec![price])),
642 ],
643 )
644 .unwrap();
645 Event::new(timestamp, batch)
646 }
647
648 fn create_test_context<'a>(
649 timers: &'a mut TimerService,
650 state: &'a mut dyn crate::state::StateStore,
651 watermark_gen: &'a mut dyn crate::time::WatermarkGenerator,
652 ) -> OperatorContext<'a> {
653 OperatorContext {
654 event_time: 0,
655 processing_time: 0,
656 timers,
657 state,
658 watermark_generator: watermark_gen,
659 operator_index: 0,
660 }
661 }
662
663 fn lag_config(offset: usize) -> LagLeadConfig {
664 LagLeadConfig {
665 operator_id: "test_lag".to_string(),
666 functions: vec![LagLeadFunctionSpec {
667 is_lag: true,
668 source_column: "price".to_string(),
669 offset,
670 default_value: None,
671 output_column: "prev_price".to_string(),
672 }],
673 partition_columns: vec!["symbol".to_string()],
674 max_partitions: 100,
675 }
676 }
677
678 fn lead_config(offset: usize) -> LagLeadConfig {
679 LagLeadConfig {
680 operator_id: "test_lead".to_string(),
681 functions: vec![LagLeadFunctionSpec {
682 is_lag: false,
683 source_column: "price".to_string(),
684 offset,
685 default_value: Some(0.0),
686 output_column: "next_price".to_string(),
687 }],
688 partition_columns: vec!["symbol".to_string()],
689 max_partitions: 100,
690 }
691 }
692
693 #[test]
694 fn test_lag_first_event_returns_nan() {
695 let mut op = LagLeadOperator::new(lag_config(1));
696 let mut timers = TimerService::new();
697 let mut state = InMemoryStore::new();
698 let mut wm = BoundedOutOfOrdernessGenerator::new(0);
699
700 let event = make_trade(1, "AAPL", 150.0);
701 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
702 let outputs = op.process(&event, &mut ctx);
703
704 assert_eq!(outputs.len(), 1);
705 if let Output::Event(e) = &outputs[0] {
706 let arr = e
707 .data
708 .column_by_name("prev_price")
709 .unwrap()
710 .as_any()
711 .downcast_ref::<Float64Array>()
712 .unwrap();
713 assert!(arr.value(0).is_nan());
714 } else {
715 panic!("Expected Event output");
716 }
717 }
718
719 #[test]
720 fn test_lag_second_event_returns_previous() {
721 let mut op = LagLeadOperator::new(lag_config(1));
722 let mut timers = TimerService::new();
723 let mut state = InMemoryStore::new();
724 let mut wm = BoundedOutOfOrdernessGenerator::new(0);
725
726 let e1 = make_trade(1, "AAPL", 150.0);
727 let e2 = make_trade(2, "AAPL", 155.0);
728 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
729 op.process(&e1, &mut ctx);
730 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
731 let outputs = op.process(&e2, &mut ctx);
732
733 if let Output::Event(e) = &outputs[0] {
734 let arr = e
735 .data
736 .column_by_name("prev_price")
737 .unwrap()
738 .as_any()
739 .downcast_ref::<Float64Array>()
740 .unwrap();
741 assert_eq!(arr.value(0), 150.0);
742 }
743 }
744
745 #[test]
746 fn test_lag_with_default() {
747 let mut op = LagLeadOperator::new(LagLeadConfig {
748 operator_id: "test".to_string(),
749 functions: vec![LagLeadFunctionSpec {
750 is_lag: true,
751 source_column: "price".to_string(),
752 offset: 1,
753 default_value: Some(-1.0),
754 output_column: "prev_price".to_string(),
755 }],
756 partition_columns: vec!["symbol".to_string()],
757 max_partitions: 100,
758 });
759 let mut timers = TimerService::new();
760 let mut state = InMemoryStore::new();
761 let mut wm = BoundedOutOfOrdernessGenerator::new(0);
762
763 let event = make_trade(1, "AAPL", 150.0);
764 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
765 let outputs = op.process(&event, &mut ctx);
766
767 if let Output::Event(e) = &outputs[0] {
768 let arr = e
769 .data
770 .column_by_name("prev_price")
771 .unwrap()
772 .as_any()
773 .downcast_ref::<Float64Array>()
774 .unwrap();
775 assert_eq!(arr.value(0), -1.0);
776 }
777 }
778
779 #[test]
780 fn test_lag_offset_2() {
781 let mut op = LagLeadOperator::new(lag_config(2));
782 let mut timers = TimerService::new();
783 let mut state = InMemoryStore::new();
784 let mut wm = BoundedOutOfOrdernessGenerator::new(0);
785
786 let events = [
787 make_trade(1, "AAPL", 100.0),
788 make_trade(2, "AAPL", 110.0),
789 make_trade(3, "AAPL", 120.0),
790 ];
791
792 for e in &events[..2] {
793 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
794 op.process(e, &mut ctx);
795 }
796
797 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
798 let outputs = op.process(&events[2], &mut ctx);
799
800 if let Output::Event(e) = &outputs[0] {
801 let arr = e
802 .data
803 .column_by_name("prev_price")
804 .unwrap()
805 .as_any()
806 .downcast_ref::<Float64Array>()
807 .unwrap();
808 assert_eq!(arr.value(0), 100.0); }
810 }
811
812 #[test]
813 fn test_lag_separate_partitions() {
814 let mut op = LagLeadOperator::new(lag_config(1));
815 let mut timers = TimerService::new();
816 let mut state = InMemoryStore::new();
817 let mut wm = BoundedOutOfOrdernessGenerator::new(0);
818
819 let a1 = make_trade(1, "AAPL", 150.0);
821 let a2 = make_trade(3, "AAPL", 155.0);
822 let g1 = make_trade(2, "GOOG", 2800.0);
824
825 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
826 op.process(&a1, &mut ctx);
827 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
828 op.process(&g1, &mut ctx);
829 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
830 let outputs = op.process(&a2, &mut ctx);
831
832 if let Output::Event(e) = &outputs[0] {
834 let arr = e
835 .data
836 .column_by_name("prev_price")
837 .unwrap()
838 .as_any()
839 .downcast_ref::<Float64Array>()
840 .unwrap();
841 assert_eq!(arr.value(0), 150.0);
842 }
843 assert_eq!(op.partition_count(), 2);
844 }
845
846 #[test]
847 fn test_lead_buffers_events() {
848 let mut op = LagLeadOperator::new(lead_config(1));
849 let mut timers = TimerService::new();
850 let mut state = InMemoryStore::new();
851 let mut wm = BoundedOutOfOrdernessGenerator::new(0);
852
853 let e1 = make_trade(1, "AAPL", 150.0);
854 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
855 let outputs = op.process(&e1, &mut ctx);
856
857 assert!(outputs.is_empty());
859 }
860
861 #[test]
862 fn test_lead_resolves_on_next_event() {
863 let mut op = LagLeadOperator::new(lead_config(1));
864 let mut timers = TimerService::new();
865 let mut state = InMemoryStore::new();
866 let mut wm = BoundedOutOfOrdernessGenerator::new(0);
867
868 let e1 = make_trade(1, "AAPL", 150.0);
869 let e2 = make_trade(2, "AAPL", 155.0);
870
871 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
872 op.process(&e1, &mut ctx);
873 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
874 let outputs = op.process(&e2, &mut ctx);
875
876 assert_eq!(outputs.len(), 1);
878 if let Output::Event(e) = &outputs[0] {
879 let arr = e
880 .data
881 .column_by_name("next_price")
882 .unwrap()
883 .as_any()
884 .downcast_ref::<Float64Array>()
885 .unwrap();
886 assert_eq!(arr.value(0), 155.0);
887 }
888 }
889
890 #[test]
891 fn test_lead_flush_on_watermark() {
892 let mut op = LagLeadOperator::new(lead_config(1));
893 let mut timers = TimerService::new();
894 let mut state = InMemoryStore::new();
895 let mut wm = BoundedOutOfOrdernessGenerator::new(0);
896
897 let e1 = make_trade(1, "AAPL", 150.0);
898 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
899 op.process(&e1, &mut ctx);
900
901 let timer = Timer {
903 key: TimerKey::default(),
904 timestamp: 100,
905 };
906 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
907 let outputs = op.on_timer(timer, &mut ctx);
908
909 assert_eq!(outputs.len(), 1);
911 if let Output::Event(e) = &outputs[0] {
912 let arr = e
913 .data
914 .column_by_name("next_price")
915 .unwrap()
916 .as_any()
917 .downcast_ref::<Float64Array>()
918 .unwrap();
919 assert_eq!(arr.value(0), 0.0);
920 }
921 }
922
923 #[test]
924 fn test_max_partitions() {
925 let mut op = LagLeadOperator::new(LagLeadConfig {
926 operator_id: "test".to_string(),
927 functions: vec![LagLeadFunctionSpec {
928 is_lag: true,
929 source_column: "price".to_string(),
930 offset: 1,
931 default_value: None,
932 output_column: "prev_price".to_string(),
933 }],
934 partition_columns: vec!["symbol".to_string()],
935 max_partitions: 2,
936 });
937 let mut timers = TimerService::new();
938 let mut state = InMemoryStore::new();
939 let mut wm = BoundedOutOfOrdernessGenerator::new(0);
940
941 let e1 = make_trade(1, "AAPL", 150.0);
942 let e2 = make_trade(2, "GOOG", 2800.0);
943 let e3 = make_trade(3, "MSFT", 300.0);
944
945 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
946 op.process(&e1, &mut ctx);
947 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
948 op.process(&e2, &mut ctx);
949 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
950 let outputs = op.process(&e3, &mut ctx);
951
952 assert!(outputs.is_empty()); assert_eq!(op.partition_count(), 2);
954 }
955
956 #[test]
957 fn test_checkpoint_restore() {
958 let mut op = LagLeadOperator::new(lag_config(1));
959 let mut timers = TimerService::new();
960 let mut state = InMemoryStore::new();
961 let mut wm = BoundedOutOfOrdernessGenerator::new(0);
962
963 let events = vec![
964 make_trade(1, "AAPL", 100.0),
965 make_trade(2, "AAPL", 110.0),
966 make_trade(3, "GOOG", 2800.0),
967 ];
968 for e in &events {
969 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
970 op.process(e, &mut ctx);
971 }
972
973 let checkpoint = op.checkpoint();
974 assert_eq!(checkpoint.operator_id, "test_lag");
975
976 let mut op2 = LagLeadOperator::new(lag_config(1));
977 op2.restore(checkpoint).unwrap();
978 assert_eq!(op2.partition_count(), 2);
979 }
980
981 #[test]
982 fn test_metrics() {
983 let mut op = LagLeadOperator::new(lag_config(1));
984 let mut timers = TimerService::new();
985 let mut state = InMemoryStore::new();
986 let mut wm = BoundedOutOfOrdernessGenerator::new(0);
987
988 let e1 = make_trade(1, "AAPL", 150.0);
989 let e2 = make_trade(2, "AAPL", 155.0);
990 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
991 op.process(&e1, &mut ctx);
992 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
993 op.process(&e2, &mut ctx);
994
995 assert_eq!(op.metrics().events_processed, 2);
996 assert_eq!(op.metrics().lag_lookups, 2);
997 }
998
999 #[test]
1000 fn test_lead_separate_partitions() {
1001 let mut op = LagLeadOperator::new(lead_config(1));
1002 let mut timers = TimerService::new();
1003 let mut state = InMemoryStore::new();
1004 let mut wm = BoundedOutOfOrdernessGenerator::new(0);
1005
1006 let a1 = make_trade(1, "AAPL", 150.0);
1007 let g1 = make_trade(2, "GOOG", 2800.0);
1008 let a2 = make_trade(3, "AAPL", 155.0);
1009
1010 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
1011 op.process(&a1, &mut ctx);
1012 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
1013 op.process(&g1, &mut ctx);
1014 let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
1015 let outputs = op.process(&a2, &mut ctx);
1016
1017 assert_eq!(outputs.len(), 1);
1019 if let Output::Event(e) = &outputs[0] {
1020 let arr = e
1021 .data
1022 .column_by_name("next_price")
1023 .unwrap()
1024 .as_any()
1025 .downcast_ref::<Float64Array>()
1026 .unwrap();
1027 assert_eq!(arr.value(0), 155.0);
1028 }
1029 }
1030
1031 #[test]
1032 fn test_empty_operator() {
1033 let op = LagLeadOperator::new(lag_config(1));
1034 assert_eq!(op.partition_count(), 0);
1035 assert_eq!(op.metrics().events_processed, 0);
1036 }
1037}