Skip to main content

laminar_core/operator/
lag_lead.rs

1//! # LAG/LEAD Operator
2//!
3//! Per-row analytic window functions for streaming event processing.
4//!
5//! ## Streaming Semantics
6//!
7//! - **LAG(col, offset, default)**: Look back `offset` events in the partition's
8//!   history buffer. Emit immediately since history is always available.
9//! - **LEAD(col, offset, default)**: Buffer current event and wait for `offset`
10//!   future events in the partition. Flush remaining with defaults on watermark.
11//!
12//! ## Memory Bounds
13//!
14//! Memory is O(P * max(`lag_offset`, `lead_offset`)) where P = distinct partitions.
15//! A `max_partitions` limit prevents unbounded partition growth.
16
17use std::collections::VecDeque;
18use std::sync::Arc;
19
20use arrow_array::{
21    Array, Float64Array, Int64Array, RecordBatch, StringArray, TimestampMicrosecondArray,
22};
23use arrow_schema::{DataType, Field, Schema};
24use fxhash::FxHashMap;
25
26use super::{
27    Event, Operator, OperatorContext, OperatorError, OperatorState, Output, OutputVec, Timer,
28};
29
30/// Configuration for a LAG/LEAD operator.
31#[derive(Debug, Clone)]
32pub struct LagLeadConfig {
33    /// Operator identifier for checkpointing.
34    pub operator_id: String,
35    /// Individual function specifications.
36    pub functions: Vec<LagLeadFunctionSpec>,
37    /// Partition key columns.
38    pub partition_columns: Vec<String>,
39    /// Maximum number of partitions (memory safety).
40    pub max_partitions: usize,
41}
42
43/// Specification for a single LAG or LEAD function.
44#[derive(Debug, Clone)]
45pub struct LagLeadFunctionSpec {
46    /// True for LAG, false for LEAD.
47    pub is_lag: bool,
48    /// Source column to read values from.
49    pub source_column: String,
50    /// Offset (number of rows to look back/ahead).
51    pub offset: usize,
52    /// Default value when no row is available (as f64 for simplicity).
53    pub default_value: Option<f64>,
54    /// Output column name.
55    pub output_column: String,
56}
57
58/// Per-partition state for LAG/LEAD processing.
59#[derive(Debug, Clone)]
60struct PartitionState {
61    /// History buffer for LAG lookback (most recent at back).
62    lag_history: VecDeque<f64>,
63    /// Pending events for LEAD (waiting for future events).
64    lead_pending: VecDeque<PendingLeadEvent>,
65}
66
67/// A pending event waiting for LEAD resolution.
68#[derive(Debug, Clone)]
69struct PendingLeadEvent {
70    /// Original event.
71    event: Event,
72    /// Remaining events needed before this can be emitted.
73    remaining: usize,
74    /// The value from the source column at this event's position.
75    value: f64,
76}
77
78/// Metrics for LAG/LEAD operator.
79#[derive(Debug, Default)]
80pub struct LagLeadMetrics {
81    /// Total events processed.
82    pub events_processed: u64,
83    /// LAG lookups performed.
84    pub lag_lookups: u64,
85    /// LEAD events buffered.
86    pub lead_buffered: u64,
87    /// LEAD events flushed (resolved or watermark).
88    pub lead_flushed: u64,
89    /// Active partition count.
90    pub partitions_active: u64,
91}
92
93/// LAG/LEAD streaming operator.
94///
95/// Implements per-partition LAG and LEAD analytic functions for streaming
96/// event processing with checkpoint/restore support.
97pub struct LagLeadOperator {
98    /// Operator identifier for checkpointing.
99    operator_id: String,
100    /// Function specifications.
101    functions: Vec<LagLeadFunctionSpec>,
102    /// Partition key columns.
103    partition_columns: Vec<String>,
104    /// Per-partition state.
105    partitions: FxHashMap<Vec<u8>, PartitionState>,
106    /// Maximum number of partitions.
107    max_partitions: usize,
108    /// Operator metrics.
109    metrics: LagLeadMetrics,
110}
111
112impl LagLeadOperator {
113    /// Creates a new LAG/LEAD operator from configuration.
114    #[must_use]
115    pub fn new(config: LagLeadConfig) -> Self {
116        Self {
117            operator_id: config.operator_id,
118            functions: config.functions,
119            partition_columns: config.partition_columns,
120            partitions: FxHashMap::default(),
121            max_partitions: config.max_partitions,
122            metrics: LagLeadMetrics::default(),
123        }
124    }
125
126    /// Returns the number of active partitions.
127    #[must_use]
128    pub fn partition_count(&self) -> usize {
129        self.partitions.len()
130    }
131
132    /// Returns a reference to the metrics.
133    #[must_use]
134    pub fn metrics(&self) -> &LagLeadMetrics {
135        &self.metrics
136    }
137
138    /// Extracts the partition key from an event.
139    fn extract_partition_key(&self, event: &Event) -> Vec<u8> {
140        let batch = &event.data;
141        let schema = batch.schema();
142        let mut key = Vec::new();
143
144        for col_name in &self.partition_columns {
145            let Ok(col_idx) = schema.index_of(col_name) else {
146                key.push(0x00); // missing column marker
147                continue;
148            };
149
150            let array = batch.column(col_idx);
151
152            if array.is_null(0) {
153                key.push(0x00);
154                continue;
155            }
156
157            key.push(0x01); // non-null marker
158
159            match array.data_type() {
160                DataType::Int64 => {
161                    let arr = array.as_any().downcast_ref::<Int64Array>().unwrap();
162                    key.extend_from_slice(&arr.value(0).to_le_bytes());
163                }
164                DataType::Utf8 => {
165                    let arr = array.as_any().downcast_ref::<StringArray>().unwrap();
166                    key.extend_from_slice(arr.value(0).as_bytes());
167                    key.push(0x00); // null terminator
168                }
169                DataType::Float64 => {
170                    let arr = array.as_any().downcast_ref::<Float64Array>().unwrap();
171                    key.extend_from_slice(&arr.value(0).to_bits().to_le_bytes());
172                }
173                _ => {
174                    key.push(0x00);
175                }
176            }
177        }
178
179        key
180    }
181
182    /// Extracts a f64 value from a column in the event.
183    fn extract_column_value(event: &Event, column: &str) -> f64 {
184        let batch = &event.data;
185        let schema = batch.schema();
186        let Ok(col_idx) = schema.index_of(column) else {
187            return f64::NAN;
188        };
189
190        let array = batch.column(col_idx);
191        if array.is_null(0) {
192            return f64::NAN;
193        }
194
195        match array.data_type() {
196            DataType::Float64 => {
197                let arr = array.as_any().downcast_ref::<Float64Array>().unwrap();
198                arr.value(0)
199            }
200            DataType::Int64 => {
201                let arr = array.as_any().downcast_ref::<Int64Array>().unwrap();
202                #[allow(clippy::cast_precision_loss)]
203                {
204                    arr.value(0) as f64
205                }
206            }
207            DataType::Timestamp(_, _) => {
208                let arr = array
209                    .as_any()
210                    .downcast_ref::<TimestampMicrosecondArray>()
211                    .unwrap();
212                #[allow(clippy::cast_precision_loss)]
213                {
214                    arr.value(0) as f64
215                }
216            }
217            _ => f64::NAN,
218        }
219    }
220
221    /// Computes LAG values for a partition.
222    fn compute_lag_values(functions: &[LagLeadFunctionSpec], state: &PartitionState) -> Vec<f64> {
223        functions
224            .iter()
225            .filter(|f| f.is_lag)
226            .map(|func| {
227                let history = &state.lag_history;
228                if history.len() >= func.offset {
229                    let idx = history.len() - func.offset;
230                    history[idx]
231                } else {
232                    func.default_value.unwrap_or(f64::NAN)
233                }
234            })
235            .collect()
236    }
237
238    /// Builds an output event with computed values (static to avoid borrow issues).
239    fn build_output(
240        functions: &[LagLeadFunctionSpec],
241        event: &Event,
242        lag_values: &[f64],
243        lead_values: &[f64],
244    ) -> Event {
245        let original_batch = &event.data;
246        let mut fields: Vec<Field> = original_batch
247            .schema()
248            .fields()
249            .iter()
250            .map(|f| f.as_ref().clone())
251            .collect();
252        let mut columns: Vec<Arc<dyn Array>> = (0..original_batch.num_columns())
253            .map(|i| original_batch.column(i).clone())
254            .collect();
255
256        let mut lag_idx = 0;
257        let mut lead_idx = 0;
258
259        for func in functions {
260            let value = if func.is_lag {
261                let v = lag_values.get(lag_idx).copied().unwrap_or(f64::NAN);
262                lag_idx += 1;
263                v
264            } else {
265                let v = lead_values.get(lead_idx).copied().unwrap_or(f64::NAN);
266                lead_idx += 1;
267                v
268            };
269
270            fields.push(Field::new(&func.output_column, DataType::Float64, true));
271            columns.push(Arc::new(Float64Array::from(vec![value])));
272        }
273
274        let schema = Arc::new(Schema::new(fields));
275        let batch = RecordBatch::try_new(schema, columns)
276            .unwrap_or_else(|_| RecordBatch::new_empty(Arc::new(Schema::empty())));
277        Event::new(event.timestamp, batch)
278    }
279
280    /// Processes an event: computes LAG values, buffers LEAD events, resolves
281    /// pending LEAD events that now have enough future rows.
282    #[allow(clippy::too_many_lines)]
283    fn process_event(&mut self, event: &Event) -> OutputVec {
284        let partition_key = self.extract_partition_key(event);
285
286        // Check max partitions
287        if !self.partitions.contains_key(&partition_key)
288            && self.partitions.len() >= self.max_partitions
289        {
290            return OutputVec::new();
291        }
292
293        let has_lag = self.functions.iter().any(|f| f.is_lag);
294        let has_lead = self.functions.iter().any(|f| !f.is_lag);
295
296        // Pre-compute constants from functions to avoid borrowing self later
297        let max_lag_offset = self
298            .functions
299            .iter()
300            .filter(|f| f.is_lag)
301            .map(|f| f.offset)
302            .max()
303            .unwrap_or(1);
304        let max_lead_offset = self
305            .functions
306            .iter()
307            .filter(|f| !f.is_lag)
308            .map(|f| f.offset)
309            .max()
310            .unwrap_or(1);
311        let lag_source_col = self
312            .functions
313            .iter()
314            .find(|f| f.is_lag)
315            .map(|f| f.source_column.clone());
316        let lead_source_col = self
317            .functions
318            .iter()
319            .find(|f| !f.is_lag)
320            .map(|f| f.source_column.clone());
321        // Pre-collect LEAD function defaults/offsets to avoid borrow conflicts
322        let lead_func_specs: Vec<(usize, Option<f64>)> = self
323            .functions
324            .iter()
325            .filter(|f| !f.is_lag)
326            .map(|f| (f.offset, f.default_value))
327            .collect();
328
329        // Get or create partition state
330        let state = self
331            .partitions
332            .entry(partition_key)
333            .or_insert_with(|| PartitionState {
334                lag_history: VecDeque::new(),
335                lead_pending: VecDeque::new(),
336            });
337
338        let mut outputs = OutputVec::new();
339
340        // Process LAG: look back in history
341        let lag_values = if has_lag {
342            Self::compute_lag_values(&self.functions, state)
343        } else {
344            vec![]
345        };
346
347        // Update LAG history
348        if has_lag {
349            if let Some(col) = &lag_source_col {
350                let value = Self::extract_column_value(event, col);
351                state.lag_history.push_back(value);
352                while state.lag_history.len() > max_lag_offset {
353                    state.lag_history.pop_front();
354                }
355            }
356        }
357
358        if has_lead {
359            // Buffer this event for LEAD resolution
360            let value = if let Some(col) = &lead_source_col {
361                Self::extract_column_value(event, col)
362            } else {
363                f64::NAN
364            };
365
366            // Decrement remaining on all pending events
367            for pending in &mut state.lead_pending {
368                pending.remaining = pending.remaining.saturating_sub(1);
369            }
370
371            state.lead_pending.push_back(PendingLeadEvent {
372                event: event.clone(),
373                remaining: max_lead_offset,
374                value,
375            });
376            self.metrics.lead_buffered += 1;
377
378            // Emit resolved LEAD events (remaining == 0)
379            let mut resolved_events = Vec::new();
380            while state.lead_pending.front().is_some_and(|p| p.remaining == 0) {
381                let resolved = state.lead_pending.pop_front().unwrap();
382                let lead_values: Vec<f64> = lead_func_specs
383                    .iter()
384                    .map(|(offset, default)| {
385                        if *offset <= state.lead_pending.len() {
386                            state.lead_pending[*offset - 1].value
387                        } else {
388                            default.unwrap_or(f64::NAN)
389                        }
390                    })
391                    .collect();
392                resolved_events.push((resolved, lead_values));
393            }
394
395            for (resolved, lead_values) in resolved_events {
396                let output =
397                    Self::build_output(&self.functions, &resolved.event, &lag_values, &lead_values);
398                outputs.push(Output::Event(output));
399                self.metrics.lead_flushed += 1;
400            }
401        } else {
402            // No LEAD functions: emit immediately with LAG values
403            let output = Self::build_output(&self.functions, event, &lag_values, &[]);
404            outputs.push(Output::Event(output));
405        }
406
407        self.metrics.events_processed += 1;
408        if has_lag {
409            self.metrics.lag_lookups += 1;
410        }
411        self.metrics.partitions_active = self.partitions.len() as u64;
412
413        outputs
414    }
415
416    /// Flushes all pending LEAD events with default values.
417    /// Called on watermark advance.
418    fn flush_pending_leads(&mut self) -> OutputVec {
419        let mut outputs = OutputVec::new();
420
421        // Pre-compute lead defaults to avoid borrow conflicts
422        let lead_defaults: Vec<f64> = self
423            .functions
424            .iter()
425            .filter(|f| !f.is_lag)
426            .map(|func| func.default_value.unwrap_or(f64::NAN))
427            .collect();
428        let lead_output_columns: Vec<String> = self
429            .functions
430            .iter()
431            .filter(|f| !f.is_lag)
432            .map(|f| f.output_column.clone())
433            .collect();
434
435        let mut flushed_count = 0u64;
436
437        for state in self.partitions.values_mut() {
438            while let Some(pending) = state.lead_pending.pop_front() {
439                let original_batch = &pending.event.data;
440                let mut fields: Vec<Field> = original_batch
441                    .schema()
442                    .fields()
443                    .iter()
444                    .map(|f| f.as_ref().clone())
445                    .collect();
446                let mut columns: Vec<Arc<dyn Array>> = (0..original_batch.num_columns())
447                    .map(|i| original_batch.column(i).clone())
448                    .collect();
449
450                for (col_name, &default) in lead_output_columns.iter().zip(lead_defaults.iter()) {
451                    fields.push(Field::new(col_name, DataType::Float64, true));
452                    columns.push(Arc::new(Float64Array::from(vec![default])));
453                }
454
455                let schema = Arc::new(Schema::new(fields));
456                if let Ok(batch) = RecordBatch::try_new(schema, columns) {
457                    let output_event = Event::new(pending.event.timestamp, batch);
458                    outputs.push(Output::Event(output_event));
459                    flushed_count += 1;
460                }
461            }
462        }
463
464        self.metrics.lead_flushed += flushed_count;
465        outputs
466    }
467}
468
469impl Operator for LagLeadOperator {
470    fn process(&mut self, event: &Event, _ctx: &mut OperatorContext) -> OutputVec {
471        self.process_event(event)
472    }
473
474    fn on_timer(&mut self, _timer: Timer, _ctx: &mut OperatorContext) -> OutputVec {
475        // Flush pending LEAD events on watermark/timer
476        self.flush_pending_leads()
477    }
478
479    fn checkpoint(&self) -> OperatorState {
480        let mut data = Vec::new();
481
482        // Write partition count
483        let num_partitions = self.partitions.len() as u64;
484        data.extend_from_slice(&num_partitions.to_le_bytes());
485
486        // Write each partition
487        for (key, state) in &self.partitions {
488            // Partition key
489            let key_len = key.len() as u64;
490            data.extend_from_slice(&key_len.to_le_bytes());
491            data.extend_from_slice(key);
492
493            // LAG history
494            let history_len = state.lag_history.len() as u64;
495            data.extend_from_slice(&history_len.to_le_bytes());
496            for &val in &state.lag_history {
497                data.extend_from_slice(&val.to_le_bytes());
498            }
499
500            // LEAD pending count
501            let pending_len = state.lead_pending.len() as u64;
502            data.extend_from_slice(&pending_len.to_le_bytes());
503            for pending in &state.lead_pending {
504                data.extend_from_slice(&pending.event.timestamp.to_le_bytes());
505                data.extend_from_slice(&(pending.remaining as u64).to_le_bytes());
506                data.extend_from_slice(&pending.value.to_le_bytes());
507            }
508        }
509
510        OperatorState {
511            operator_id: self.operator_id.clone(),
512            data,
513        }
514    }
515
516    #[allow(clippy::cast_possible_truncation)]
517    fn restore(&mut self, state: OperatorState) -> Result<(), OperatorError> {
518        if state.data.len() < 8 {
519            return Err(OperatorError::SerializationFailed(
520                "LagLead checkpoint data too short".to_string(),
521            ));
522        }
523
524        let mut offset = 0;
525
526        let num_partitions = u64::from_le_bytes(
527            state.data[offset..offset + 8]
528                .try_into()
529                .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
530        ) as usize;
531        offset += 8;
532
533        self.partitions.clear();
534
535        for _ in 0..num_partitions {
536            if offset + 8 > state.data.len() {
537                return Err(OperatorError::SerializationFailed(
538                    "LagLead checkpoint truncated".to_string(),
539                ));
540            }
541
542            // Read partition key
543            let key_len = u64::from_le_bytes(
544                state.data[offset..offset + 8]
545                    .try_into()
546                    .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
547            ) as usize;
548            offset += 8;
549
550            let partition_key = state.data[offset..offset + key_len].to_vec();
551            offset += key_len;
552
553            // Read LAG history
554            let history_len = u64::from_le_bytes(
555                state.data[offset..offset + 8]
556                    .try_into()
557                    .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
558            ) as usize;
559            offset += 8;
560
561            let mut lag_history = VecDeque::with_capacity(history_len);
562            for _ in 0..history_len {
563                let val = f64::from_le_bytes(
564                    state.data[offset..offset + 8]
565                        .try_into()
566                        .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
567                );
568                offset += 8;
569                lag_history.push_back(val);
570            }
571
572            // Read LEAD pending
573            let pending_len = u64::from_le_bytes(
574                state.data[offset..offset + 8]
575                    .try_into()
576                    .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
577            ) as usize;
578            offset += 8;
579
580            let mut lead_pending = VecDeque::with_capacity(pending_len);
581            for _ in 0..pending_len {
582                let timestamp = i64::from_le_bytes(
583                    state.data[offset..offset + 8]
584                        .try_into()
585                        .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
586                );
587                offset += 8;
588
589                let remaining = u64::from_le_bytes(
590                    state.data[offset..offset + 8]
591                        .try_into()
592                        .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
593                ) as usize;
594                offset += 8;
595
596                let value = f64::from_le_bytes(
597                    state.data[offset..offset + 8]
598                        .try_into()
599                        .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
600                );
601                offset += 8;
602
603                let batch = RecordBatch::new_empty(Arc::new(Schema::empty()));
604                lead_pending.push_back(PendingLeadEvent {
605                    event: Event::new(timestamp, batch),
606                    remaining,
607                    value,
608                });
609            }
610
611            self.partitions.insert(
612                partition_key,
613                PartitionState {
614                    lag_history,
615                    lead_pending,
616                },
617            );
618        }
619
620        Ok(())
621    }
622}
623
624#[cfg(test)]
625#[allow(clippy::float_cmp)]
626mod tests {
627    use super::*;
628    use crate::operator::TimerKey;
629    use crate::state::InMemoryStore;
630    use crate::time::{BoundedOutOfOrdernessGenerator, TimerService};
631
632    fn make_trade(timestamp: i64, symbol: &str, price: f64) -> Event {
633        let schema = Arc::new(Schema::new(vec![
634            Field::new("symbol", DataType::Utf8, false),
635            Field::new("price", DataType::Float64, false),
636        ]));
637        let batch = RecordBatch::try_new(
638            schema,
639            vec![
640                Arc::new(StringArray::from(vec![symbol])),
641                Arc::new(Float64Array::from(vec![price])),
642            ],
643        )
644        .unwrap();
645        Event::new(timestamp, batch)
646    }
647
648    fn create_test_context<'a>(
649        timers: &'a mut TimerService,
650        state: &'a mut dyn crate::state::StateStore,
651        watermark_gen: &'a mut dyn crate::time::WatermarkGenerator,
652    ) -> OperatorContext<'a> {
653        OperatorContext {
654            event_time: 0,
655            processing_time: 0,
656            timers,
657            state,
658            watermark_generator: watermark_gen,
659            operator_index: 0,
660        }
661    }
662
663    fn lag_config(offset: usize) -> LagLeadConfig {
664        LagLeadConfig {
665            operator_id: "test_lag".to_string(),
666            functions: vec![LagLeadFunctionSpec {
667                is_lag: true,
668                source_column: "price".to_string(),
669                offset,
670                default_value: None,
671                output_column: "prev_price".to_string(),
672            }],
673            partition_columns: vec!["symbol".to_string()],
674            max_partitions: 100,
675        }
676    }
677
678    fn lead_config(offset: usize) -> LagLeadConfig {
679        LagLeadConfig {
680            operator_id: "test_lead".to_string(),
681            functions: vec![LagLeadFunctionSpec {
682                is_lag: false,
683                source_column: "price".to_string(),
684                offset,
685                default_value: Some(0.0),
686                output_column: "next_price".to_string(),
687            }],
688            partition_columns: vec!["symbol".to_string()],
689            max_partitions: 100,
690        }
691    }
692
693    #[test]
694    fn test_lag_first_event_returns_nan() {
695        let mut op = LagLeadOperator::new(lag_config(1));
696        let mut timers = TimerService::new();
697        let mut state = InMemoryStore::new();
698        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
699
700        let event = make_trade(1, "AAPL", 150.0);
701        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
702        let outputs = op.process(&event, &mut ctx);
703
704        assert_eq!(outputs.len(), 1);
705        if let Output::Event(e) = &outputs[0] {
706            let arr = e
707                .data
708                .column_by_name("prev_price")
709                .unwrap()
710                .as_any()
711                .downcast_ref::<Float64Array>()
712                .unwrap();
713            assert!(arr.value(0).is_nan());
714        } else {
715            panic!("Expected Event output");
716        }
717    }
718
719    #[test]
720    fn test_lag_second_event_returns_previous() {
721        let mut op = LagLeadOperator::new(lag_config(1));
722        let mut timers = TimerService::new();
723        let mut state = InMemoryStore::new();
724        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
725
726        let e1 = make_trade(1, "AAPL", 150.0);
727        let e2 = make_trade(2, "AAPL", 155.0);
728        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
729        op.process(&e1, &mut ctx);
730        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
731        let outputs = op.process(&e2, &mut ctx);
732
733        if let Output::Event(e) = &outputs[0] {
734            let arr = e
735                .data
736                .column_by_name("prev_price")
737                .unwrap()
738                .as_any()
739                .downcast_ref::<Float64Array>()
740                .unwrap();
741            assert_eq!(arr.value(0), 150.0);
742        }
743    }
744
745    #[test]
746    fn test_lag_with_default() {
747        let mut op = LagLeadOperator::new(LagLeadConfig {
748            operator_id: "test".to_string(),
749            functions: vec![LagLeadFunctionSpec {
750                is_lag: true,
751                source_column: "price".to_string(),
752                offset: 1,
753                default_value: Some(-1.0),
754                output_column: "prev_price".to_string(),
755            }],
756            partition_columns: vec!["symbol".to_string()],
757            max_partitions: 100,
758        });
759        let mut timers = TimerService::new();
760        let mut state = InMemoryStore::new();
761        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
762
763        let event = make_trade(1, "AAPL", 150.0);
764        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
765        let outputs = op.process(&event, &mut ctx);
766
767        if let Output::Event(e) = &outputs[0] {
768            let arr = e
769                .data
770                .column_by_name("prev_price")
771                .unwrap()
772                .as_any()
773                .downcast_ref::<Float64Array>()
774                .unwrap();
775            assert_eq!(arr.value(0), -1.0);
776        }
777    }
778
779    #[test]
780    fn test_lag_offset_2() {
781        let mut op = LagLeadOperator::new(lag_config(2));
782        let mut timers = TimerService::new();
783        let mut state = InMemoryStore::new();
784        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
785
786        let events = [
787            make_trade(1, "AAPL", 100.0),
788            make_trade(2, "AAPL", 110.0),
789            make_trade(3, "AAPL", 120.0),
790        ];
791
792        for e in &events[..2] {
793            let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
794            op.process(e, &mut ctx);
795        }
796
797        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
798        let outputs = op.process(&events[2], &mut ctx);
799
800        if let Output::Event(e) = &outputs[0] {
801            let arr = e
802                .data
803                .column_by_name("prev_price")
804                .unwrap()
805                .as_any()
806                .downcast_ref::<Float64Array>()
807                .unwrap();
808            assert_eq!(arr.value(0), 100.0); // 2 positions back
809        }
810    }
811
812    #[test]
813    fn test_lag_separate_partitions() {
814        let mut op = LagLeadOperator::new(lag_config(1));
815        let mut timers = TimerService::new();
816        let mut state = InMemoryStore::new();
817        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
818
819        // AAPL events
820        let a1 = make_trade(1, "AAPL", 150.0);
821        let a2 = make_trade(3, "AAPL", 155.0);
822        // GOOG events
823        let g1 = make_trade(2, "GOOG", 2800.0);
824
825        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
826        op.process(&a1, &mut ctx);
827        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
828        op.process(&g1, &mut ctx);
829        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
830        let outputs = op.process(&a2, &mut ctx);
831
832        // AAPL should see previous AAPL price (150.0), not GOOG
833        if let Output::Event(e) = &outputs[0] {
834            let arr = e
835                .data
836                .column_by_name("prev_price")
837                .unwrap()
838                .as_any()
839                .downcast_ref::<Float64Array>()
840                .unwrap();
841            assert_eq!(arr.value(0), 150.0);
842        }
843        assert_eq!(op.partition_count(), 2);
844    }
845
846    #[test]
847    fn test_lead_buffers_events() {
848        let mut op = LagLeadOperator::new(lead_config(1));
849        let mut timers = TimerService::new();
850        let mut state = InMemoryStore::new();
851        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
852
853        let e1 = make_trade(1, "AAPL", 150.0);
854        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
855        let outputs = op.process(&e1, &mut ctx);
856
857        // First event should be buffered (no future event yet)
858        assert!(outputs.is_empty());
859    }
860
861    #[test]
862    fn test_lead_resolves_on_next_event() {
863        let mut op = LagLeadOperator::new(lead_config(1));
864        let mut timers = TimerService::new();
865        let mut state = InMemoryStore::new();
866        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
867
868        let e1 = make_trade(1, "AAPL", 150.0);
869        let e2 = make_trade(2, "AAPL", 155.0);
870
871        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
872        op.process(&e1, &mut ctx);
873        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
874        let outputs = op.process(&e2, &mut ctx);
875
876        // First event should now be emitted with LEAD = 155.0
877        assert_eq!(outputs.len(), 1);
878        if let Output::Event(e) = &outputs[0] {
879            let arr = e
880                .data
881                .column_by_name("next_price")
882                .unwrap()
883                .as_any()
884                .downcast_ref::<Float64Array>()
885                .unwrap();
886            assert_eq!(arr.value(0), 155.0);
887        }
888    }
889
890    #[test]
891    fn test_lead_flush_on_watermark() {
892        let mut op = LagLeadOperator::new(lead_config(1));
893        let mut timers = TimerService::new();
894        let mut state = InMemoryStore::new();
895        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
896
897        let e1 = make_trade(1, "AAPL", 150.0);
898        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
899        op.process(&e1, &mut ctx);
900
901        // Flush on timer/watermark
902        let timer = Timer {
903            key: TimerKey::default(),
904            timestamp: 100,
905        };
906        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
907        let outputs = op.on_timer(timer, &mut ctx);
908
909        // Should emit with default value (0.0)
910        assert_eq!(outputs.len(), 1);
911        if let Output::Event(e) = &outputs[0] {
912            let arr = e
913                .data
914                .column_by_name("next_price")
915                .unwrap()
916                .as_any()
917                .downcast_ref::<Float64Array>()
918                .unwrap();
919            assert_eq!(arr.value(0), 0.0);
920        }
921    }
922
923    #[test]
924    fn test_max_partitions() {
925        let mut op = LagLeadOperator::new(LagLeadConfig {
926            operator_id: "test".to_string(),
927            functions: vec![LagLeadFunctionSpec {
928                is_lag: true,
929                source_column: "price".to_string(),
930                offset: 1,
931                default_value: None,
932                output_column: "prev_price".to_string(),
933            }],
934            partition_columns: vec!["symbol".to_string()],
935            max_partitions: 2,
936        });
937        let mut timers = TimerService::new();
938        let mut state = InMemoryStore::new();
939        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
940
941        let e1 = make_trade(1, "AAPL", 150.0);
942        let e2 = make_trade(2, "GOOG", 2800.0);
943        let e3 = make_trade(3, "MSFT", 300.0);
944
945        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
946        op.process(&e1, &mut ctx);
947        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
948        op.process(&e2, &mut ctx);
949        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
950        let outputs = op.process(&e3, &mut ctx);
951
952        assert!(outputs.is_empty()); // MSFT rejected
953        assert_eq!(op.partition_count(), 2);
954    }
955
956    #[test]
957    fn test_checkpoint_restore() {
958        let mut op = LagLeadOperator::new(lag_config(1));
959        let mut timers = TimerService::new();
960        let mut state = InMemoryStore::new();
961        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
962
963        let events = vec![
964            make_trade(1, "AAPL", 100.0),
965            make_trade(2, "AAPL", 110.0),
966            make_trade(3, "GOOG", 2800.0),
967        ];
968        for e in &events {
969            let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
970            op.process(e, &mut ctx);
971        }
972
973        let checkpoint = op.checkpoint();
974        assert_eq!(checkpoint.operator_id, "test_lag");
975
976        let mut op2 = LagLeadOperator::new(lag_config(1));
977        op2.restore(checkpoint).unwrap();
978        assert_eq!(op2.partition_count(), 2);
979    }
980
981    #[test]
982    fn test_metrics() {
983        let mut op = LagLeadOperator::new(lag_config(1));
984        let mut timers = TimerService::new();
985        let mut state = InMemoryStore::new();
986        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
987
988        let e1 = make_trade(1, "AAPL", 150.0);
989        let e2 = make_trade(2, "AAPL", 155.0);
990        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
991        op.process(&e1, &mut ctx);
992        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
993        op.process(&e2, &mut ctx);
994
995        assert_eq!(op.metrics().events_processed, 2);
996        assert_eq!(op.metrics().lag_lookups, 2);
997    }
998
999    #[test]
1000    fn test_lead_separate_partitions() {
1001        let mut op = LagLeadOperator::new(lead_config(1));
1002        let mut timers = TimerService::new();
1003        let mut state = InMemoryStore::new();
1004        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
1005
1006        let a1 = make_trade(1, "AAPL", 150.0);
1007        let g1 = make_trade(2, "GOOG", 2800.0);
1008        let a2 = make_trade(3, "AAPL", 155.0);
1009
1010        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
1011        op.process(&a1, &mut ctx);
1012        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
1013        op.process(&g1, &mut ctx);
1014        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
1015        let outputs = op.process(&a2, &mut ctx);
1016
1017        // AAPL's first event should resolve with next AAPL value
1018        assert_eq!(outputs.len(), 1);
1019        if let Output::Event(e) = &outputs[0] {
1020            let arr = e
1021                .data
1022                .column_by_name("next_price")
1023                .unwrap()
1024                .as_any()
1025                .downcast_ref::<Float64Array>()
1026                .unwrap();
1027            assert_eq!(arr.value(0), 155.0);
1028        }
1029    }
1030
1031    #[test]
1032    fn test_empty_operator() {
1033        let op = LagLeadOperator::new(lag_config(1));
1034        assert_eq!(op.partition_count(), 0);
1035        assert_eq!(op.metrics().events_processed, 0);
1036    }
1037}