Skip to main content

diskann_disk/utils/
statistics.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6/// A struct to gather statistics for a given disk query execution.
7#[derive(Debug, Default, Clone)]
8pub struct QueryStatistics {
9    /// Total time to process the query in microseconds.
10    pub total_execution_time_us: u128,
11
12    /// Total time spent in IO operations in microseconds.
13    pub io_time_us: u128,
14
15    /// Total time spent in CPU operations in microseconds.
16    pub cpu_time_us: u128,
17
18    /// Time spent in query preprocessing for the PQ in microseconds.
19    pub query_pq_preprocess_time_us: u128,
20
21    /// Total number of IO operations issued.
22    pub total_io_operations: u32,
23
24    /// Number of saved comparisons (optimization metric).
25    pub comparisons_saved: u32,
26
27    /// Total number of comparisons performed.
28    pub total_comparisons: u32,
29
30    /// Total number of vertices loaded.
31    pub total_vertices_loaded: u32,
32
33    /// Number of hops performed during search.
34    pub search_hops: u32,
35}
36
37/// Calculates the percentile value of a specific metric in a list of QueryStats.
38pub fn get_percentile_stats<T: Ord + Copy>(
39    stats: &[QueryStatistics],
40    percentile: f32,
41    member_fn: impl Fn(&QueryStatistics) -> T,
42) -> T {
43    let mut vals: Vec<T> = stats.iter().map(&member_fn).collect();
44    vals.sort_unstable();
45    let idx = ((percentile * stats.len() as f32) as usize).min(stats.len() - 1);
46    vals[idx]
47}
48
49/// Calculates the mean value of a specific metric in a list of QueryStats.
50pub fn get_mean_stats<T: Into<f64>>(
51    stats: &[QueryStatistics],
52    member_fn: impl Fn(&QueryStatistics) -> T,
53) -> f64 {
54    get_sum_stats(stats, member_fn) / (stats.len() as f64)
55}
56
57pub fn get_sum_stats<T: Into<f64>>(
58    stats: &[QueryStatistics],
59    member_fn: impl Fn(&QueryStatistics) -> T,
60) -> f64 {
61    stats.iter().map(&member_fn).map(|v| v.into()).sum()
62}
63#[cfg(test)]
64mod tests {
65    use rand::Rng;
66
67    use super::*;
68
69    #[test]
70    fn test_get_percentile_stats_batch() {
71        test_get_percentile_stats(0.0f32);
72        test_get_percentile_stats(0.5);
73        test_get_percentile_stats(0.57);
74        test_get_percentile_stats(0.85);
75        test_get_percentile_stats(0.95);
76        test_get_percentile_stats(0.99);
77        test_get_percentile_stats(1.0);
78
79        let mut rng = diskann_providers::utils::create_rnd_in_tests();
80        let random_percentiles: Vec<f32> = (0..100).map(|_| rng.random_range(0f32..1f32)).collect();
81        random_percentiles
82            .iter()
83            .for_each(|&p| test_get_percentile_stats(p));
84    }
85
86    fn test_get_percentile_stats(percentile: f32) {
87        let mut rng = diskann_providers::utils::create_rnd_in_tests();
88        let mut random_numbers: Vec<u32> = (0..1000).map(|_| rng.random_range(0..999999)).collect();
89
90        let query_stats: Vec<QueryStatistics> = random_numbers
91            .iter()
92            .map(|&num| QueryStatistics {
93                total_io_operations: num,
94                ..Default::default()
95            })
96            .collect();
97
98        let member_fn = |s: &QueryStatistics| s.total_io_operations;
99
100        let result = get_percentile_stats(&query_stats, percentile, member_fn);
101
102        let index =
103            ((percentile * random_numbers.len() as f32) as usize).min(random_numbers.len() - 1);
104        random_numbers.sort_unstable();
105        let expected_result: u32 = random_numbers[index];
106
107        assert_eq!(result, expected_result);
108    }
109
110    #[test]
111    fn test_get_mean_stats() {
112        let numbers = [1, 2, 3, 4, 5];
113
114        let query_stats: Vec<QueryStatistics> = numbers
115            .iter()
116            .map(|&num| QueryStatistics {
117                total_io_operations: num,
118                ..Default::default()
119            })
120            .collect();
121
122        let member_fn = |s: &QueryStatistics| s.total_io_operations;
123        let result = get_mean_stats(&query_stats, member_fn);
124
125        let expected_result: f64 = 3.0; // (1 + 2 + 3 + 4 + 5) / 5 = 3
126
127        assert!((result - expected_result).abs() <= 1e-3);
128    }
129
130    #[test]
131    fn test_get_sum_stats() {
132        let numbers = [1, 2, 3, 4, 5];
133
134        let query_stats: Vec<QueryStatistics> = numbers
135            .iter()
136            .map(|&num| QueryStatistics {
137                total_io_operations: num,
138                ..Default::default()
139            })
140            .collect();
141
142        let member_fn = |s: &QueryStatistics| s.total_io_operations;
143        let result = get_sum_stats(&query_stats, member_fn);
144
145        let expected_result: f64 = 15.0; // 1 + 2 + 3 + 4 + 5 = 15
146
147        assert!((result - expected_result).abs() <= 1e-3);
148    }
149}