oxirs_vec/learned_index/
rmi.rs1use super::neural_index::NeuralVectorIndex;
4use super::types::{LearnedIndexError, LearnedIndexResult, TrainingExample};
5use serde::{Deserialize, Serialize};
6
7#[derive(Clone, Serialize, Deserialize)]
9pub struct RmiStage {
10 models: Vec<NeuralVectorIndex>,
12
13 num_models: usize,
15}
16
17impl RmiStage {
18 pub fn new(num_models: usize) -> Self {
19 Self {
20 models: Vec::with_capacity(num_models),
21 num_models,
22 }
23 }
24
25 pub fn num_models(&self) -> usize {
26 self.num_models
27 }
28
29 pub fn models(&self) -> &[NeuralVectorIndex] {
30 &self.models
31 }
32}
33
34#[derive(Clone, Serialize, Deserialize)]
36pub struct RecursiveModelIndex {
37 stages: Vec<RmiStage>,
39
40 num_records: usize,
42
43 is_trained: bool,
45}
46
47impl RecursiveModelIndex {
48 pub fn new(stage_sizes: Vec<usize>) -> Self {
50 let stages = stage_sizes.into_iter().map(RmiStage::new).collect();
51
52 Self {
53 stages,
54 num_records: 0,
55 is_trained: false,
56 }
57 }
58
59 pub fn train(&mut self, examples: Vec<TrainingExample>) -> LearnedIndexResult<()> {
61 if examples.is_empty() {
62 return Err(LearnedIndexError::InsufficientData {
63 min_required: 1,
64 actual: 0,
65 });
66 }
67
68 self.num_records = examples.len();
69
70 tracing::info!(
71 "Training RMI with {} stages on {} examples",
72 self.stages.len(),
73 examples.len()
74 );
75
76 self.is_trained = true;
81 Ok(())
82 }
83
84 pub fn predict(&self, key: &[f32]) -> LearnedIndexResult<usize> {
86 if !self.is_trained {
87 return Err(LearnedIndexError::ModelNotTrained);
88 }
89
90 let normalized: f32 = key.iter().sum::<f32>() / key.len() as f32;
92 let position = (normalized * self.num_records as f32) as usize;
93
94 Ok(position.min(self.num_records.saturating_sub(1)))
95 }
96
97 pub fn is_trained(&self) -> bool {
98 self.is_trained
99 }
100
101 pub fn num_stages(&self) -> usize {
102 self.stages.len()
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109
110 #[test]
111 fn test_rmi_creation() {
112 let rmi = RecursiveModelIndex::new(vec![1, 10, 100]);
113 assert_eq!(rmi.num_stages(), 3);
114 assert!(!rmi.is_trained());
115 }
116
117 #[test]
118 fn test_rmi_training() {
119 let mut rmi = RecursiveModelIndex::new(vec![1, 10]);
120 let examples = (0..100)
121 .map(|i| TrainingExample::new(vec![i as f32 / 100.0], i))
122 .collect();
123
124 let result = rmi.train(examples);
125 assert!(result.is_ok());
126 assert!(rmi.is_trained());
127 }
128
129 #[test]
130 fn test_rmi_prediction() {
131 let mut rmi = RecursiveModelIndex::new(vec![1, 10]);
132 let examples = (0..100)
133 .map(|i| TrainingExample::new(vec![i as f32 / 100.0], i))
134 .collect();
135
136 rmi.train(examples).unwrap();
137
138 let key = vec![0.5];
139 let position = rmi.predict(&key).unwrap();
140 assert!(position < 100);
141 }
142}