use crate::exec::result::Row;
use std::cmp::Ordering;
use std::collections::BinaryHeap;
#[allow(dead_code)]
pub struct StreamingTopK {
heap: BinaryHeap<ScoredRow>,
k: usize,
processed_count: usize,
would_keep_count: usize,
}
#[allow(dead_code)]
#[derive(Clone)]
struct ScoredRow {
row: Row,
score: f64,
}
impl Ord for ScoredRow {
fn cmp(&self, other: &Self) -> Ordering {
other
.score
.partial_cmp(&self.score)
.unwrap_or(Ordering::Equal)
}
}
impl PartialOrd for ScoredRow {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Eq for ScoredRow {}
impl PartialEq for ScoredRow {
fn eq(&self, other: &Self) -> bool {
self.score == other.score
}
}
impl StreamingTopK {
#[allow(dead_code)] pub fn new(k: usize) -> Self {
Self {
heap: BinaryHeap::with_capacity(k + 1), k,
processed_count: 0,
would_keep_count: 0,
}
}
#[allow(dead_code)] pub fn add(&mut self, row: Row, score: f64) {
self.processed_count += 1;
if self.heap.len() < self.k {
self.heap.push(ScoredRow { row, score });
self.would_keep_count += 1;
return;
}
if let Some(min_item) = self.heap.peek() {
if score > min_item.score {
self.heap.push(ScoredRow { row, score });
self.heap.pop(); self.would_keep_count += 1;
}
}
}
#[allow(dead_code)] pub fn min_score(&self) -> Option<f64> {
self.heap.peek().map(|item| item.score)
}
#[allow(dead_code)] pub fn len(&self) -> usize {
self.heap.len()
}
#[allow(dead_code)] pub fn is_empty(&self) -> bool {
self.heap.is_empty()
}
#[allow(dead_code)] pub fn stats(&self) -> TopKStats {
TopKStats {
processed_count: self.processed_count,
kept_count: self.heap.len(),
would_keep_count: self.would_keep_count,
k: self.k,
}
}
#[allow(dead_code)] pub fn into_results(self) -> Vec<Row> {
let mut results: Vec<ScoredRow> = self.heap.into_iter().collect();
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
results.into_iter().map(|sr| sr.row).collect()
}
#[allow(dead_code)] pub fn results(&self) -> Vec<Row> {
let mut results: Vec<ScoredRow> = self.heap.iter().cloned().collect();
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
results.into_iter().map(|sr| sr.row).collect()
}
}
#[derive(Debug, Clone)]
#[allow(dead_code)] pub struct TopKStats {
pub processed_count: usize,
pub kept_count: usize,
pub would_keep_count: usize,
pub k: usize,
}
impl TopKStats {
#[allow(dead_code)] pub fn memory_savings_ratio(&self) -> f64 {
if self.processed_count == 0 {
return 1.0;
}
self.kept_count as f64 / self.processed_count as f64
}
#[allow(dead_code)] pub fn discard_ratio(&self) -> f64 {
if self.processed_count == 0 {
return 0.0;
}
(self.processed_count - self.kept_count) as f64 / self.processed_count as f64
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::Value;
use std::collections::HashMap;
fn create_test_row(id: usize, score: f64) -> Row {
let mut values = HashMap::new();
values.insert("id".to_string(), Value::Number(id as f64));
values.insert("score".to_string(), Value::Number(score));
Row {
values,
positional_values: vec![],
source_entities: HashMap::new(),
text_score: Some(score),
highlight_snippet: None,
}
}
#[test]
fn test_streaming_topk_basic() {
let mut topk = StreamingTopK::new(3);
topk.add(create_test_row(1, 10.0), 10.0);
topk.add(create_test_row(2, 20.0), 20.0);
topk.add(create_test_row(3, 5.0), 5.0);
topk.add(create_test_row(4, 15.0), 15.0);
let results = topk.into_results();
assert_eq!(results.len(), 3);
let scores: Vec<f64> = results
.iter()
.map(|r| r.get_text_score().unwrap())
.collect();
assert_eq!(scores[0], 20.0);
assert_eq!(scores[1], 15.0);
assert_eq!(scores[2], 10.0);
}
#[test]
fn test_streaming_topk_with_many_rows() {
let mut topk = StreamingTopK::new(10);
for i in 0..100 {
let score = (i * 7 % 100) as f64; topk.add(create_test_row(i, score), score);
}
let results = topk.into_results();
assert_eq!(results.len(), 10);
for i in 0..results.len() - 1 {
let score_i = results[i].get_text_score().unwrap();
let score_j = results[i + 1].get_text_score().unwrap();
assert!(score_i >= score_j);
}
}
#[test]
fn test_streaming_topk_min_score() {
let mut topk = StreamingTopK::new(3);
topk.add(create_test_row(1, 10.0), 10.0);
topk.add(create_test_row(2, 20.0), 20.0);
topk.add(create_test_row(3, 5.0), 5.0);
assert_eq!(topk.min_score(), Some(5.0));
topk.add(create_test_row(4, 15.0), 15.0);
assert_eq!(topk.min_score(), Some(10.0));
}
#[test]
fn test_streaming_topk_stats() {
let mut topk = StreamingTopK::new(5);
for i in 0..20 {
topk.add(create_test_row(i, i as f64), i as f64);
}
let stats = topk.stats();
assert_eq!(stats.processed_count, 20);
assert_eq!(stats.kept_count, 5);
assert_eq!(stats.k, 5);
assert!((stats.memory_savings_ratio() - 0.25).abs() < 0.01);
assert!((stats.discard_ratio() - 0.75).abs() < 0.01);
}
#[test]
fn test_streaming_topk_empty() {
let topk = StreamingTopK::new(10);
assert!(topk.is_empty());
assert_eq!(topk.len(), 0);
assert_eq!(topk.min_score(), None);
let results = topk.into_results();
assert_eq!(results.len(), 0);
}
#[test]
fn test_streaming_topk_fewer_than_k() {
let mut topk = StreamingTopK::new(10);
for i in 0..5 {
topk.add(create_test_row(i, i as f64), i as f64);
}
let results = topk.into_results();
assert_eq!(results.len(), 5);
}
#[test]
fn test_streaming_topk_duplicate_scores() {
let mut topk = StreamingTopK::new(3);
topk.add(create_test_row(1, 10.0), 10.0);
topk.add(create_test_row(2, 10.0), 10.0);
topk.add(create_test_row(3, 10.0), 10.0);
topk.add(create_test_row(4, 5.0), 5.0);
let results = topk.into_results();
assert_eq!(results.len(), 3);
let scores: Vec<f64> = results
.iter()
.map(|r| r.get_text_score().unwrap())
.collect();
assert_eq!(scores[0], 10.0);
assert_eq!(scores[1], 10.0);
assert_eq!(scores[2], 10.0);
}
}