Skip to main content

laminar_core/operator/
partitioned_topk.rs

1//! # Partitioned Top-K Operator
2//!
3//! Per-group top-K supporting the `ROW_NUMBER() OVER (PARTITION BY ... ORDER BY ...) WHERE rn <= N`
4//! pattern. Each partition key gets an independent top-K heap.
5//!
6//! ## Memory Bounds
7//!
8//! Memory is O(P * K) where P = distinct partitions, K = per-partition limit.
9//! A `max_partitions` safety limit prevents unbounded partition growth.
10//!
11//! ## Emit Strategies
12//!
13//! Same as the global top-K: `OnUpdate`, `OnWatermark`, or `Periodic`.
14//! Changelog records are emitted per-partition.
15
16use super::topk::{
17    encode_f64, encode_i64, encode_not_null, encode_null, encode_utf8, TopKEmitStrategy,
18    TopKSortColumn,
19};
20use super::window::ChangelogRecord;
21use super::{
22    Event, Operator, OperatorContext, OperatorError, OperatorState, Output, OutputVec, Timer,
23};
24use arrow_array::{Array, Float64Array, Int64Array, StringArray, TimestampMicrosecondArray};
25use arrow_schema::DataType;
26use fxhash::FxHashMap;
27
28/// Configuration for a partition key column.
29#[derive(Debug, Clone, PartialEq, Eq)]
30pub struct PartitionColumn {
31    /// Column name in the event schema.
32    pub column_name: String,
33}
34
35impl PartitionColumn {
36    /// Creates a new partition column.
37    #[must_use]
38    pub fn new(name: impl Into<String>) -> Self {
39        Self {
40            column_name: name.into(),
41        }
42    }
43}
44
45/// An entry in a per-partition top-K heap.
46#[derive(Debug, Clone)]
47struct PartitionEntry {
48    /// Memcomparable sort key.
49    sort_key: Vec<u8>,
50    /// The original event.
51    event: Event,
52}
53
54/// Partitioned top-K operator.
55///
56/// Maintains independent top-K heaps per partition key.
57/// Supports the `ROW_NUMBER() OVER (PARTITION BY ... ORDER BY ...) WHERE rn <= N` pattern.
58pub struct PartitionedTopKOperator {
59    /// Operator identifier for checkpointing.
60    operator_id: String,
61    /// Number of top entries per partition.
62    k: usize,
63    /// Partition key columns.
64    partition_columns: Vec<PartitionColumn>,
65    /// Sort column specifications.
66    sort_columns: Vec<TopKSortColumn>,
67    /// Per-partition top-K heaps, keyed by partition key bytes.
68    partitions: FxHashMap<Vec<u8>, Vec<PartitionEntry>>,
69    /// Emission strategy.
70    emit_strategy: TopKEmitStrategy,
71    /// Pending changelog records (for OnWatermark/Periodic strategies).
72    pending_changes: Vec<ChangelogRecord>,
73    /// Monotonic sequence counter.
74    sequence_counter: u64,
75    /// Maximum number of partitions (memory safety).
76    max_partitions: usize,
77}
78
79impl PartitionedTopKOperator {
80    /// Creates a new partitioned top-K operator.
81    #[must_use]
82    pub fn new(
83        operator_id: String,
84        k: usize,
85        partition_columns: Vec<PartitionColumn>,
86        sort_columns: Vec<TopKSortColumn>,
87        emit_strategy: TopKEmitStrategy,
88        max_partitions: usize,
89    ) -> Self {
90        Self {
91            operator_id,
92            k,
93            partition_columns,
94            sort_columns,
95            partitions: FxHashMap::default(),
96            emit_strategy,
97            pending_changes: Vec::new(),
98            sequence_counter: 0,
99            max_partitions,
100        }
101    }
102
103    /// Returns the number of active partitions.
104    #[must_use]
105    pub fn partition_count(&self) -> usize {
106        self.partitions.len()
107    }
108
109    /// Returns the total number of entries across all partitions.
110    #[must_use]
111    pub fn total_entries(&self) -> usize {
112        self.partitions.values().map(Vec::len).sum()
113    }
114
115    /// Returns the number of entries in a specific partition.
116    #[must_use]
117    pub fn partition_size(&self, partition_key: &[u8]) -> usize {
118        self.partitions.get(partition_key).map_or(0, Vec::len)
119    }
120
121    /// Returns the number of pending changelog records.
122    #[must_use]
123    pub fn pending_changes_count(&self) -> usize {
124        self.pending_changes.len()
125    }
126
127    /// Extracts the partition key from an event.
128    fn extract_partition_key(&self, event: &Event) -> Vec<u8> {
129        let batch = &event.data;
130        let schema = batch.schema();
131        let mut key = Vec::new();
132
133        for col in &self.partition_columns {
134            let Ok(col_idx) = schema.index_of(&col.column_name) else {
135                // Missing column: use null marker
136                key.push(0x00);
137                continue;
138            };
139
140            let array = batch.column(col_idx);
141
142            if array.is_null(0) {
143                key.push(0x00); // null marker
144                continue;
145            }
146
147            key.push(0x01); // non-null marker
148
149            match array.data_type() {
150                DataType::Int64 => {
151                    let arr = array.as_any().downcast_ref::<Int64Array>().unwrap();
152                    key.extend_from_slice(&arr.value(0).to_le_bytes());
153                }
154                DataType::Utf8 => {
155                    let arr = array.as_any().downcast_ref::<StringArray>().unwrap();
156                    let val = arr.value(0);
157                    key.extend_from_slice(val.as_bytes());
158                    key.push(0x00); // null terminator
159                }
160                DataType::Float64 => {
161                    let arr = array.as_any().downcast_ref::<Float64Array>().unwrap();
162                    key.extend_from_slice(&arr.value(0).to_bits().to_le_bytes());
163                }
164                _ => {
165                    key.push(0x00); // unsupported type marker
166                }
167            }
168        }
169
170        key
171    }
172
173    /// Extracts a memcomparable sort key from an event.
174    fn extract_sort_key(&self, event: &Event) -> Vec<u8> {
175        let batch = &event.data;
176        let schema = batch.schema();
177        let mut key = Vec::new();
178
179        for col_spec in &self.sort_columns {
180            let Ok(col_idx) = schema.index_of(&col_spec.column_name) else {
181                encode_null(col_spec.nulls_first, col_spec.descending, &mut key);
182                continue;
183            };
184
185            let array = batch.column(col_idx);
186
187            if array.is_null(0) {
188                encode_null(col_spec.nulls_first, col_spec.descending, &mut key);
189                continue;
190            }
191
192            match array.data_type() {
193                DataType::Int64 => {
194                    let arr = array.as_any().downcast_ref::<Int64Array>().unwrap();
195                    encode_not_null(col_spec.nulls_first, col_spec.descending, &mut key);
196                    encode_i64(arr.value(0), col_spec.descending, &mut key);
197                }
198                DataType::Float64 => {
199                    let arr = array.as_any().downcast_ref::<Float64Array>().unwrap();
200                    encode_not_null(col_spec.nulls_first, col_spec.descending, &mut key);
201                    encode_f64(arr.value(0), col_spec.descending, &mut key);
202                }
203                DataType::Utf8 => {
204                    let arr = array.as_any().downcast_ref::<StringArray>().unwrap();
205                    encode_not_null(col_spec.nulls_first, col_spec.descending, &mut key);
206                    encode_utf8(arr.value(0), col_spec.descending, &mut key);
207                }
208                DataType::Timestamp(_, _) => {
209                    let arr = array
210                        .as_any()
211                        .downcast_ref::<TimestampMicrosecondArray>()
212                        .unwrap();
213                    encode_not_null(col_spec.nulls_first, col_spec.descending, &mut key);
214                    encode_i64(arr.value(0), col_spec.descending, &mut key);
215                }
216                _ => {
217                    encode_null(col_spec.nulls_first, col_spec.descending, &mut key);
218                }
219            }
220        }
221
222        key
223    }
224
225    /// Processes an event for a specific partition, returning changelog records.
226    fn process_partition(
227        &mut self,
228        partition_key: Vec<u8>,
229        event: &Event,
230        emit_timestamp: i64,
231    ) -> Vec<ChangelogRecord> {
232        let sort_key = self.extract_sort_key(event);
233
234        let entries = self.partitions.entry(partition_key).or_default();
235
236        // Check if event enters this partition's top-K
237        if entries.len() >= self.k {
238            if let Some(worst) = entries.last() {
239                if sort_key >= worst.sort_key {
240                    return Vec::new(); // Doesn't enter top-K
241                }
242            }
243        }
244
245        // Find insertion position (binary search)
246        let insert_pos = entries
247            .binary_search_by(|entry| entry.sort_key.as_slice().cmp(&sort_key))
248            .unwrap_or_else(|pos| pos);
249
250        let new_entry = PartitionEntry {
251            sort_key,
252            event: event.clone(),
253        };
254        entries.insert(insert_pos, new_entry);
255
256        let mut changes = Vec::new();
257
258        // Generate insert changelog
259        changes.push(ChangelogRecord::insert(event.clone(), emit_timestamp));
260
261        // Generate rank change retractions for shifted entries
262        for entry in entries
263            .iter()
264            .take(entries.len().min(self.k))
265            .skip(insert_pos + 1)
266        {
267            let shifted_event = &entry.event;
268            let (before, after) = ChangelogRecord::update(
269                shifted_event.clone(),
270                shifted_event.clone(),
271                emit_timestamp,
272            );
273            changes.push(before);
274            changes.push(after);
275        }
276
277        // Evict worst entry if over capacity
278        if entries.len() > self.k {
279            let evicted = entries.pop().unwrap();
280            changes.push(ChangelogRecord::delete(evicted.event, emit_timestamp));
281        }
282
283        self.sequence_counter += 1;
284        changes
285    }
286
287    /// Flushes pending changelog records as Output.
288    fn flush_pending(&mut self) -> OutputVec {
289        let mut outputs = OutputVec::new();
290        for record in self.pending_changes.drain(..) {
291            outputs.push(Output::Changelog(record));
292        }
293        outputs
294    }
295}
296
297impl Operator for PartitionedTopKOperator {
298    fn process(&mut self, event: &Event, _ctx: &mut OperatorContext) -> OutputVec {
299        let partition_key = self.extract_partition_key(event);
300
301        // Check max partitions limit
302        if !self.partitions.contains_key(&partition_key)
303            && self.partitions.len() >= self.max_partitions
304        {
305            // Reject: too many partitions
306            return OutputVec::new();
307        }
308
309        let emit_timestamp = event.timestamp;
310        let changes = self.process_partition(partition_key, event, emit_timestamp);
311
312        match &self.emit_strategy {
313            TopKEmitStrategy::OnUpdate => {
314                let mut outputs = OutputVec::new();
315                for record in changes {
316                    outputs.push(Output::Changelog(record));
317                }
318                outputs
319            }
320            TopKEmitStrategy::OnWatermark | TopKEmitStrategy::Periodic(_) => {
321                self.pending_changes.extend(changes);
322                OutputVec::new()
323            }
324        }
325    }
326
327    fn on_timer(&mut self, _timer: Timer, _ctx: &mut OperatorContext) -> OutputVec {
328        match &self.emit_strategy {
329            TopKEmitStrategy::Periodic(_) => self.flush_pending(),
330            _ => OutputVec::new(),
331        }
332    }
333
334    fn checkpoint(&self) -> OperatorState {
335        let mut data = Vec::new();
336
337        // Write partition count
338        let num_partitions = self.partitions.len() as u64;
339        data.extend_from_slice(&num_partitions.to_le_bytes());
340
341        // Write sequence counter
342        data.extend_from_slice(&self.sequence_counter.to_le_bytes());
343
344        // Write each partition
345        for (key, entries) in &self.partitions {
346            // Partition key length + bytes
347            let key_len = key.len() as u64;
348            data.extend_from_slice(&key_len.to_le_bytes());
349            data.extend_from_slice(key);
350
351            // Entry count
352            let entry_count = entries.len() as u64;
353            data.extend_from_slice(&entry_count.to_le_bytes());
354
355            // Each entry: sort_key_len + sort_key + timestamp
356            for entry in entries {
357                let sk_len = entry.sort_key.len() as u64;
358                data.extend_from_slice(&sk_len.to_le_bytes());
359                data.extend_from_slice(&entry.sort_key);
360                data.extend_from_slice(&entry.event.timestamp.to_le_bytes());
361            }
362        }
363
364        OperatorState {
365            operator_id: self.operator_id.clone(),
366            data,
367        }
368    }
369
370    #[allow(clippy::cast_possible_truncation)] // Checkpoint wire format uses u64 for counts
371    fn restore(&mut self, state: OperatorState) -> Result<(), OperatorError> {
372        if state.data.len() < 16 {
373            return Err(OperatorError::SerializationFailed(
374                "PartitionedTopK checkpoint data too short".to_string(),
375            ));
376        }
377
378        let mut offset = 0;
379
380        let num_partitions = u64::from_le_bytes(
381            state.data[offset..offset + 8]
382                .try_into()
383                .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
384        ) as usize;
385        offset += 8;
386
387        self.sequence_counter = u64::from_le_bytes(
388            state.data[offset..offset + 8]
389                .try_into()
390                .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
391        );
392        offset += 8;
393
394        self.partitions.clear();
395
396        for _ in 0..num_partitions {
397            if offset + 8 > state.data.len() {
398                return Err(OperatorError::SerializationFailed(
399                    "PartitionedTopK checkpoint truncated".to_string(),
400                ));
401            }
402            let key_len = u64::from_le_bytes(
403                state.data[offset..offset + 8]
404                    .try_into()
405                    .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
406            ) as usize;
407            offset += 8;
408
409            if offset + key_len + 8 > state.data.len() {
410                return Err(OperatorError::SerializationFailed(
411                    "PartitionedTopK checkpoint truncated at key".to_string(),
412                ));
413            }
414            let partition_key = state.data[offset..offset + key_len].to_vec();
415            offset += key_len;
416
417            let entry_count = u64::from_le_bytes(
418                state.data[offset..offset + 8]
419                    .try_into()
420                    .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
421            ) as usize;
422            offset += 8;
423
424            let mut entries = Vec::with_capacity(entry_count);
425            for _ in 0..entry_count {
426                if offset + 8 > state.data.len() {
427                    return Err(OperatorError::SerializationFailed(
428                        "PartitionedTopK checkpoint truncated at entry".to_string(),
429                    ));
430                }
431                let sk_len = u64::from_le_bytes(
432                    state.data[offset..offset + 8]
433                        .try_into()
434                        .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
435                ) as usize;
436                offset += 8;
437
438                if offset + sk_len + 8 > state.data.len() {
439                    return Err(OperatorError::SerializationFailed(
440                        "PartitionedTopK checkpoint truncated at sort key".to_string(),
441                    ));
442                }
443                let sort_key = state.data[offset..offset + sk_len].to_vec();
444                offset += sk_len;
445
446                let timestamp = i64::from_le_bytes(
447                    state.data[offset..offset + 8]
448                        .try_into()
449                        .map_err(|e| OperatorError::SerializationFailed(format!("{e}")))?,
450                );
451                offset += 8;
452
453                let batch = arrow_array::RecordBatch::new_empty(std::sync::Arc::new(
454                    arrow_schema::Schema::empty(),
455                ));
456                entries.push(PartitionEntry {
457                    sort_key,
458                    event: Event::new(timestamp, batch),
459                });
460            }
461
462            self.partitions.insert(partition_key, entries);
463        }
464
465        Ok(())
466    }
467}
468
469#[cfg(test)]
470#[allow(clippy::uninlined_format_args)]
471#[allow(clippy::cast_precision_loss)]
472mod tests {
473    use super::super::window::CdcOperation;
474    use super::*;
475    use crate::state::InMemoryStore;
476    use crate::time::{BoundedOutOfOrdernessGenerator, TimerService};
477    use arrow_array::{Float64Array, Int64Array, RecordBatch, StringArray};
478    use arrow_schema::{DataType, Field, Schema};
479    use std::sync::Arc;
480
481    fn make_trade(timestamp: i64, category: &str, price: f64) -> Event {
482        let schema = Arc::new(Schema::new(vec![
483            Field::new("category", DataType::Utf8, false),
484            Field::new("price", DataType::Float64, false),
485        ]));
486        let batch = RecordBatch::try_new(
487            schema,
488            vec![
489                Arc::new(StringArray::from(vec![category])),
490                Arc::new(Float64Array::from(vec![price])),
491            ],
492        )
493        .unwrap();
494        Event::new(timestamp, batch)
495    }
496
497    fn make_trade_int(timestamp: i64, category: &str, value: i64) -> Event {
498        let schema = Arc::new(Schema::new(vec![
499            Field::new("category", DataType::Utf8, false),
500            Field::new("value", DataType::Int64, false),
501        ]));
502        let batch = RecordBatch::try_new(
503            schema,
504            vec![
505                Arc::new(StringArray::from(vec![category])),
506                Arc::new(Int64Array::from(vec![value])),
507            ],
508        )
509        .unwrap();
510        Event::new(timestamp, batch)
511    }
512
513    fn create_test_context<'a>(
514        timers: &'a mut TimerService,
515        state: &'a mut dyn crate::state::StateStore,
516        watermark_gen: &'a mut dyn crate::time::WatermarkGenerator,
517    ) -> OperatorContext<'a> {
518        OperatorContext {
519            event_time: 0,
520            processing_time: 0,
521            timers,
522            state,
523            watermark_generator: watermark_gen,
524            operator_index: 0,
525        }
526    }
527
528    fn create_partitioned_topk(k: usize, max_partitions: usize) -> PartitionedTopKOperator {
529        PartitionedTopKOperator::new(
530            "test_ptopk".to_string(),
531            k,
532            vec![PartitionColumn::new("category")],
533            vec![TopKSortColumn::descending("price")],
534            TopKEmitStrategy::OnUpdate,
535            max_partitions,
536        )
537    }
538
539    #[test]
540    fn test_partitioned_topk_single_partition() {
541        let mut op = create_partitioned_topk(3, 100);
542        let mut timers = TimerService::new();
543        let mut state = InMemoryStore::new();
544        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
545
546        let trades = vec![
547            make_trade(1, "A", 100.0),
548            make_trade(2, "A", 200.0),
549            make_trade(3, "A", 150.0),
550        ];
551
552        for trade in &trades {
553            let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
554            op.process(trade, &mut ctx);
555        }
556
557        assert_eq!(op.partition_count(), 1);
558        assert_eq!(op.total_entries(), 3);
559    }
560
561    #[test]
562    fn test_partitioned_topk_multiple_partitions() {
563        let mut op = create_partitioned_topk(2, 100);
564        let mut timers = TimerService::new();
565        let mut state = InMemoryStore::new();
566        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
567
568        let trades = vec![
569            make_trade(1, "A", 100.0),
570            make_trade(2, "B", 200.0),
571            make_trade(3, "A", 150.0),
572            make_trade(4, "B", 250.0),
573        ];
574
575        for trade in &trades {
576            let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
577            op.process(trade, &mut ctx);
578        }
579
580        assert_eq!(op.partition_count(), 2);
581        assert_eq!(op.total_entries(), 4);
582    }
583
584    #[test]
585    fn test_partitioned_topk_eviction_in_partition() {
586        let mut op = create_partitioned_topk(2, 100);
587        let mut timers = TimerService::new();
588        let mut state = InMemoryStore::new();
589        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
590
591        // Fill partition "A" to capacity
592        let e1 = make_trade(1, "A", 200.0);
593        let e2 = make_trade(2, "A", 150.0);
594        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
595        op.process(&e1, &mut ctx);
596        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
597        op.process(&e2, &mut ctx);
598
599        // Better entry evicts worst in partition
600        let e3 = make_trade(3, "A", 300.0);
601        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
602        let outputs = op.process(&e3, &mut ctx);
603
604        // Should evict price=150 and keep 300, 200
605        assert_eq!(op.total_entries(), 2);
606        assert!(!outputs.is_empty());
607
608        // Verify we have Insert + Delete among outputs
609        let mut has_insert = false;
610        let mut has_delete = false;
611        for output in &outputs {
612            if let Output::Changelog(rec) = output {
613                match rec.operation {
614                    CdcOperation::Insert => has_insert = true,
615                    CdcOperation::Delete => has_delete = true,
616                    _ => {}
617                }
618            }
619        }
620        assert!(has_insert);
621        assert!(has_delete);
622    }
623
624    #[test]
625    fn test_partitioned_topk_no_cross_partition_eviction() {
626        let mut op = create_partitioned_topk(2, 100);
627        let mut timers = TimerService::new();
628        let mut state = InMemoryStore::new();
629        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
630
631        // Fill partition "A" to capacity
632        let e1 = make_trade(1, "A", 200.0);
633        let e2 = make_trade(2, "A", 150.0);
634        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
635        op.process(&e1, &mut ctx);
636        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
637        op.process(&e2, &mut ctx);
638
639        // New entry in partition "B" does NOT evict from "A"
640        let e3 = make_trade(3, "B", 50.0);
641        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
642        op.process(&e3, &mut ctx);
643
644        assert_eq!(op.partition_count(), 2);
645        assert_eq!(op.total_entries(), 3); // 2 in A + 1 in B
646    }
647
648    #[test]
649    fn test_partitioned_topk_emit_on_update() {
650        let mut op = create_partitioned_topk(3, 100);
651        let mut timers = TimerService::new();
652        let mut state = InMemoryStore::new();
653        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
654
655        let trade = make_trade(1, "A", 100.0);
656        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
657        let outputs = op.process(&trade, &mut ctx);
658
659        // OnUpdate: should emit immediately
660        assert!(!outputs.is_empty());
661        match &outputs[0] {
662            Output::Changelog(rec) => {
663                assert_eq!(rec.operation, CdcOperation::Insert);
664            }
665            _ => panic!("Expected Changelog output"),
666        }
667    }
668
669    #[test]
670    fn test_partitioned_topk_emit_on_watermark() {
671        let mut op = PartitionedTopKOperator::new(
672            "test_ptopk".to_string(),
673            2,
674            vec![PartitionColumn::new("category")],
675            vec![TopKSortColumn::descending("price")],
676            TopKEmitStrategy::OnWatermark,
677            100,
678        );
679
680        let mut timers = TimerService::new();
681        let mut state = InMemoryStore::new();
682        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
683
684        let trade = make_trade(1, "A", 100.0);
685        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
686        let outputs = op.process(&trade, &mut ctx);
687
688        // OnWatermark: should buffer, not emit
689        assert!(outputs.is_empty());
690        assert!(op.pending_changes_count() > 0);
691    }
692
693    #[test]
694    fn test_partitioned_topk_empty_partition() {
695        let op = create_partitioned_topk(3, 100);
696        assert_eq!(op.partition_count(), 0);
697        assert_eq!(op.total_entries(), 0);
698    }
699
700    #[test]
701    fn test_partitioned_topk_max_partitions() {
702        let mut op = create_partitioned_topk(2, 2); // max 2 partitions
703        let mut timers = TimerService::new();
704        let mut state = InMemoryStore::new();
705        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
706
707        // Create 2 partitions
708        let e1 = make_trade(1, "A", 100.0);
709        let e2 = make_trade(2, "B", 200.0);
710        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
711        op.process(&e1, &mut ctx);
712        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
713        op.process(&e2, &mut ctx);
714
715        // Third partition rejected
716        let e3 = make_trade(3, "C", 300.0);
717        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
718        let outputs = op.process(&e3, &mut ctx);
719
720        assert!(outputs.is_empty());
721        assert_eq!(op.partition_count(), 2);
722    }
723
724    #[test]
725    fn test_partitioned_topk_k_equals_one() {
726        let mut op = create_partitioned_topk(1, 100);
727        let mut timers = TimerService::new();
728        let mut state = InMemoryStore::new();
729        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
730
731        let e1 = make_trade(1, "A", 100.0);
732        let e2 = make_trade(2, "A", 200.0);
733        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
734        op.process(&e1, &mut ctx);
735        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
736        op.process(&e2, &mut ctx);
737
738        // Only best entry kept per partition
739        assert_eq!(op.total_entries(), 1);
740    }
741
742    #[test]
743    fn test_partitioned_topk_multi_column_partition_key() {
744        let mut op = PartitionedTopKOperator::new(
745            "test_ptopk".to_string(),
746            3,
747            vec![
748                PartitionColumn::new("category"),
749                PartitionColumn::new("value"),
750            ],
751            vec![TopKSortColumn::descending("value")],
752            TopKEmitStrategy::OnUpdate,
753            100,
754        );
755
756        let mut timers = TimerService::new();
757        let mut state = InMemoryStore::new();
758        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
759
760        let e1 = make_trade_int(1, "A", 100);
761        let e2 = make_trade_int(2, "A", 200);
762        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
763        op.process(&e1, &mut ctx);
764        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
765        op.process(&e2, &mut ctx);
766
767        // Each (category, value) combo is a unique partition
768        assert_eq!(op.partition_count(), 2);
769    }
770
771    #[test]
772    fn test_partitioned_topk_multi_column_sort() {
773        let mut op = PartitionedTopKOperator::new(
774            "test_ptopk".to_string(),
775            3,
776            vec![PartitionColumn::new("category")],
777            vec![TopKSortColumn::descending("price")],
778            TopKEmitStrategy::OnUpdate,
779            100,
780        );
781
782        let mut timers = TimerService::new();
783        let mut state = InMemoryStore::new();
784        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
785
786        let trades = vec![
787            make_trade(1, "A", 100.0),
788            make_trade(2, "A", 300.0),
789            make_trade(3, "A", 200.0),
790        ];
791
792        for trade in &trades {
793            let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
794            op.process(trade, &mut ctx);
795        }
796
797        assert_eq!(op.total_entries(), 3);
798    }
799
800    #[test]
801    fn test_partitioned_topk_checkpoint_restore() {
802        let mut op = create_partitioned_topk(3, 100);
803        let mut timers = TimerService::new();
804        let mut state = InMemoryStore::new();
805        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
806
807        let trades = vec![
808            make_trade(1, "A", 100.0),
809            make_trade(2, "B", 200.0),
810            make_trade(3, "A", 150.0),
811        ];
812
813        for trade in &trades {
814            let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
815            op.process(trade, &mut ctx);
816        }
817
818        let checkpoint = op.checkpoint();
819        assert_eq!(checkpoint.operator_id, "test_ptopk");
820
821        let mut op2 = create_partitioned_topk(3, 100);
822        op2.restore(checkpoint).unwrap();
823
824        assert_eq!(op2.partition_count(), 2);
825        assert_eq!(op2.total_entries(), 3);
826    }
827
828    #[test]
829    fn test_partitioned_topk_rank_changes() {
830        let mut op = create_partitioned_topk(3, 100);
831        let mut timers = TimerService::new();
832        let mut state = InMemoryStore::new();
833        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
834
835        // Insert two entries
836        let e1 = make_trade(1, "A", 100.0);
837        let e2 = make_trade(2, "A", 200.0);
838        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
839        op.process(&e1, &mut ctx);
840        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
841        op.process(&e2, &mut ctx);
842
843        // Insert between them causes rank change for price=100
844        let e3 = make_trade(3, "A", 150.0);
845        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
846        let outputs = op.process(&e3, &mut ctx);
847
848        // Should have Insert + UpdateBefore + UpdateAfter
849        let mut has_update_before = false;
850        let mut has_update_after = false;
851        for output in &outputs {
852            if let Output::Changelog(rec) = output {
853                match rec.operation {
854                    CdcOperation::UpdateBefore => has_update_before = true,
855                    CdcOperation::UpdateAfter => has_update_after = true,
856                    _ => {}
857                }
858            }
859        }
860        assert!(has_update_before);
861        assert!(has_update_after);
862    }
863
864    #[test]
865    fn test_partitioned_topk_row_number_pattern() {
866        // Simulates ROW_NUMBER() OVER (PARTITION BY category ORDER BY price DESC) WHERE rn <= 2
867        let mut op = create_partitioned_topk(2, 100);
868        let mut timers = TimerService::new();
869        let mut state = InMemoryStore::new();
870        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
871
872        let trades = vec![
873            make_trade(1, "tech", 100.0),
874            make_trade(2, "tech", 200.0),
875            make_trade(3, "tech", 150.0), // evicts 100
876            make_trade(4, "finance", 300.0),
877            make_trade(5, "finance", 250.0),
878        ];
879
880        for trade in &trades {
881            let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
882            op.process(trade, &mut ctx);
883        }
884
885        assert_eq!(op.partition_count(), 2);
886        assert_eq!(op.total_entries(), 4); // 2 in tech + 2 in finance
887    }
888
889    #[test]
890    fn test_partitioned_topk_string_partition_key() {
891        let mut op = create_partitioned_topk(3, 100);
892        let mut timers = TimerService::new();
893        let mut state = InMemoryStore::new();
894        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
895
896        let trades = vec![
897            make_trade(1, "electronics", 100.0),
898            make_trade(2, "clothing", 200.0),
899            make_trade(3, "electronics", 150.0),
900        ];
901
902        for trade in &trades {
903            let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
904            op.process(trade, &mut ctx);
905        }
906
907        assert_eq!(op.partition_count(), 2);
908    }
909
910    #[test]
911    fn test_partitioned_topk_null_partition_key() {
912        let mut op = create_partitioned_topk(3, 100);
913        let mut timers = TimerService::new();
914        let mut state = InMemoryStore::new();
915        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
916
917        // Create event with null category
918        let schema = Arc::new(Schema::new(vec![
919            Field::new("category", DataType::Utf8, true),
920            Field::new("price", DataType::Float64, false),
921        ]));
922        let batch = RecordBatch::try_new(
923            schema,
924            vec![
925                Arc::new(StringArray::new_null(1)),
926                Arc::new(Float64Array::from(vec![100.0])),
927            ],
928        )
929        .unwrap();
930        let null_event = Event::new(1, batch);
931
932        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
933        op.process(&null_event, &mut ctx);
934
935        // Null partition key should still create a partition
936        assert_eq!(op.partition_count(), 1);
937    }
938
939    #[test]
940    fn test_partitioned_topk_changelog_per_partition() {
941        let mut op = create_partitioned_topk(2, 100);
942        let mut timers = TimerService::new();
943        let mut state = InMemoryStore::new();
944        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
945
946        // Insert in partition A
947        let e1 = make_trade(1, "A", 100.0);
948        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
949        let out_a = op.process(&e1, &mut ctx);
950
951        // Insert in partition B
952        let e2 = make_trade(2, "B", 200.0);
953        let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
954        let out_b = op.process(&e2, &mut ctx);
955
956        // Both should independently emit Insert changelog
957        assert_eq!(out_a.len(), 1);
958        assert_eq!(out_b.len(), 1);
959    }
960
961    #[test]
962    fn test_partitioned_topk_large_partitions() {
963        let mut op = create_partitioned_topk(5, 1000);
964        let mut timers = TimerService::new();
965        let mut state = InMemoryStore::new();
966        let mut wm = BoundedOutOfOrdernessGenerator::new(0);
967
968        // Create many partitions with a few entries each
969        for i in 0..50 {
970            let category = format!("cat_{}", i);
971            for j in 0..3 {
972                let trade = make_trade(i * 100 + j, &category, j as f64 * 10.0);
973                let mut ctx = create_test_context(&mut timers, &mut state, &mut wm);
974                op.process(&trade, &mut ctx);
975            }
976        }
977
978        assert_eq!(op.partition_count(), 50);
979        assert_eq!(op.total_entries(), 150); // 50 partitions * 3 entries each
980    }
981}