1use crate::Error;
21use arrow::array::{
22 Array, ArrayRef, Float32Array, Float64Array, Int32Array, Int64Array, StringArray,
23};
24use arrow::compute::SortOptions;
25use arrow::record_batch::RecordBatch;
26use std::cmp::Ordering;
27use std::collections::BinaryHeap;
28use std::sync::Arc;
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum SortOrder {
33 Ascending,
35 Descending,
37}
38
39impl From<SortOrder> for SortOptions {
40 fn from(order: SortOrder) -> Self {
41 Self { descending: matches!(order, SortOrder::Descending), nulls_first: false }
42 }
43}
44
45pub trait TopKSelection {
47 fn top_k(&self, column_index: usize, k: usize, order: SortOrder) -> crate::Result<RecordBatch>;
87}
88
89impl TopKSelection for RecordBatch {
90 fn top_k(&self, column_index: usize, k: usize, order: SortOrder) -> crate::Result<RecordBatch> {
91 if k == 0 {
93 return Err(Error::InvalidInput("k must be greater than 0".to_string()));
94 }
95
96 if column_index >= self.num_columns() {
97 return Err(Error::InvalidInput(format!(
98 "Column index {} out of bounds (batch has {} columns)",
99 column_index,
100 self.num_columns()
101 )));
102 }
103
104 if k >= self.num_rows() {
106 return sort_all_rows(self, column_index, order);
107 }
108
109 let column = self.column(column_index);
111 let indices = select_top_k_indices(column, k, order)?;
112
113 build_batch_from_indices(self, &indices)
115 }
116}
117
118fn select_top_k_indices(
123 column: &ArrayRef,
124 k: usize,
125 order: SortOrder,
126) -> crate::Result<Vec<usize>> {
127 match column.data_type() {
128 arrow::datatypes::DataType::Int32 => {
129 let array = column.as_any().downcast_ref::<Int32Array>().ok_or_else(|| {
130 Error::Other("Failed to downcast Int32 column to Int32Array".to_string())
131 })?;
132 select_top_k_typed(array.len(), k, order, |i| array.is_null(i), |i| array.value(i))
133 }
134 arrow::datatypes::DataType::Int64 => {
135 let array = column.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
136 Error::Other("Failed to downcast Int64 column to Int64Array".to_string())
137 })?;
138 select_top_k_typed(array.len(), k, order, |i| array.is_null(i), |i| array.value(i))
139 }
140 arrow::datatypes::DataType::Float32 => {
141 let array = column.as_any().downcast_ref::<Float32Array>().ok_or_else(|| {
142 Error::Other("Failed to downcast Float32 column to Float32Array".to_string())
143 })?;
144 select_top_k_typed(array.len(), k, order, |i| array.is_null(i), |i| array.value(i))
145 }
146 arrow::datatypes::DataType::Float64 => {
147 let array = column.as_any().downcast_ref::<Float64Array>().ok_or_else(|| {
148 Error::Other("Failed to downcast Float64 column to Float64Array".to_string())
149 })?;
150 select_top_k_typed(array.len(), k, order, |i| array.is_null(i), |i| array.value(i))
151 }
152 dt => Err(Error::InvalidInput(format!("Top-K not supported for data type: {dt:?}"))),
153 }
154}
155
156#[derive(Debug)]
158struct MinHeapItem<V> {
159 value: V,
160 index: usize,
161}
162
163impl<V: PartialOrd> PartialEq for MinHeapItem<V> {
164 fn eq(&self, other: &Self) -> bool {
165 self.value.partial_cmp(&other.value) == Some(Ordering::Equal)
166 }
167}
168
169impl<V: PartialOrd> Eq for MinHeapItem<V> {}
170
171impl<V: PartialOrd> Ord for MinHeapItem<V> {
172 fn cmp(&self, other: &Self) -> Ordering {
173 other.value.partial_cmp(&self.value).unwrap_or(Ordering::Equal)
175 }
176}
177
178impl<V: PartialOrd> PartialOrd for MinHeapItem<V> {
179 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
180 Some(self.cmp(other))
181 }
182}
183
184#[derive(Debug)]
186struct MaxHeapItem<V> {
187 value: V,
188 index: usize,
189}
190
191impl<V: PartialOrd> PartialEq for MaxHeapItem<V> {
192 fn eq(&self, other: &Self) -> bool {
193 self.value.partial_cmp(&other.value) == Some(Ordering::Equal)
194 }
195}
196
197impl<V: PartialOrd> Eq for MaxHeapItem<V> {}
198
199impl<V: PartialOrd> Ord for MaxHeapItem<V> {
200 fn cmp(&self, other: &Self) -> Ordering {
201 self.value.partial_cmp(&other.value).unwrap_or(Ordering::Equal)
203 }
204}
205
206impl<V: PartialOrd> PartialOrd for MaxHeapItem<V> {
207 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
208 Some(self.cmp(other))
209 }
210}
211
212fn collect_top_k_descending<V: PartialOrd>(
214 len: usize,
215 k: usize,
216 is_null: impl Fn(usize) -> bool,
217 get_value: impl Fn(usize) -> V,
218) -> Vec<usize> {
219 let mut heap: BinaryHeap<MinHeapItem<V>> = BinaryHeap::with_capacity(k);
220
221 for index in 0..len {
222 if !is_null(index) {
223 let value = get_value(index);
224 if heap.len() < k {
225 heap.push(MinHeapItem { value, index });
226 } else if let Some(top) = heap.peek() {
227 if value.partial_cmp(&top.value) == Some(Ordering::Greater) {
228 heap.pop();
229 heap.push(MinHeapItem { value, index });
230 }
231 }
232 }
233 }
234
235 let mut result: Vec<_> = heap.into_vec();
236 result.sort_by(|a, b| b.value.partial_cmp(&a.value).unwrap_or(Ordering::Equal));
237 result.into_iter().map(|item| item.index).collect()
238}
239
240fn collect_top_k_ascending<V: PartialOrd>(
242 len: usize,
243 k: usize,
244 is_null: impl Fn(usize) -> bool,
245 get_value: impl Fn(usize) -> V,
246) -> Vec<usize> {
247 let mut heap: BinaryHeap<MaxHeapItem<V>> = BinaryHeap::with_capacity(k);
248
249 for index in 0..len {
250 if !is_null(index) {
251 let value = get_value(index);
252 if heap.len() < k {
253 heap.push(MaxHeapItem { value, index });
254 } else if let Some(top) = heap.peek() {
255 if value.partial_cmp(&top.value) == Some(Ordering::Less) {
256 heap.pop();
257 heap.push(MaxHeapItem { value, index });
258 }
259 }
260 }
261 }
262
263 let mut result: Vec<_> = heap.into_vec();
264 result.sort_by(|a, b| a.value.partial_cmp(&b.value).unwrap_or(Ordering::Equal));
265 result.into_iter().map(|item| item.index).collect()
266}
267
268#[allow(clippy::unnecessary_wraps)]
270fn select_top_k_typed<V: PartialOrd>(
271 len: usize,
272 k: usize,
273 order: SortOrder,
274 is_null: impl Fn(usize) -> bool,
275 get_value: impl Fn(usize) -> V,
276) -> crate::Result<Vec<usize>> {
277 let indices = match order {
278 SortOrder::Descending => collect_top_k_descending(len, k, is_null, get_value),
279 SortOrder::Ascending => collect_top_k_ascending(len, k, is_null, get_value),
280 };
281 Ok(indices)
282}
283
284fn build_batch_from_indices(batch: &RecordBatch, indices: &[usize]) -> crate::Result<RecordBatch> {
286 use arrow::datatypes::DataType;
287
288 let mut new_columns: Vec<ArrayRef> = Vec::with_capacity(batch.num_columns());
289
290 for col_idx in 0..batch.num_columns() {
291 let column = batch.column(col_idx);
292
293 let new_array: ArrayRef = match column.data_type() {
294 DataType::Int32 => {
295 let array = column.as_any().downcast_ref::<Int32Array>().ok_or_else(|| {
296 Error::Other("Failed to downcast Int32 column to Int32Array".to_string())
297 })?;
298 let values: Vec<i32> = indices.iter().map(|&idx| array.value(idx)).collect();
299 Arc::new(Int32Array::from(values))
300 }
301 DataType::Int64 => {
302 let array = column.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
303 Error::Other("Failed to downcast Int64 column to Int64Array".to_string())
304 })?;
305 let values: Vec<i64> = indices.iter().map(|&idx| array.value(idx)).collect();
306 Arc::new(Int64Array::from(values))
307 }
308 DataType::Float32 => {
309 let array = column.as_any().downcast_ref::<Float32Array>().ok_or_else(|| {
310 Error::Other("Failed to downcast Float32 column to Float32Array".to_string())
311 })?;
312 let values: Vec<f32> = indices.iter().map(|&idx| array.value(idx)).collect();
313 Arc::new(Float32Array::from(values))
314 }
315 DataType::Float64 => {
316 let array = column.as_any().downcast_ref::<Float64Array>().ok_or_else(|| {
317 Error::Other("Failed to downcast Float64 column to Float64Array".to_string())
318 })?;
319 let values: Vec<f64> = indices.iter().map(|&idx| array.value(idx)).collect();
320 Arc::new(Float64Array::from(values))
321 }
322 DataType::Utf8 => {
323 let array = column.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
324 Error::Other("Failed to downcast Utf8 column to StringArray".to_string())
325 })?;
326 let values: Vec<&str> = indices.iter().map(|&idx| array.value(idx)).collect();
327 Arc::new(StringArray::from(values))
328 }
329 dt => {
330 return Err(Error::InvalidInput(format!(
331 "Top-K not implemented for column data type: {dt:?}"
332 )));
333 }
334 };
335
336 new_columns.push(new_array);
337 }
338
339 RecordBatch::try_new(batch.schema(), new_columns)
340 .map_err(|e| Error::StorageError(format!("Failed to create result batch: {e}")))
341}
342
343fn sort_all_rows(
345 batch: &RecordBatch,
346 column_index: usize,
347 order: SortOrder,
348) -> crate::Result<RecordBatch> {
349 use arrow::compute::sort_to_indices;
350
351 let sort_options = SortOptions::from(order);
352 let indices = sort_to_indices(batch.column(column_index).as_ref(), Some(sort_options), None)
353 .map_err(|e| Error::StorageError(format!("Failed to sort: {e}")))?;
354
355 let indices_array =
357 indices.as_any().downcast_ref::<arrow::array::UInt32Array>().ok_or_else(|| {
358 Error::Other(
359 "Failed to downcast sort indices to UInt32Array (expected from sort_to_indices)"
360 .to_string(),
361 )
362 })?;
363 let indices_vec: Vec<usize> =
364 (0..indices_array.len()).map(|i| indices_array.value(i) as usize).collect();
365
366 build_batch_from_indices(batch, &indices_vec)
367}
368
369#[cfg(test)]
370#[allow(
371 clippy::cast_possible_truncation,
372 clippy::cast_possible_wrap,
373 clippy::cast_precision_loss,
374 clippy::float_cmp,
375 clippy::redundant_closure
376)]
377mod tests {
378 use super::*;
379 use arrow::datatypes::{DataType, Field, Schema};
380 use std::sync::Arc;
381
382 fn create_test_batch(values: Vec<f64>) -> RecordBatch {
383 let schema = Arc::new(Schema::new(vec![
384 Field::new("id", DataType::Int32, false),
385 Field::new("score", DataType::Float64, false),
386 ]));
387
388 let ids: Vec<i32> = (0..values.len() as i32).collect();
389
390 RecordBatch::try_new(
391 schema,
392 vec![Arc::new(Int32Array::from(ids)), Arc::new(Float64Array::from(values))],
393 )
394 .unwrap()
395 }
396
397 #[test]
398 fn test_top_k_descending_basic() {
399 let batch = create_test_batch(vec![1.0, 5.0, 3.0, 9.0, 2.0]);
401 let result = batch.top_k(1, 3, SortOrder::Descending).unwrap();
402
403 assert_eq!(result.num_rows(), 3);
404
405 let scores = result.column(1).as_any().downcast_ref::<Float64Array>().unwrap();
406 assert_eq!(scores.value(0), 9.0);
407 assert_eq!(scores.value(1), 5.0);
408 assert_eq!(scores.value(2), 3.0);
409 }
410
411 #[test]
412 fn test_top_k_ascending_basic() {
413 let batch = create_test_batch(vec![1.0, 5.0, 3.0, 9.0, 2.0]);
415 let result = batch.top_k(1, 3, SortOrder::Ascending).unwrap();
416
417 assert_eq!(result.num_rows(), 3);
418
419 let scores = result.column(1).as_any().downcast_ref::<Float64Array>().unwrap();
420 assert_eq!(scores.value(0), 1.0);
421 assert_eq!(scores.value(1), 2.0);
422 assert_eq!(scores.value(2), 3.0);
423 }
424
425 #[test]
426 fn test_top_k_k_equals_length() {
427 let batch = create_test_batch(vec![3.0, 1.0, 2.0]);
429 let result = batch.top_k(1, 3, SortOrder::Descending).unwrap();
430
431 assert_eq!(result.num_rows(), 3);
432
433 let scores = result.column(1).as_any().downcast_ref::<Float64Array>().unwrap();
434 assert_eq!(scores.value(0), 3.0);
435 assert_eq!(scores.value(1), 2.0);
436 assert_eq!(scores.value(2), 1.0);
437 }
438
439 #[test]
440 fn test_top_k_k_greater_than_length() {
441 let batch = create_test_batch(vec![3.0, 1.0, 2.0]);
443 let result = batch.top_k(1, 10, SortOrder::Descending).unwrap();
444
445 assert_eq!(result.num_rows(), 3);
446
447 let scores = result.column(1).as_any().downcast_ref::<Float64Array>().unwrap();
448 assert_eq!(scores.value(0), 3.0);
449 assert_eq!(scores.value(1), 2.0);
450 assert_eq!(scores.value(2), 1.0);
451 }
452
453 #[test]
454 fn test_top_k_k_zero_fails() {
455 let batch = create_test_batch(vec![1.0, 2.0, 3.0]);
457 let result = batch.top_k(1, 0, SortOrder::Descending);
458
459 assert!(result.is_err());
460 assert!(result.unwrap_err().to_string().contains("must be greater than 0"));
461 }
462
463 #[test]
464 fn test_top_k_invalid_column_index() {
465 let batch = create_test_batch(vec![1.0, 2.0, 3.0]);
467 let result = batch.top_k(99, 2, SortOrder::Descending);
468
469 assert!(result.is_err());
470 assert!(result.unwrap_err().to_string().contains("out of bounds"));
471 }
472
473 #[test]
474 fn test_top_k_preserves_row_integrity() {
475 let batch = create_test_batch(vec![1.0, 5.0, 3.0]);
477 let result = batch.top_k(1, 2, SortOrder::Descending).unwrap();
478
479 let ids = result.column(0).as_any().downcast_ref::<Int32Array>().unwrap();
480 let scores = result.column(1).as_any().downcast_ref::<Float64Array>().unwrap();
481
482 assert_eq!(scores.value(0), 5.0);
484 assert_eq!(ids.value(0), 1);
485
486 assert_eq!(scores.value(1), 3.0);
487 assert_eq!(ids.value(1), 2);
488 }
489
490 #[test]
491 fn test_top_k_large_dataset() {
492 let values: Vec<f64> = (0..1_000_000).map(|i| f64::from(i)).collect();
494 let batch = create_test_batch(values);
495
496 let start = std::time::Instant::now();
497 let result = batch.top_k(1, 10, SortOrder::Descending).unwrap();
498 let duration = start.elapsed();
499
500 assert_eq!(result.num_rows(), 10);
501
502 let scores = result.column(1).as_any().downcast_ref::<Float64Array>().unwrap();
503 for i in 0..10 {
505 assert_eq!(scores.value(i), 999_999.0 - i as f64);
506 }
507
508 assert!(
512 duration.as_millis() < 500,
513 "Top-K took {}ms (expected <500ms)",
514 duration.as_millis()
515 );
516 }
517
518 #[cfg(test)]
520 mod property_tests {
521 use super::*;
522 use proptest::prelude::*;
523
524 proptest! {
525 #[test]
527 fn prop_top_k_returns_k_rows(
528 values in prop::collection::vec(0.0f64..1000.0, 10..1000),
529 k in 1usize..100
530 ) {
531 let batch = create_test_batch(values.clone());
532 let result = batch.top_k(1, k, SortOrder::Descending).unwrap();
533
534 let expected_rows = k.min(values.len());
535 prop_assert_eq!(result.num_rows(), expected_rows);
536 }
537
538 #[test]
540 fn prop_top_k_descending_is_sorted(
541 values in prop::collection::vec(0.0f64..1000.0, 10..1000),
542 k in 1usize..100
543 ) {
544 let batch = create_test_batch(values);
545 let result = batch.top_k(1, k, SortOrder::Descending).unwrap();
546
547 let scores = result.column(1).as_any().downcast_ref::<Float64Array>().unwrap();
548
549 for i in 0..scores.len().saturating_sub(1) {
551 prop_assert!(
552 scores.value(i) >= scores.value(i + 1),
553 "Not in descending order: {} < {}",
554 scores.value(i),
555 scores.value(i + 1)
556 );
557 }
558 }
559
560 #[test]
562 fn prop_top_k_ascending_is_sorted(
563 values in prop::collection::vec(0.0f64..1000.0, 10..1000),
564 k in 1usize..100
565 ) {
566 let batch = create_test_batch(values);
567 let result = batch.top_k(1, k, SortOrder::Ascending).unwrap();
568
569 let scores = result.column(1).as_any().downcast_ref::<Float64Array>().unwrap();
570
571 for i in 0..scores.len().saturating_sub(1) {
573 prop_assert!(
574 scores.value(i) <= scores.value(i + 1),
575 "Not in ascending order: {} > {}",
576 scores.value(i),
577 scores.value(i + 1)
578 );
579 }
580 }
581 }
582 }
583
584 #[test]
586 fn test_top_k_int32() {
587 use arrow::array::Int32Array;
588 use arrow::datatypes::{DataType, Field, Schema};
589 use std::sync::Arc;
590
591 let schema = Schema::new(vec![Field::new("value", DataType::Int32, false)]);
592 let values = Int32Array::from(vec![5, 2, 8, 1, 9, 3]);
593 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values)]).unwrap();
594
595 let result = batch.top_k(0, 3, SortOrder::Descending).unwrap();
596 assert_eq!(result.num_rows(), 3);
597
598 let col = result.column(0).as_any().downcast_ref::<Int32Array>().unwrap();
599 assert_eq!(col.value(0), 9);
600 assert_eq!(col.value(1), 8);
601 assert_eq!(col.value(2), 5);
602 }
603
604 #[test]
605 fn test_top_k_int32_ascending() {
606 use arrow::array::Int32Array;
607 use arrow::datatypes::{DataType, Field, Schema};
608 use std::sync::Arc;
609
610 let schema = Schema::new(vec![Field::new("value", DataType::Int32, false)]);
611 let values = Int32Array::from(vec![5, 2, 8, 1, 9, 3]);
612 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values)]).unwrap();
613
614 let result = batch.top_k(0, 3, SortOrder::Ascending).unwrap();
615 assert_eq!(result.num_rows(), 3);
616
617 let col = result.column(0).as_any().downcast_ref::<Int32Array>().unwrap();
618 assert_eq!(col.value(0), 1);
619 assert_eq!(col.value(1), 2);
620 assert_eq!(col.value(2), 3);
621 }
622
623 #[test]
624 fn test_top_k_int64() {
625 use arrow::array::Int64Array;
626 use arrow::datatypes::{DataType, Field, Schema};
627 use std::sync::Arc;
628
629 let schema = Schema::new(vec![Field::new("value", DataType::Int64, false)]);
630 let values = Int64Array::from(vec![100i64, 200, 50, 300, 150]);
631 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values)]).unwrap();
632
633 let result = batch.top_k(0, 2, SortOrder::Ascending).unwrap();
634 assert_eq!(result.num_rows(), 2);
635
636 let col = result.column(0).as_any().downcast_ref::<Int64Array>().unwrap();
637 assert_eq!(col.value(0), 50);
638 assert_eq!(col.value(1), 100);
639 }
640
641 #[test]
642 fn test_top_k_int64_descending() {
643 use arrow::array::Int64Array;
644 use arrow::datatypes::{DataType, Field, Schema};
645 use std::sync::Arc;
646
647 let schema = Schema::new(vec![Field::new("value", DataType::Int64, false)]);
648 let values = Int64Array::from(vec![100i64, 200, 50, 300, 150]);
649 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values)]).unwrap();
650
651 let result = batch.top_k(0, 2, SortOrder::Descending).unwrap();
652 assert_eq!(result.num_rows(), 2);
653
654 let col = result.column(0).as_any().downcast_ref::<Int64Array>().unwrap();
655 assert_eq!(col.value(0), 300);
656 assert_eq!(col.value(1), 200);
657 }
658
659 #[test]
660 fn test_top_k_float32() {
661 use arrow::array::Float32Array;
662 use arrow::datatypes::{DataType, Field, Schema};
663 use std::sync::Arc;
664
665 let schema = Schema::new(vec![Field::new("value", DataType::Float32, false)]);
666 let values = Float32Array::from(vec![1.5f32, 2.7, 0.3, 4.2, 3.1]);
667 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values)]).unwrap();
668
669 let result = batch.top_k(0, 3, SortOrder::Descending).unwrap();
670 assert_eq!(result.num_rows(), 3);
671
672 let col = result.column(0).as_any().downcast_ref::<Float32Array>().unwrap();
673 assert!((col.value(0) - 4.2).abs() < 0.001);
674 assert!((col.value(1) - 3.1).abs() < 0.001);
675 assert!((col.value(2) - 2.7).abs() < 0.001);
676 }
677
678 #[test]
679 fn test_top_k_float32_ascending() {
680 use arrow::array::Float32Array;
681 use arrow::datatypes::{DataType, Field, Schema};
682 use std::sync::Arc;
683
684 let schema = Schema::new(vec![Field::new("value", DataType::Float32, false)]);
685 let values = Float32Array::from(vec![1.5f32, 2.7, 0.3, 4.2, 3.1]);
686 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values)]).unwrap();
687
688 let result = batch.top_k(0, 3, SortOrder::Ascending).unwrap();
689 assert_eq!(result.num_rows(), 3);
690
691 let col = result.column(0).as_any().downcast_ref::<Float32Array>().unwrap();
692 assert!((col.value(0) - 0.3).abs() < 0.001);
693 assert!((col.value(1) - 1.5).abs() < 0.001);
694 assert!((col.value(2) - 2.7).abs() < 0.001);
695 }
696
697 #[test]
698 fn test_top_k_unsupported_type() {
699 use arrow::array::StringArray;
700 use arrow::datatypes::{DataType, Field, Schema};
701 use std::sync::Arc;
702
703 let schema = Schema::new(vec![Field::new("value", DataType::Utf8, false)]);
704 let values = StringArray::from(vec!["a", "b", "c"]);
705 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values)]).unwrap();
706
707 let result = batch.top_k(0, 2, SortOrder::Descending);
708 assert!(result.is_err());
709 assert!(result.unwrap_err().to_string().contains("Top-K not supported for data type"));
710 }
711
712 #[test]
717 fn test_min_heap_item_eq() {
718 let item1 = MinHeapItem { value: 42i32, index: 0 };
719 let item2 = MinHeapItem { value: 42i32, index: 1 };
720 let item3 = MinHeapItem { value: 43i32, index: 2 };
721
722 assert_eq!(item1, item2);
723 assert_ne!(item1, item3);
724 }
725
726 #[test]
727 fn test_min_heap_item_ord() {
728 let item1 = MinHeapItem { value: 10i32, index: 0 };
729 let item2 = MinHeapItem { value: 20i32, index: 1 };
730 let item3 = MinHeapItem { value: 30i32, index: 2 };
731
732 assert!(item3 < item2); assert!(item2 < item1); }
736
737 #[test]
738 fn test_min_heap_item_partial_ord() {
739 let item1 = MinHeapItem { value: 5i32, index: 0 };
740 let item2 = MinHeapItem { value: 10i32, index: 1 };
741
742 assert!(item1.partial_cmp(&item2) == Some(Ordering::Greater));
743 }
744
745 #[test]
746 fn test_max_heap_item_eq() {
747 let item1 = MaxHeapItem { value: 42i32, index: 0 };
748 let item2 = MaxHeapItem { value: 42i32, index: 1 };
749 let item3 = MaxHeapItem { value: 43i32, index: 2 };
750
751 assert_eq!(item1, item2);
752 assert_ne!(item1, item3);
753 }
754
755 #[test]
756 fn test_max_heap_item_ord() {
757 let item1 = MaxHeapItem { value: 10i32, index: 0 };
758 let item2 = MaxHeapItem { value: 20i32, index: 1 };
759 let item3 = MaxHeapItem { value: 30i32, index: 2 };
760
761 assert!(item3 > item2);
763 assert!(item2 > item1);
764 }
765
766 #[test]
767 fn test_max_heap_item_partial_ord() {
768 let item1 = MaxHeapItem { value: 5i32, index: 0 };
769 let item2 = MaxHeapItem { value: 10i32, index: 1 };
770
771 assert!(item1.partial_cmp(&item2) == Some(Ordering::Less));
772 }
773
774 #[test]
775 fn test_heap_item_with_floats() {
776 let item1 = MinHeapItem { value: 1.5f64, index: 0 };
777 let item2 = MinHeapItem { value: 2.5f64, index: 1 };
778
779 assert_ne!(item1, item2);
780 assert!(item2 < item1); }
782
783 #[test]
784 fn test_heap_item_eq_method_with_floats() {
785 let item1 = MaxHeapItem { value: 3.25f64, index: 0 };
786 let item2 = MaxHeapItem { value: 3.25f64, index: 1 };
787 let item3 = MaxHeapItem { value: 2.75f64, index: 2 };
788
789 assert!(item1.eq(&item2));
790 assert!(!item1.eq(&item3));
791 }
792}