oxirs_vec/learned_index/
types.rs

1//! Core types for learned indexes
2
3use serde::{Deserialize, Serialize};
4use thiserror::Error;
5
6/// Result type for learned index operations
7pub type LearnedIndexResult<T> = std::result::Result<T, LearnedIndexError>;
8
9/// Errors for learned index operations
10#[derive(Debug, Error, Clone, Serialize, Deserialize)]
11pub enum LearnedIndexError {
12    #[error("Model not trained")]
13    ModelNotTrained,
14
15    #[error("Training failed: {message}")]
16    TrainingFailed { message: String },
17
18    #[error("Prediction out of bounds: predicted={predicted}, actual_size={actual_size}")]
19    PredictionOutOfBounds {
20        predicted: usize,
21        actual_size: usize,
22    },
23
24    #[error("Invalid configuration: {message}")]
25    InvalidConfiguration { message: String },
26
27    #[error("Insufficient data: need at least {min_required}, got {actual}")]
28    InsufficientData { min_required: usize, actual: usize },
29
30    #[error("Internal error: {message}")]
31    InternalError { message: String },
32}
33
34/// Prediction bounds for error correction
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct PredictionBounds {
37    /// Predicted position
38    pub predicted: usize,
39
40    /// Lower bound (min error)
41    pub lower_bound: usize,
42
43    /// Upper bound (max error)
44    pub upper_bound: usize,
45
46    /// Error magnitude
47    pub error_magnitude: usize,
48
49    /// Confidence score (0.0 to 1.0)
50    pub confidence: f32,
51}
52
53impl PredictionBounds {
54    /// Create new prediction bounds
55    pub fn new(predicted: usize, lower: usize, upper: usize, confidence: f32) -> Self {
56        let error_magnitude = upper.saturating_sub(lower);
57        Self {
58            predicted,
59            lower_bound: lower,
60            upper_bound: upper,
61            error_magnitude,
62            confidence: confidence.clamp(0.0, 1.0),
63        }
64    }
65
66    /// Get search range for binary search fallback
67    pub fn search_range(&self) -> std::ops::Range<usize> {
68        self.lower_bound..self.upper_bound
69    }
70}
71
72/// Training example for learned index
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct TrainingExample {
75    /// Input features (e.g., vector or key)
76    pub features: Vec<f32>,
77
78    /// Target position in sorted order
79    pub target_position: usize,
80
81    /// Optional weight for importance sampling
82    pub weight: f32,
83}
84
85impl TrainingExample {
86    pub fn new(features: Vec<f32>, target_position: usize) -> Self {
87        Self {
88            features,
89            target_position,
90            weight: 1.0,
91        }
92    }
93
94    pub fn with_weight(mut self, weight: f32) -> Self {
95        self.weight = weight;
96        self
97    }
98}
99
100/// Statistics for learned index performance
101#[derive(Debug, Clone, Default, Serialize, Deserialize)]
102pub struct IndexStatistics {
103    /// Total predictions made
104    pub total_predictions: usize,
105
106    /// Predictions within error bounds
107    pub predictions_within_bounds: usize,
108
109    /// Average prediction error
110    pub avg_prediction_error: f64,
111
112    /// Max prediction error observed
113    pub max_prediction_error: usize,
114
115    /// Average search range size
116    pub avg_search_range_size: f64,
117
118    /// Total lookups performed
119    pub total_lookups: usize,
120
121    /// Average lookup time (microseconds)
122    pub avg_lookup_time_us: f64,
123}
124
125impl IndexStatistics {
126    pub fn new() -> Self {
127        Self::default()
128    }
129
130    pub fn record_prediction(&mut self, predicted: usize, actual: usize, within_bounds: bool) {
131        self.total_predictions += 1;
132
133        if within_bounds {
134            self.predictions_within_bounds += 1;
135        }
136
137        let error = predicted.abs_diff(actual);
138
139        // Update average error
140        let n = self.total_predictions as f64;
141        self.avg_prediction_error = (self.avg_prediction_error * (n - 1.0) + error as f64) / n;
142
143        // Update max error
144        if error > self.max_prediction_error {
145            self.max_prediction_error = error;
146        }
147    }
148
149    pub fn record_lookup(&mut self, search_range_size: usize, lookup_time_us: f64) {
150        self.total_lookups += 1;
151
152        let n = self.total_lookups as f64;
153        self.avg_search_range_size =
154            (self.avg_search_range_size * (n - 1.0) + search_range_size as f64) / n;
155
156        self.avg_lookup_time_us = (self.avg_lookup_time_us * (n - 1.0) + lookup_time_us) / n;
157    }
158
159    pub fn accuracy(&self) -> f64 {
160        if self.total_predictions == 0 {
161            0.0
162        } else {
163            self.predictions_within_bounds as f64 / self.total_predictions as f64
164        }
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171
172    #[test]
173    fn test_prediction_bounds() {
174        let bounds = PredictionBounds::new(100, 95, 105, 0.9);
175
176        assert_eq!(bounds.predicted, 100);
177        assert_eq!(bounds.lower_bound, 95);
178        assert_eq!(bounds.upper_bound, 105);
179        assert_eq!(bounds.error_magnitude, 10);
180        assert_eq!(bounds.confidence, 0.9);
181
182        let range = bounds.search_range();
183        assert_eq!(range.start, 95);
184        assert_eq!(range.end, 105);
185    }
186
187    #[test]
188    fn test_training_example() {
189        let example = TrainingExample::new(vec![1.0, 2.0, 3.0], 42).with_weight(0.5);
190
191        assert_eq!(example.features, vec![1.0, 2.0, 3.0]);
192        assert_eq!(example.target_position, 42);
193        assert_eq!(example.weight, 0.5);
194    }
195
196    #[test]
197    fn test_index_statistics() {
198        let mut stats = IndexStatistics::new();
199
200        stats.record_prediction(100, 102, true);
201        stats.record_prediction(200, 195, true);
202        stats.record_prediction(300, 250, false);
203
204        assert_eq!(stats.total_predictions, 3);
205        assert_eq!(stats.predictions_within_bounds, 2);
206        assert!(stats.avg_prediction_error > 0.0);
207        assert_eq!(stats.max_prediction_error, 50);
208        assert!((stats.accuracy() - 0.666).abs() < 0.01);
209    }
210
211    #[test]
212    fn test_lookup_statistics() {
213        let mut stats = IndexStatistics::new();
214
215        stats.record_lookup(10, 5.0);
216        stats.record_lookup(20, 10.0);
217
218        assert_eq!(stats.total_lookups, 2);
219        assert_eq!(stats.avg_search_range_size, 15.0);
220        assert_eq!(stats.avg_lookup_time_us, 7.5);
221    }
222}