1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
//! Sorted view implementation for efficient quantile queries.
use crate::{ReqError, Result, SearchCriteria, TotalOrd};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
/// A sorted view of all items in the sketch with their cumulative weights.
///
/// This provides an efficient representation for quantile and rank queries
/// by maintaining items in sorted order with precomputed cumulative weights.
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(
feature = "serde",
serde(
bound = "T: Clone + TotalOrd + PartialEq + serde::Serialize + serde::de::DeserializeOwned"
)
)]
pub struct SortedView<T> {
/// Items in sorted order
items: Vec<T>,
/// Cumulative weights for each item
cumulative_weights: Vec<u64>,
/// Total weight of all items
total_weight: u64,
}
impl<T> SortedView<T>
where
T: Clone + TotalOrd + PartialEq,
{
/// Creates a new sorted view from weighted items.
///
/// # Arguments
/// * `weighted_items` - Vector of (item, weight) pairs
///
/// The items will be sorted and cumulative weights computed.
pub fn new(mut weighted_items: Vec<(T, u64)>) -> Self {
if weighted_items.is_empty() {
return Self {
items: Vec::new(),
cumulative_weights: Vec::new(),
total_weight: 0,
};
}
// Sort by item value - use unstable sort for better performance
weighted_items.sort_unstable_by(|a, b| a.0.total_cmp(&b.0));
let mut items: Vec<T> = Vec::with_capacity(weighted_items.len());
let mut cumulative_weights = Vec::with_capacity(weighted_items.len());
let mut cumulative_weight = 0u64;
for (item, weight) in weighted_items {
if let Some(last) = items.last() {
if matches!(last.total_cmp(&item), std::cmp::Ordering::Equal) {
cumulative_weight += weight;
let last_idx = cumulative_weights.len() - 1;
cumulative_weights[last_idx] = cumulative_weight;
continue;
}
}
cumulative_weight += weight;
items.push(item);
cumulative_weights.push(cumulative_weight);
}
Self {
items,
cumulative_weights,
total_weight: cumulative_weight,
}
}
/// Returns the number of items in the sorted view.
pub fn len(&self) -> usize {
self.items.len()
}
/// Returns true if the sorted view is empty.
pub fn is_empty(&self) -> bool {
self.items.is_empty()
}
/// Returns the total weight of all items.
pub fn total_weight(&self) -> u64 {
self.total_weight
}
/// Returns the approximate rank of the given item with interpolation for numeric types.
///
/// # Arguments
/// * `item` - The item to find the rank for
/// * `criteria` - Whether to include the item's weight in the rank
///
/// # Returns
/// A normalized rank in [0.0, 1.0]
pub fn rank_no_interpolation(&self, item: &T, criteria: SearchCriteria) -> Result<f64> {
if self.is_empty() {
return Err(ReqError::EmptySketch);
}
match criteria {
SearchCriteria::Inclusive => {
// Find the last position where items[i] <= item
// partition_point finds first index where predicate is false
let pos = self.items.partition_point(|x| x.total_cmp(item).is_le());
if pos == 0 {
Ok(0.0)
} else {
Ok(self.cumulative_weights[pos - 1] as f64 / self.total_weight as f64)
}
}
SearchCriteria::Exclusive => {
// Find the last position where items[i] < item
let pos = self.items.partition_point(|x| x.total_cmp(item).is_lt());
if pos == 0 {
Ok(0.0)
} else {
Ok(self.cumulative_weights[pos - 1] as f64 / self.total_weight as f64)
}
}
}
}
/// Returns the approximate quantile for the given normalized rank.
///
/// # Arguments
/// * `rank` - A normalized rank in [0.0, 1.0]
/// * `criteria` - Search criteria for quantile selection
///
/// # Returns
/// The item at approximately the given rank
pub fn quantile(&self, rank: f64, criteria: SearchCriteria) -> Result<T> {
if self.is_empty() {
return Err(ReqError::EmptySketch);
}
if !(0.0..=1.0).contains(&rank) {
return Err(ReqError::InvalidRank(rank));
}
// Handle edge cases
if rank == 0.0 {
match criteria {
SearchCriteria::Inclusive => return Ok(self.items[0].clone()),
SearchCriteria::Exclusive => return Ok(self.items[0].clone()),
}
}
if rank == 1.0 {
return Ok(self.items[self.items.len() - 1].clone());
}
// Convert rank to target cumulative weight
// uint64_t weight = static_cast<uint64_t>(inclusive ? std::ceil(rank * total_weight_) : rank * total_weight_);
let target_weight = match criteria {
SearchCriteria::Inclusive => (rank * self.total_weight as f64).ceil() as u64,
SearchCriteria::Exclusive => (rank * self.total_weight as f64) as u64,
};
let index = match criteria {
SearchCriteria::Inclusive => {
// Equivalent to C++ lower_bound: first index where cumulative_weight >= target
self.cumulative_weights
.partition_point(|&w| w < target_weight)
}
SearchCriteria::Exclusive => {
// Equivalent to C++ upper_bound: first index where cumulative_weight > target
self.cumulative_weights
.partition_point(|&w| w <= target_weight)
}
};
if index >= self.items.len() {
return Ok(self.items[self.items.len() - 1].clone());
}
Ok(self.items[index].clone())
}
/// Returns the Probability Mass Function (PMF) for the given split points.
///
/// # Arguments
/// * `split_points` - Array of split points that divide the domain
/// * `criteria` - Search criteria for boundary handling
///
/// # Returns
/// Array of probabilities for each interval defined by the split points
pub fn pmf(&self, split_points: &[T], criteria: SearchCriteria) -> Result<Vec<f64>> {
if self.is_empty() {
return Err(ReqError::EmptySketch);
}
self.validate_split_points(split_points)?;
let mut result = Vec::with_capacity(split_points.len() + 1);
let mut prev_rank = 0.0;
for split_point in split_points {
let rank = self.rank_no_interpolation(split_point, criteria)?;
result.push(rank - prev_rank);
prev_rank = rank;
}
// Add the final interval
result.push(1.0 - prev_rank);
Ok(result)
}
/// Returns the Cumulative Distribution Function (CDF) for the given split points.
///
/// # Arguments
/// * `split_points` - Array of split points that divide the domain
/// * `criteria` - Search criteria for boundary handling
///
/// # Returns
/// Array of cumulative probabilities at each split point
pub fn cdf(&self, split_points: &[T], criteria: SearchCriteria) -> Result<Vec<f64>> {
if self.is_empty() {
return Err(ReqError::EmptySketch);
}
self.validate_split_points(split_points)?;
let mut result = Vec::with_capacity(split_points.len() + 1);
let mut cumulative = 0.0;
let pmf = self.pmf(split_points, criteria)?;
for mass in pmf {
cumulative += mass;
result.push(cumulative);
}
Ok(result)
}
/// Returns an iterator over the items in sorted order.
pub fn iter(&self) -> impl Iterator<Item = &T> {
self.items.iter()
}
/// Returns an iterator over (item, cumulative_weight) pairs.
pub fn iter_with_weights(&self) -> impl Iterator<Item = (&T, u64)> {
self.items
.iter()
.zip(self.cumulative_weights.iter().copied())
}
// Private helper methods
fn validate_split_points(&self, split_points: &[T]) -> Result<()> {
// Check that split points are monotonically increasing
for i in 1..split_points.len() {
if split_points[i - 1].total_cmp(&split_points[i]).is_ge() {
return Err(ReqError::InvalidSplitPoints(
"Split points must be unique and monotonically increasing".to_string(),
));
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_view() -> SortedView<i32> {
let weighted_items = vec![(1, 1), (3, 1), (5, 1), (7, 1), (9, 1)];
SortedView::new(weighted_items)
}
#[test]
fn test_sorted_view_creation() {
let view = create_test_view();
assert_eq!(view.len(), 5);
assert_eq!(view.total_weight(), 5);
assert!(!view.is_empty());
}
#[test]
fn test_rank_queries() -> Result<()> {
let view = create_test_view();
// Test exact matches
assert!((view.rank_no_interpolation(&1, SearchCriteria::Inclusive)? - 0.2).abs() < 1e-10);
assert!((view.rank_no_interpolation(&1, SearchCriteria::Exclusive)? - 0.0).abs() < 1e-10);
// Test values between items
assert!((view.rank_no_interpolation(&2, SearchCriteria::Inclusive)? - 0.2).abs() < 1e-10);
assert!((view.rank_no_interpolation(&6, SearchCriteria::Inclusive)? - 0.6).abs() < 1e-10);
// Test edge cases
assert!((view.rank_no_interpolation(&0, SearchCriteria::Inclusive)? - 0.0).abs() < 1e-10);
assert!((view.rank_no_interpolation(&10, SearchCriteria::Inclusive)? - 1.0).abs() < 1e-10);
Ok(())
}
#[test]
fn test_quantile_queries() -> Result<()> {
let view = create_test_view();
// Test edge cases
assert_eq!(view.quantile(0.0, SearchCriteria::Inclusive)?, 1);
assert_eq!(view.quantile(1.0, SearchCriteria::Inclusive)?, 9);
// Test middle values
let median = view.quantile(0.5, SearchCriteria::Inclusive)?;
assert!((3..=7).contains(&median)); // Should be around the middle (values are 1,3,5,7,9)
// Test various ranks
let q25 = view.quantile(0.25, SearchCriteria::Inclusive)?;
let q75 = view.quantile(0.75, SearchCriteria::Inclusive)?;
assert!(q25 <= median);
assert!(median <= q75);
Ok(())
}
#[test]
fn test_pmf() -> Result<()> {
let view = create_test_view();
let split_points = vec![3, 7];
let pmf = view.pmf(&split_points, SearchCriteria::Inclusive)?;
assert_eq!(pmf.len(), 3); // 2 split points create 3 intervals
// Sum should be approximately 1.0
let sum: f64 = pmf.iter().sum();
assert!((sum - 1.0).abs() < 1e-10);
Ok(())
}
#[test]
fn test_cdf() -> Result<()> {
let view = create_test_view();
let split_points = vec![3, 7];
let cdf = view.cdf(&split_points, SearchCriteria::Inclusive)?;
assert_eq!(cdf.len(), 3);
// CDF should be monotonically increasing
for i in 1..cdf.len() {
assert!(cdf[i] >= cdf[i - 1]);
}
// Last value should be 1.0
assert!((cdf[cdf.len() - 1] - 1.0).abs() < 1e-10);
Ok(())
}
#[test]
fn test_empty_view() {
let view: SortedView<i32> = SortedView::new(vec![]);
assert!(view.is_empty());
assert_eq!(view.len(), 0);
assert_eq!(view.total_weight(), 0);
// Operations on empty view should return errors
assert!(view
.rank_no_interpolation(&5, SearchCriteria::Inclusive)
.is_err());
assert!(view.quantile(0.5, SearchCriteria::Inclusive).is_err());
}
}