use std::cmp::Ordering;
use std::collections::BinaryHeap;
use sochdb_core::{SochRow, SochValue};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SortDirection {
Ascending,
Descending,
}
#[derive(Debug, Clone)]
pub struct OrderByColumn {
pub column: ColumnRef,
pub direction: SortDirection,
pub nulls_first: bool,
}
#[derive(Debug, Clone)]
pub enum ColumnRef {
Index(usize),
Name(String),
}
impl ColumnRef {
pub fn resolve(&self, columns: &[String]) -> Option<usize> {
match self {
ColumnRef::Index(i) => Some(*i),
ColumnRef::Name(name) => columns.iter().position(|c| c == name),
}
}
}
#[derive(Debug, Clone)]
pub struct OrderBySpec {
pub columns: Vec<OrderByColumn>,
}
impl OrderBySpec {
pub fn single(column: ColumnRef, direction: SortDirection) -> Self {
Self {
columns: vec![OrderByColumn {
column,
direction,
nulls_first: false,
}],
}
}
pub fn then_by(mut self, column: ColumnRef, direction: SortDirection) -> Self {
self.columns.push(OrderByColumn {
column,
direction,
nulls_first: false,
});
self
}
pub fn comparator(&self, column_names: &[String]) -> impl Fn(&SochRow, &SochRow) -> Ordering {
let resolved: Vec<_> = self.columns
.iter()
.filter_map(|col| {
col.column.resolve(column_names).map(|idx| (idx, col.direction, col.nulls_first))
})
.collect();
move |a: &SochRow, b: &SochRow| {
for &(idx, direction, nulls_first) in &resolved {
let val_a = a.values.get(idx);
let val_b = b.values.get(idx);
let ordering = compare_values(val_a, val_b, nulls_first);
if ordering != Ordering::Equal {
return match direction {
SortDirection::Ascending => ordering,
SortDirection::Descending => ordering.reverse(),
};
}
}
Ordering::Equal
}
}
pub fn matches_index(&self, index_columns: &[(String, SortDirection)]) -> bool {
if self.columns.len() > index_columns.len() {
return false;
}
self.columns.iter().zip(index_columns.iter()).all(|(col, (idx_col, idx_dir))| {
match &col.column {
ColumnRef::Name(name) => name == idx_col && col.direction == *idx_dir,
ColumnRef::Index(_) => false, }
})
}
}
fn compare_values(a: Option<&SochValue>, b: Option<&SochValue>, nulls_first: bool) -> Ordering {
match (a, b) {
(None, None) => Ordering::Equal,
(None, Some(_)) => if nulls_first { Ordering::Less } else { Ordering::Greater },
(Some(_), None) => if nulls_first { Ordering::Greater } else { Ordering::Less },
(Some(SochValue::Null), Some(SochValue::Null)) => Ordering::Equal,
(Some(SochValue::Null), Some(_)) => if nulls_first { Ordering::Less } else { Ordering::Greater },
(Some(_), Some(SochValue::Null)) => if nulls_first { Ordering::Greater } else { Ordering::Less },
(Some(a), Some(b)) => compare_soch_values(a, b),
}
}
fn compare_soch_values(a: &SochValue, b: &SochValue) -> Ordering {
match (a, b) {
(SochValue::Int(a), SochValue::Int(b)) => a.cmp(b),
(SochValue::UInt(a), SochValue::UInt(b)) => a.cmp(b),
(SochValue::Float(a), SochValue::Float(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal),
(SochValue::Text(a), SochValue::Text(b)) => a.cmp(b),
(SochValue::Bool(a), SochValue::Bool(b)) => a.cmp(b),
_ => Ordering::Equal, }
}
pub struct TopKHeap<T, F>
where
F: Fn(&T, &T) -> Ordering,
{
heap: BinaryHeap<ComparableWrapper<T, F>>,
k: usize,
comparator: F,
want_smallest: bool,
}
struct ComparableWrapper<T, F>
where
F: Fn(&T, &T) -> Ordering,
{
value: T,
comparator: *const F,
inverted: bool,
}
unsafe impl<T: Send, F> Send for ComparableWrapper<T, F> where F: Fn(&T, &T) -> Ordering {}
unsafe impl<T: Sync, F> Sync for ComparableWrapper<T, F> where F: Fn(&T, &T) -> Ordering {}
impl<T, F> PartialEq for ComparableWrapper<T, F>
where
F: Fn(&T, &T) -> Ordering,
{
fn eq(&self, other: &Self) -> bool {
self.cmp(other) == Ordering::Equal
}
}
impl<T, F> Eq for ComparableWrapper<T, F> where F: Fn(&T, &T) -> Ordering {}
impl<T, F> PartialOrd for ComparableWrapper<T, F>
where
F: Fn(&T, &T) -> Ordering,
{
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<T, F> Ord for ComparableWrapper<T, F>
where
F: Fn(&T, &T) -> Ordering,
{
fn cmp(&self, other: &Self) -> Ordering {
let cmp = unsafe { &*self.comparator };
let result = cmp(&self.value, &other.value);
if self.inverted {
result.reverse()
} else {
result
}
}
}
impl<T, F> TopKHeap<T, F>
where
F: Fn(&T, &T) -> Ordering,
{
pub fn new(k: usize, comparator: F, want_smallest: bool) -> Self {
Self {
heap: BinaryHeap::with_capacity(k + 1),
k,
comparator,
want_smallest,
}
}
pub fn push(&mut self, value: T) {
if self.k == 0 {
return;
}
let wrapper = ComparableWrapper {
value,
comparator: &self.comparator as *const F,
inverted: !self.want_smallest,
};
if self.heap.len() < self.k {
self.heap.push(wrapper);
} else if let Some(top) = self.heap.peek() {
let should_replace = if self.want_smallest {
(self.comparator)(&wrapper.value, &top.value) == Ordering::Less
} else {
(self.comparator)(&wrapper.value, &top.value) == Ordering::Greater
};
if should_replace {
self.heap.pop();
self.heap.push(wrapper);
}
}
}
pub fn threshold(&self) -> Option<&T> {
self.heap.peek().map(|w| &w.value)
}
pub fn is_full(&self) -> bool {
self.heap.len() >= self.k
}
pub fn into_sorted_vec(self) -> Vec<T> {
let mut values: Vec<_> = self.heap.into_iter().map(|w| w.value).collect();
if self.want_smallest {
values.sort_by(&self.comparator);
} else {
values.sort_by(|a, b| (&self.comparator)(b, a));
}
values
}
pub fn len(&self) -> usize {
self.heap.len()
}
pub fn is_empty(&self) -> bool {
self.heap.is_empty()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ExecutionStrategy {
IndexPushdown,
StreamingTopK,
FullSort,
}
impl ExecutionStrategy {
pub fn choose(
has_matching_index: bool,
estimated_rows: Option<usize>,
limit: usize,
) -> Self {
if has_matching_index {
return ExecutionStrategy::IndexPushdown;
}
let n = match estimated_rows {
Some(n) if n > 0 => n,
_ => return ExecutionStrategy::StreamingTopK,
};
let k = limit;
if k <= 100 {
ExecutionStrategy::StreamingTopK
} else if (k as f64) < (n as f64).sqrt() {
ExecutionStrategy::StreamingTopK
} else {
ExecutionStrategy::FullSort
}
}
pub fn complexity(&self, n: usize, k: usize) -> String {
match self {
ExecutionStrategy::IndexPushdown => {
format!("O(log {} + {}) = O({})", n, k, (n as f64).log2() as usize + k)
}
ExecutionStrategy::StreamingTopK => {
let log_k = (k as f64).log2().max(1.0) as usize;
format!("O({} * log {}) ≈ O({})", n, k, n * log_k)
}
ExecutionStrategy::FullSort => {
let log_n = (n as f64).log2().max(1.0) as usize;
format!("O({} * log {}) ≈ O({})", n, n, n * log_n)
}
}
}
}
#[derive(Debug, Clone, Default)]
pub struct OrderByLimitStats {
pub strategy: Option<ExecutionStrategy>,
pub input_rows: usize,
pub output_rows: usize,
pub heap_operations: usize,
pub comparisons: usize,
pub offset_skipped: usize,
}
pub struct OrderByLimitExecutor {
order_by: OrderBySpec,
limit: usize,
offset: usize,
column_names: Vec<String>,
strategy: ExecutionStrategy,
}
impl OrderByLimitExecutor {
pub fn new(
order_by: OrderBySpec,
limit: usize,
offset: usize,
column_names: Vec<String>,
has_matching_index: bool,
estimated_rows: Option<usize>,
) -> Self {
let effective_limit = limit.saturating_add(offset);
let strategy = ExecutionStrategy::choose(has_matching_index, estimated_rows, effective_limit);
Self {
order_by,
limit,
offset,
column_names,
strategy,
}
}
pub fn strategy(&self) -> ExecutionStrategy {
self.strategy
}
pub fn execute<I>(&self, rows: I) -> (Vec<SochRow>, OrderByLimitStats)
where
I: Iterator<Item = SochRow>,
{
let mut stats = OrderByLimitStats {
strategy: Some(self.strategy),
..Default::default()
};
let effective_limit = self.limit.saturating_add(self.offset);
let result = match self.strategy {
ExecutionStrategy::IndexPushdown => {
let collected: Vec<_> = rows.take(effective_limit).collect();
stats.input_rows = collected.len();
collected
}
ExecutionStrategy::StreamingTopK => {
self.execute_streaming(rows, effective_limit, &mut stats)
}
ExecutionStrategy::FullSort => {
self.execute_full_sort(rows, effective_limit, &mut stats)
}
};
let final_result: Vec<_> = result
.into_iter()
.skip(self.offset)
.take(self.limit)
.collect();
stats.offset_skipped = self.offset.min(stats.input_rows);
stats.output_rows = final_result.len();
(final_result, stats)
}
fn execute_streaming<I>(
&self,
rows: I,
k: usize,
stats: &mut OrderByLimitStats,
) -> Vec<SochRow>
where
I: Iterator<Item = SochRow>,
{
let comparator = self.order_by.comparator(&self.column_names);
let mut heap = TopKHeap::new(k, comparator, true);
for row in rows {
stats.input_rows += 1;
stats.heap_operations += 1;
heap.push(row);
}
heap.into_sorted_vec()
}
fn execute_full_sort<I>(
&self,
rows: I,
k: usize,
stats: &mut OrderByLimitStats,
) -> Vec<SochRow>
where
I: Iterator<Item = SochRow>,
{
let comparator = self.order_by.comparator(&self.column_names);
let mut all_rows: Vec<_> = rows.collect();
stats.input_rows = all_rows.len();
all_rows.sort_by(&comparator);
all_rows.truncate(k);
all_rows
}
}
pub struct IndexAwareTopK<T, F>
where
F: Fn(&T, &T) -> Ordering,
{
current_batch: Vec<T>,
result: Vec<T>,
k: usize,
secondary_cmp: F,
}
impl<T, F> IndexAwareTopK<T, F>
where
F: Fn(&T, &T) -> Ordering,
{
pub fn new(k: usize, secondary_cmp: F) -> Self {
Self {
current_batch: Vec::new(),
result: Vec::with_capacity(k),
k,
secondary_cmp,
}
}
pub fn push(&mut self, item: T, same_index_key_as_previous: bool) {
if !same_index_key_as_previous {
self.finalize_batch();
}
self.current_batch.push(item);
}
fn finalize_batch(&mut self) {
if self.current_batch.is_empty() {
return;
}
self.current_batch.sort_by(&self.secondary_cmp);
let remaining = self.k.saturating_sub(self.result.len());
let to_take = remaining.min(self.current_batch.len());
self.result.extend(self.current_batch.drain(..to_take));
self.current_batch.clear();
}
pub fn is_complete(&self) -> bool {
self.result.len() >= self.k
}
pub fn into_result(mut self) -> Vec<T> {
self.finalize_batch();
self.result
}
}
pub struct SingleColumnTopK {
heap: BinaryHeap<SingleColEntry>,
k: usize,
col_idx: usize,
ascending: bool,
}
struct SingleColEntry {
row: SochRow,
key: OrderableValue,
ascending: bool,
}
#[derive(Clone)]
enum OrderableValue {
Int(i64),
UInt(u64),
Float(f64),
Text(String),
Bool(bool),
Null,
}
impl From<&SochValue> for OrderableValue {
fn from(v: &SochValue) -> Self {
match v {
SochValue::Int(i) => OrderableValue::Int(*i),
SochValue::UInt(u) => OrderableValue::UInt(*u),
SochValue::Float(f) => OrderableValue::Float(*f),
SochValue::Text(s) => OrderableValue::Text(s.clone()),
SochValue::Bool(b) => OrderableValue::Bool(*b),
SochValue::Null => OrderableValue::Null,
_ => OrderableValue::Null,
}
}
}
impl PartialEq for OrderableValue {
fn eq(&self, other: &Self) -> bool {
self.cmp(other) == Ordering::Equal
}
}
impl Eq for OrderableValue {}
impl PartialOrd for OrderableValue {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for OrderableValue {
fn cmp(&self, other: &Self) -> Ordering {
match (self, other) {
(OrderableValue::Null, OrderableValue::Null) => Ordering::Equal,
(OrderableValue::Null, _) => Ordering::Greater, (_, OrderableValue::Null) => Ordering::Less,
(OrderableValue::Int(a), OrderableValue::Int(b)) => a.cmp(b),
(OrderableValue::UInt(a), OrderableValue::UInt(b)) => a.cmp(b),
(OrderableValue::Float(a), OrderableValue::Float(b)) => {
a.partial_cmp(b).unwrap_or(Ordering::Equal)
}
(OrderableValue::Text(a), OrderableValue::Text(b)) => a.cmp(b),
(OrderableValue::Bool(a), OrderableValue::Bool(b)) => a.cmp(b),
_ => Ordering::Equal, }
}
}
impl PartialEq for SingleColEntry {
fn eq(&self, other: &Self) -> bool {
self.key == other.key
}
}
impl Eq for SingleColEntry {}
impl PartialOrd for SingleColEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for SingleColEntry {
fn cmp(&self, other: &Self) -> Ordering {
let base = self.key.cmp(&other.key);
if self.ascending {
base
} else {
base.reverse()
}
}
}
impl SingleColumnTopK {
pub fn new(k: usize, col_idx: usize, ascending: bool) -> Self {
Self {
heap: BinaryHeap::with_capacity(k + 1),
k,
col_idx,
ascending,
}
}
pub fn push(&mut self, row: SochRow) {
if self.k == 0 {
return;
}
let key = row.values
.get(self.col_idx)
.map(OrderableValue::from)
.unwrap_or(OrderableValue::Null);
let entry = SingleColEntry {
row,
key,
ascending: self.ascending,
};
if self.heap.len() < self.k {
self.heap.push(entry);
} else if let Some(top) = self.heap.peek() {
let should_replace = if self.ascending {
entry.key < top.key
} else {
entry.key > top.key
};
if should_replace {
self.heap.pop();
self.heap.push(entry);
}
}
}
pub fn into_sorted_vec(self) -> Vec<SochRow> {
let mut entries: Vec<_> = self.heap.into_iter().collect();
entries.sort_by(|a, b| {
let base = a.key.cmp(&b.key);
if self.ascending { base } else { base.reverse() }
});
entries.into_iter().map(|e| e.row).collect()
}
pub fn len(&self) -> usize {
self.heap.len()
}
pub fn is_empty(&self) -> bool {
self.heap.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_row(values: Vec<SochValue>) -> SochRow {
SochRow::new(values)
}
#[test]
fn test_strategy_selection() {
assert_eq!(
ExecutionStrategy::choose(true, Some(1_000_000), 10),
ExecutionStrategy::IndexPushdown
);
assert_eq!(
ExecutionStrategy::choose(false, Some(1_000_000), 10),
ExecutionStrategy::StreamingTopK
);
assert_eq!(
ExecutionStrategy::choose(false, Some(1000), 500),
ExecutionStrategy::FullSort
);
assert_eq!(
ExecutionStrategy::choose(false, Some(100), 90),
ExecutionStrategy::StreamingTopK
);
}
#[test]
fn test_order_by_spec_comparator() {
let spec = OrderBySpec::single(
ColumnRef::Name("priority".to_string()),
SortDirection::Ascending,
);
let columns = vec!["id".to_string(), "priority".to_string(), "name".to_string()];
let cmp = spec.comparator(&columns);
let row1 = make_row(vec![
SochValue::Int(1),
SochValue::Int(5),
SochValue::Text("A".to_string()),
]);
let row2 = make_row(vec![
SochValue::Int(2),
SochValue::Int(3),
SochValue::Text("B".to_string()),
]);
assert_eq!(cmp(&row2, &row1), Ordering::Less);
}
#[test]
fn test_topk_heap_ascending() {
let cmp = |a: &i32, b: &i32| a.cmp(b);
let mut heap = TopKHeap::new(3, cmp, true);
for i in [5, 2, 8, 1, 9, 3, 7, 4, 6] {
heap.push(i);
}
let result = heap.into_sorted_vec();
assert_eq!(result, vec![1, 2, 3]);
}
#[test]
fn test_topk_heap_descending() {
let cmp = |a: &i32, b: &i32| a.cmp(b);
let mut heap = TopKHeap::new(3, cmp, false);
for i in [5, 2, 8, 1, 9, 3, 7, 4, 6] {
heap.push(i);
}
let result = heap.into_sorted_vec();
assert_eq!(result, vec![9, 8, 7]);
}
#[test]
fn test_executor_streaming() {
let columns = vec!["priority".to_string(), "name".to_string()];
let order_by = OrderBySpec::single(
ColumnRef::Name("priority".to_string()),
SortDirection::Ascending,
);
let executor = OrderByLimitExecutor::new(
order_by,
3, 0, columns.clone(),
false, Some(10),
);
let rows = vec![
make_row(vec![SochValue::Int(5), SochValue::Text("E".to_string())]),
make_row(vec![SochValue::Int(3), SochValue::Text("C".to_string())]),
make_row(vec![SochValue::Int(1), SochValue::Text("A".to_string())]),
make_row(vec![SochValue::Int(4), SochValue::Text("D".to_string())]),
make_row(vec![SochValue::Int(2), SochValue::Text("B".to_string())]),
];
let (result, stats) = executor.execute(rows.into_iter());
assert_eq!(result.len(), 3);
assert_eq!(result[0].values[0], SochValue::Int(1));
assert_eq!(result[1].values[0], SochValue::Int(2));
assert_eq!(result[2].values[0], SochValue::Int(3));
assert_eq!(stats.input_rows, 5);
assert_eq!(stats.output_rows, 3);
}
#[test]
fn test_executor_with_offset() {
let columns = vec!["priority".to_string()];
let order_by = OrderBySpec::single(
ColumnRef::Name("priority".to_string()),
SortDirection::Ascending,
);
let executor = OrderByLimitExecutor::new(
order_by,
2, 2, columns,
false,
Some(10),
);
let rows = vec![
make_row(vec![SochValue::Int(5)]),
make_row(vec![SochValue::Int(3)]),
make_row(vec![SochValue::Int(1)]),
make_row(vec![SochValue::Int(4)]),
make_row(vec![SochValue::Int(2)]),
];
let (result, _) = executor.execute(rows.into_iter());
assert_eq!(result.len(), 2);
assert_eq!(result[0].values[0], SochValue::Int(3));
assert_eq!(result[1].values[0], SochValue::Int(4));
}
#[test]
fn test_single_column_topk() {
let mut topk = SingleColumnTopK::new(3, 0, true);
topk.push(make_row(vec![SochValue::Int(5)]));
topk.push(make_row(vec![SochValue::Int(3)]));
topk.push(make_row(vec![SochValue::Int(1)]));
topk.push(make_row(vec![SochValue::Int(4)]));
topk.push(make_row(vec![SochValue::Int(2)]));
let result = topk.into_sorted_vec();
assert_eq!(result.len(), 3);
assert_eq!(result[0].values[0], SochValue::Int(1));
assert_eq!(result[1].values[0], SochValue::Int(2));
assert_eq!(result[2].values[0], SochValue::Int(3));
}
#[test]
fn test_correctness_vs_buggy_implementation() {
let columns = vec!["priority".to_string()];
let order_by = OrderBySpec::single(
ColumnRef::Name("priority".to_string()),
SortDirection::Ascending,
);
let rows: Vec<_> = [5, 2, 8, 1, 9, 3]
.iter()
.map(|&p| make_row(vec![SochValue::Int(p)]))
.collect();
let buggy: Vec<_> = rows.iter().take(3).cloned().collect();
let executor = OrderByLimitExecutor::new(
order_by,
3,
0,
columns,
false,
Some(6),
);
let (correct, _) = executor.execute(rows.into_iter());
assert_eq!(correct[0].values[0], SochValue::Int(1));
assert_eq!(correct[1].values[0], SochValue::Int(2));
assert_eq!(correct[2].values[0], SochValue::Int(3));
}
#[test]
fn test_multi_column_order_by() {
let columns = vec!["priority".to_string(), "created_at".to_string()];
let order_by = OrderBySpec::single(
ColumnRef::Name("priority".to_string()),
SortDirection::Ascending,
).then_by(
ColumnRef::Name("created_at".to_string()),
SortDirection::Descending,
);
let executor = OrderByLimitExecutor::new(
order_by,
3,
0,
columns,
false,
Some(5),
);
let rows = vec![
make_row(vec![SochValue::Int(1), SochValue::Int(100)]),
make_row(vec![SochValue::Int(1), SochValue::Int(200)]), make_row(vec![SochValue::Int(2), SochValue::Int(150)]),
make_row(vec![SochValue::Int(1), SochValue::Int(150)]),
make_row(vec![SochValue::Int(3), SochValue::Int(100)]),
];
let (result, _) = executor.execute(rows.into_iter());
assert_eq!(result.len(), 3);
assert_eq!(result[0].values[0], SochValue::Int(1));
assert_eq!(result[0].values[1], SochValue::Int(200));
assert_eq!(result[1].values[0], SochValue::Int(1));
assert_eq!(result[1].values[1], SochValue::Int(150));
assert_eq!(result[2].values[0], SochValue::Int(1));
assert_eq!(result[2].values[1], SochValue::Int(100));
}
}