1use serde::{Deserialize, Serialize};
4use thiserror::Error;
5
6pub type LearnedIndexResult<T> = std::result::Result<T, LearnedIndexError>;
8
9#[derive(Debug, Error, Clone, Serialize, Deserialize)]
11pub enum LearnedIndexError {
12 #[error("Model not trained")]
13 ModelNotTrained,
14
15 #[error("Training failed: {message}")]
16 TrainingFailed { message: String },
17
18 #[error("Prediction out of bounds: predicted={predicted}, actual_size={actual_size}")]
19 PredictionOutOfBounds {
20 predicted: usize,
21 actual_size: usize,
22 },
23
24 #[error("Invalid configuration: {message}")]
25 InvalidConfiguration { message: String },
26
27 #[error("Insufficient data: need at least {min_required}, got {actual}")]
28 InsufficientData { min_required: usize, actual: usize },
29
30 #[error("Internal error: {message}")]
31 InternalError { message: String },
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct PredictionBounds {
37 pub predicted: usize,
39
40 pub lower_bound: usize,
42
43 pub upper_bound: usize,
45
46 pub error_magnitude: usize,
48
49 pub confidence: f32,
51}
52
53impl PredictionBounds {
54 pub fn new(predicted: usize, lower: usize, upper: usize, confidence: f32) -> Self {
56 let error_magnitude = upper.saturating_sub(lower);
57 Self {
58 predicted,
59 lower_bound: lower,
60 upper_bound: upper,
61 error_magnitude,
62 confidence: confidence.clamp(0.0, 1.0),
63 }
64 }
65
66 pub fn search_range(&self) -> std::ops::Range<usize> {
68 self.lower_bound..self.upper_bound
69 }
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct TrainingExample {
75 pub features: Vec<f32>,
77
78 pub target_position: usize,
80
81 pub weight: f32,
83}
84
85impl TrainingExample {
86 pub fn new(features: Vec<f32>, target_position: usize) -> Self {
87 Self {
88 features,
89 target_position,
90 weight: 1.0,
91 }
92 }
93
94 pub fn with_weight(mut self, weight: f32) -> Self {
95 self.weight = weight;
96 self
97 }
98}
99
100#[derive(Debug, Clone, Default, Serialize, Deserialize)]
102pub struct IndexStatistics {
103 pub total_predictions: usize,
105
106 pub predictions_within_bounds: usize,
108
109 pub avg_prediction_error: f64,
111
112 pub max_prediction_error: usize,
114
115 pub avg_search_range_size: f64,
117
118 pub total_lookups: usize,
120
121 pub avg_lookup_time_us: f64,
123}
124
125impl IndexStatistics {
126 pub fn new() -> Self {
127 Self::default()
128 }
129
130 pub fn record_prediction(&mut self, predicted: usize, actual: usize, within_bounds: bool) {
131 self.total_predictions += 1;
132
133 if within_bounds {
134 self.predictions_within_bounds += 1;
135 }
136
137 let error = predicted.abs_diff(actual);
138
139 let n = self.total_predictions as f64;
141 self.avg_prediction_error = (self.avg_prediction_error * (n - 1.0) + error as f64) / n;
142
143 if error > self.max_prediction_error {
145 self.max_prediction_error = error;
146 }
147 }
148
149 pub fn record_lookup(&mut self, search_range_size: usize, lookup_time_us: f64) {
150 self.total_lookups += 1;
151
152 let n = self.total_lookups as f64;
153 self.avg_search_range_size =
154 (self.avg_search_range_size * (n - 1.0) + search_range_size as f64) / n;
155
156 self.avg_lookup_time_us = (self.avg_lookup_time_us * (n - 1.0) + lookup_time_us) / n;
157 }
158
159 pub fn accuracy(&self) -> f64 {
160 if self.total_predictions == 0 {
161 0.0
162 } else {
163 self.predictions_within_bounds as f64 / self.total_predictions as f64
164 }
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171
172 #[test]
173 fn test_prediction_bounds() {
174 let bounds = PredictionBounds::new(100, 95, 105, 0.9);
175
176 assert_eq!(bounds.predicted, 100);
177 assert_eq!(bounds.lower_bound, 95);
178 assert_eq!(bounds.upper_bound, 105);
179 assert_eq!(bounds.error_magnitude, 10);
180 assert_eq!(bounds.confidence, 0.9);
181
182 let range = bounds.search_range();
183 assert_eq!(range.start, 95);
184 assert_eq!(range.end, 105);
185 }
186
187 #[test]
188 fn test_training_example() {
189 let example = TrainingExample::new(vec![1.0, 2.0, 3.0], 42).with_weight(0.5);
190
191 assert_eq!(example.features, vec![1.0, 2.0, 3.0]);
192 assert_eq!(example.target_position, 42);
193 assert_eq!(example.weight, 0.5);
194 }
195
196 #[test]
197 fn test_index_statistics() {
198 let mut stats = IndexStatistics::new();
199
200 stats.record_prediction(100, 102, true);
201 stats.record_prediction(200, 195, true);
202 stats.record_prediction(300, 250, false);
203
204 assert_eq!(stats.total_predictions, 3);
205 assert_eq!(stats.predictions_within_bounds, 2);
206 assert!(stats.avg_prediction_error > 0.0);
207 assert_eq!(stats.max_prediction_error, 50);
208 assert!((stats.accuracy() - 0.666).abs() < 0.01);
209 }
210
211 #[test]
212 fn test_lookup_statistics() {
213 let mut stats = IndexStatistics::new();
214
215 stats.record_lookup(10, 5.0);
216 stats.record_lookup(20, 10.0);
217
218 assert_eq!(stats.total_lookups, 2);
219 assert_eq!(stats.avg_search_range_size, 15.0);
220 assert_eq!(stats.avg_lookup_time_us, 7.5);
221 }
222}