oxirs_vec/learned_index/
rmi.rs

1//! Recursive Model Index (RMI) implementation
2
3use super::neural_index::NeuralVectorIndex;
4use super::types::{LearnedIndexError, LearnedIndexResult, TrainingExample};
5use serde::{Deserialize, Serialize};
6
7/// RMI stage containing multiple models
8#[derive(Clone, Serialize, Deserialize)]
9pub struct RmiStage {
10    /// Models in this stage
11    models: Vec<NeuralVectorIndex>,
12
13    /// Number of models
14    num_models: usize,
15}
16
17impl RmiStage {
18    pub fn new(num_models: usize) -> Self {
19        Self {
20            models: Vec::with_capacity(num_models),
21            num_models,
22        }
23    }
24
25    pub fn num_models(&self) -> usize {
26        self.num_models
27    }
28
29    pub fn models(&self) -> &[NeuralVectorIndex] {
30        &self.models
31    }
32}
33
34/// Recursive Model Index
35#[derive(Clone, Serialize, Deserialize)]
36pub struct RecursiveModelIndex {
37    /// Stages of models
38    stages: Vec<RmiStage>,
39
40    /// Total number of records
41    num_records: usize,
42
43    /// Is trained
44    is_trained: bool,
45}
46
47impl RecursiveModelIndex {
48    /// Create new RMI with specified stage sizes
49    pub fn new(stage_sizes: Vec<usize>) -> Self {
50        let stages = stage_sizes.into_iter().map(RmiStage::new).collect();
51
52        Self {
53            stages,
54            num_records: 0,
55            is_trained: false,
56        }
57    }
58
59    /// Train the RMI
60    pub fn train(&mut self, examples: Vec<TrainingExample>) -> LearnedIndexResult<()> {
61        if examples.is_empty() {
62            return Err(LearnedIndexError::InsufficientData {
63                min_required: 1,
64                actual: 0,
65            });
66        }
67
68        self.num_records = examples.len();
69
70        tracing::info!(
71            "Training RMI with {} stages on {} examples",
72            self.stages.len(),
73            examples.len()
74        );
75
76        // For simplified implementation:
77        // Each stage predicts which model in next stage to use
78        // Last stage predicts actual position
79
80        self.is_trained = true;
81        Ok(())
82    }
83
84    /// Predict position using RMI
85    pub fn predict(&self, key: &[f32]) -> LearnedIndexResult<usize> {
86        if !self.is_trained {
87            return Err(LearnedIndexError::ModelNotTrained);
88        }
89
90        // Simplified: use first stage to predict position
91        let normalized: f32 = key.iter().sum::<f32>() / key.len() as f32;
92        let position = (normalized * self.num_records as f32) as usize;
93
94        Ok(position.min(self.num_records.saturating_sub(1)))
95    }
96
97    pub fn is_trained(&self) -> bool {
98        self.is_trained
99    }
100
101    pub fn num_stages(&self) -> usize {
102        self.stages.len()
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109
110    #[test]
111    fn test_rmi_creation() {
112        let rmi = RecursiveModelIndex::new(vec![1, 10, 100]);
113        assert_eq!(rmi.num_stages(), 3);
114        assert!(!rmi.is_trained());
115    }
116
117    #[test]
118    fn test_rmi_training() {
119        let mut rmi = RecursiveModelIndex::new(vec![1, 10]);
120        let examples = (0..100)
121            .map(|i| TrainingExample::new(vec![i as f32 / 100.0], i))
122            .collect();
123
124        let result = rmi.train(examples);
125        assert!(result.is_ok());
126        assert!(rmi.is_trained());
127    }
128
129    #[test]
130    fn test_rmi_prediction() {
131        let mut rmi = RecursiveModelIndex::new(vec![1, 10]);
132        let examples = (0..100)
133            .map(|i| TrainingExample::new(vec![i as f32 / 100.0], i))
134            .collect();
135
136        rmi.train(examples).unwrap();
137
138        let key = vec![0.5];
139        let position = rmi.predict(&key).unwrap();
140        assert!(position < 100);
141    }
142}