use crate::velesql::{WindowFunction, WindowFunctionType, WindowOrderBy};
use crate::SearchResult;
use std::collections::BTreeMap;
struct WindowSnapshot {
order_by_values: Vec<Vec<serde_json::Value>>,
partition_keys: Vec<String>,
}
impl WindowSnapshot {
fn capture(results: &[SearchResult], wf: &WindowFunction) -> Self {
let order_by_values = results
.iter()
.map(|r| {
wf.over_clause
.order_by
.iter()
.map(|ob| extract_sort_value(r, &ob.column))
.collect()
})
.collect();
let partition_keys = results
.iter()
.map(|r| partition_key(r, &wf.over_clause.partition_by))
.collect();
Self {
order_by_values,
partition_keys,
}
}
}
pub fn evaluate(
results: &mut [SearchResult],
window_functions: &[WindowFunction],
) -> crate::Result<()> {
let snapshots: Vec<WindowSnapshot> = window_functions
.iter()
.map(|wf| WindowSnapshot::capture(results, wf))
.collect();
for (wf, snapshot) in window_functions.iter().zip(snapshots.iter()) {
apply_single_window(results, wf, snapshot);
}
Ok(())
}
fn apply_single_window(
results: &mut [SearchResult],
wf: &WindowFunction,
snapshot: &WindowSnapshot,
) {
let alias = wf
.alias
.as_deref()
.unwrap_or(wf.function_type.default_alias());
if wf.over_clause.order_by.is_empty() {
tracing::warn!(
function = alias,
"Window function OVER clause has no ORDER BY; ranking order is non-deterministic"
);
}
let partitions = build_partitions_from_snapshot(&snapshot.partition_keys);
let mut all_rankings: Vec<(usize, u64)> = Vec::new();
for indices in partitions.values() {
let sorted = sort_partition_from_snapshots(
indices,
&snapshot.order_by_values,
&wf.over_clause.order_by,
);
let rankings = compute_rankings(
&sorted,
&snapshot.order_by_values,
&wf.over_clause.order_by,
wf.function_type,
);
all_rankings.extend(rankings);
}
for (idx, value) in all_rankings {
inject_ranking(&mut results[idx], alias, value);
}
}
fn build_partitions_from_snapshot(partition_keys: &[String]) -> BTreeMap<String, Vec<usize>> {
let mut partitions: BTreeMap<String, Vec<usize>> = BTreeMap::new();
for (i, key) in partition_keys.iter().enumerate() {
partitions.entry(key.clone()).or_default().push(i);
}
partitions
}
fn partition_key(result: &SearchResult, columns: &[String]) -> String {
if columns.is_empty() {
return String::new();
}
columns
.iter()
.map(|col| extract_payload_value(result, col))
.collect::<Vec<_>>()
.join("\x1F")
}
fn sort_partition_from_snapshots(
indices: &[usize],
snapshots: &[Vec<serde_json::Value>],
order_by: &[WindowOrderBy],
) -> Vec<usize> {
let mut sorted = indices.to_vec();
sorted.sort_by(|&a, &b| {
for (col_idx, ob) in order_by.iter().enumerate() {
let cmp = compare_json_values(&snapshots[a][col_idx], &snapshots[b][col_idx]);
let cmp = if ob.descending { cmp.reverse() } else { cmp };
if cmp != std::cmp::Ordering::Equal {
return cmp;
}
}
std::cmp::Ordering::Equal
});
sorted
}
fn compute_rankings(
sorted_indices: &[usize],
snapshots: &[Vec<serde_json::Value>],
order_by: &[WindowOrderBy],
fn_type: WindowFunctionType,
) -> Vec<(usize, u64)> {
match fn_type {
WindowFunctionType::RowNumber => compute_row_numbers(sorted_indices),
WindowFunctionType::Rank => compute_rank(sorted_indices, snapshots, order_by),
WindowFunctionType::DenseRank => compute_dense_rank(sorted_indices, snapshots, order_by),
}
}
fn compute_row_numbers(sorted_indices: &[usize]) -> Vec<(usize, u64)> {
sorted_indices
.iter()
.enumerate()
.map(|(position, &idx)| (idx, (position as u64) + 1))
.collect()
}
fn compute_rank(
sorted_indices: &[usize],
snapshots: &[Vec<serde_json::Value>],
order_by: &[WindowOrderBy],
) -> Vec<(usize, u64)> {
let mut rankings = Vec::with_capacity(sorted_indices.len());
let mut rank: u64 = 1;
for (position, &idx) in sorted_indices.iter().enumerate() {
if is_new_group(position, sorted_indices, snapshots, order_by, idx) {
rank = (position as u64) + 1;
}
rankings.push((idx, rank));
}
rankings
}
fn compute_dense_rank(
sorted_indices: &[usize],
snapshots: &[Vec<serde_json::Value>],
order_by: &[WindowOrderBy],
) -> Vec<(usize, u64)> {
let mut rankings = Vec::with_capacity(sorted_indices.len());
let mut dense_rank: u64 = 1;
for (position, &idx) in sorted_indices.iter().enumerate() {
if is_new_group(position, sorted_indices, snapshots, order_by, idx) {
dense_rank += 1;
}
rankings.push((idx, dense_rank));
}
rankings
}
fn is_new_group(
position: usize,
sorted_indices: &[usize],
snapshots: &[Vec<serde_json::Value>],
order_by: &[WindowOrderBy],
idx: usize,
) -> bool {
if position == 0 {
return false;
}
let prev_idx = sorted_indices[position - 1];
!snapshots_tied(&snapshots[idx], &snapshots[prev_idx], order_by)
}
fn snapshots_tied(
snap_a: &[serde_json::Value],
snap_b: &[serde_json::Value],
order_by: &[WindowOrderBy],
) -> bool {
for (col_idx, _ob) in order_by.iter().enumerate() {
if compare_json_values(&snap_a[col_idx], &snap_b[col_idx]) != std::cmp::Ordering::Equal {
return false;
}
}
true
}
fn extract_sort_value(result: &SearchResult, column: &str) -> serde_json::Value {
if column == "similarity" {
return serde_json::Value::from(f64::from(result.score));
}
extract_nested_value(result, column)
}
fn extract_nested_value(result: &SearchResult, column: &str) -> serde_json::Value {
let Some(payload) = result.point.payload.as_ref() else {
return serde_json::Value::Null;
};
if column.contains('.') {
let parts: Vec<&str> = column.split('.').collect();
let mut current = payload;
for part in &parts {
match current.get(*part) {
Some(v) => current = v,
None => return serde_json::Value::Null,
}
}
current.clone()
} else {
payload
.get(column)
.cloned()
.unwrap_or(serde_json::Value::Null)
}
}
fn extract_payload_value(result: &SearchResult, column: &str) -> String {
extract_nested_value(result, column).to_string()
}
fn compare_json_values(a: &serde_json::Value, b: &serde_json::Value) -> std::cmp::Ordering {
match (a, b) {
(serde_json::Value::Null, serde_json::Value::Null) => std::cmp::Ordering::Equal,
(serde_json::Value::Null, _) => std::cmp::Ordering::Greater, (_, serde_json::Value::Null) => std::cmp::Ordering::Less,
_ => match (a.as_f64(), b.as_f64()) {
(Some(fa), Some(fb)) => fa.partial_cmp(&fb).unwrap_or(std::cmp::Ordering::Equal),
_ => a.to_string().cmp(&b.to_string()),
},
}
}
fn inject_ranking(result: &mut SearchResult, alias: &str, value: u64) {
let payload = result
.point
.payload
.get_or_insert_with(|| serde_json::Value::Object(serde_json::Map::new()));
if let serde_json::Value::Object(map) = payload {
map.insert(alias.to_string(), serde_json::json!(value));
}
}