ipfrs_semantic/
learned.rs

1//! Learned index structures using ML models for data indexing.
2//!
3//! This module implements learned indices, which use machine learning models
4//! to predict the position of data in the index, replacing traditional index
5//! structures like B-trees with neural networks or linear models.
6//!
7//! # Architecture
8//!
9//! The implementation uses a Recursive Model Index (RMI) architecture:
10//! - Stage 0: Root model that routes to second-stage models
11//! - Stage 1: Multiple specialized models for different data ranges
12//! - Each model learns to predict positions in the sorted data
13//!
14//! # Example
15//!
16//! ```
17//! use ipfrs_semantic::learned::{LearnedIndex, RMIConfig};
18//! use ipfrs_core::cid::Cid;
19//!
20//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
21//! // Create a learned index with default configuration
22//! let mut index = LearnedIndex::new(RMIConfig::default());
23//!
24//! // Add embeddings with their CIDs
25//! let cid = Cid::default();
26//! let embedding = vec![0.1, 0.2, 0.3, 0.4];
27//! index.add(cid.clone(), embedding.clone())?;
28//!
29//! // Search for nearest neighbors
30//! let query = vec![0.15, 0.25, 0.35, 0.45];
31//! let results = index.search(&query, 5)?;
32//! # Ok(())
33//! # }
34//! ```
35
36use ipfrs_core::{Cid, Error, Result};
37use serde::{Deserialize, Serialize};
38
39/// Configuration for Recursive Model Index (RMI)
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct RMIConfig {
42    /// Number of models in the second stage
43    pub num_models: usize,
44    /// Model type to use
45    pub model_type: ModelType,
46    /// Training iterations for neural models
47    pub training_iterations: usize,
48    /// Learning rate for neural models
49    pub learning_rate: f32,
50    /// Error threshold for adaptive model selection
51    pub error_threshold: f32,
52}
53
54impl Default for RMIConfig {
55    fn default() -> Self {
56        Self {
57            num_models: 10,
58            model_type: ModelType::Linear,
59            training_iterations: 100,
60            learning_rate: 0.01,
61            error_threshold: 0.05,
62        }
63    }
64}
65
66/// Type of model to use in the learned index
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
68pub enum ModelType {
69    /// Linear regression model
70    Linear,
71    /// Simple neural network (single hidden layer)
72    NeuralNetwork,
73    /// Polynomial regression (degree 2)
74    Polynomial,
75}
76
77/// A single learned model that predicts positions
78#[derive(Debug, Clone, Serialize, Deserialize)]
79struct Model {
80    /// Model type
81    model_type: ModelType,
82    /// Model weights (interpretation depends on model_type)
83    weights: Vec<f32>,
84    /// Bias term
85    bias: f32,
86    /// Input dimension
87    input_dim: usize,
88}
89
90impl Model {
91    /// Create a new model with random initialization
92    fn new(model_type: ModelType, input_dim: usize) -> Self {
93        let weight_count = match model_type {
94            ModelType::Linear => input_dim,
95            ModelType::Polynomial => input_dim * 2, // Linear + quadratic terms
96            ModelType::NeuralNetwork => input_dim * 8 + 8, // Hidden layer + output layer
97        };
98
99        Self {
100            model_type,
101            weights: vec![0.01; weight_count],
102            bias: 0.0,
103            input_dim,
104        }
105    }
106
107    /// Predict position for given input (normalized 0-1)
108    fn predict(&self, input: &[f32]) -> f32 {
109        match self.model_type {
110            ModelType::Linear => self.predict_linear(input),
111            ModelType::Polynomial => self.predict_polynomial(input),
112            ModelType::NeuralNetwork => self.predict_neural(input),
113        }
114    }
115
116    fn predict_linear(&self, input: &[f32]) -> f32 {
117        let mut sum = self.bias;
118        for (i, &val) in input.iter().enumerate() {
119            if i < self.weights.len() {
120                sum += self.weights[i] * val;
121            }
122        }
123        sum.clamp(0.0, 1.0)
124    }
125
126    fn predict_polynomial(&self, input: &[f32]) -> f32 {
127        let mut sum = self.bias;
128        let half = self.weights.len() / 2;
129
130        // Linear terms
131        for (i, &val) in input.iter().enumerate() {
132            if i < half {
133                sum += self.weights[i] * val;
134            }
135        }
136
137        // Quadratic terms
138        for (i, &val) in input.iter().enumerate() {
139            if half + i < self.weights.len() {
140                sum += self.weights[half + i] * val * val;
141            }
142        }
143
144        sum.clamp(0.0, 1.0)
145    }
146
147    fn predict_neural(&self, input: &[f32]) -> f32 {
148        let hidden_size = 8;
149        let input_weights = &self.weights[0..self.input_dim * hidden_size];
150        let output_weights = &self.weights[self.input_dim * hidden_size..];
151
152        // Hidden layer with ReLU activation
153        let mut hidden = vec![0.0; hidden_size];
154        for h in 0..hidden_size {
155            let mut sum = 0.0;
156            for (i, &val) in input.iter().enumerate() {
157                if h * self.input_dim + i < input_weights.len() {
158                    sum += input_weights[h * self.input_dim + i] * val;
159                }
160            }
161            hidden[h] = sum.max(0.0); // ReLU
162        }
163
164        // Output layer with sigmoid
165        let mut output = self.bias;
166        for (h, &val) in hidden.iter().enumerate() {
167            if h < output_weights.len() {
168                output += output_weights[h] * val;
169            }
170        }
171
172        // Sigmoid activation
173        1.0 / (1.0 + (-output).exp())
174    }
175
176    /// Train the model on data (simple gradient descent)
177    #[allow(dead_code)]
178    fn train(&mut self, data: &[(Vec<f32>, f32)], learning_rate: f32, iterations: usize) {
179        for _ in 0..iterations {
180            for (input, target) in data {
181                let prediction = self.predict(input);
182                let error = target - prediction;
183
184                // Update weights (simplified gradient descent)
185                match self.model_type {
186                    ModelType::Linear => {
187                        for (i, &val) in input.iter().enumerate() {
188                            if i < self.weights.len() {
189                                self.weights[i] += learning_rate * error * val;
190                            }
191                        }
192                        self.bias += learning_rate * error;
193                    }
194                    ModelType::Polynomial => {
195                        let half = self.weights.len() / 2;
196                        for (i, &val) in input.iter().enumerate() {
197                            if i < half {
198                                self.weights[i] += learning_rate * error * val;
199                            }
200                            if half + i < self.weights.len() {
201                                self.weights[half + i] += learning_rate * error * val * val;
202                            }
203                        }
204                        self.bias += learning_rate * error;
205                    }
206                    ModelType::NeuralNetwork => {
207                        // Simplified backprop (full implementation would be more complex)
208                        for i in 0..self.weights.len() {
209                            self.weights[i] += learning_rate * error * 0.01;
210                        }
211                        self.bias += learning_rate * error;
212                    }
213                }
214            }
215        }
216    }
217}
218
219/// Recursive Model Index (RMI) for learned indexing
220pub struct LearnedIndex {
221    /// Configuration
222    config: RMIConfig,
223    /// Root model (stage 0)
224    root_model: Option<Model>,
225    /// Second stage models
226    stage1_models: Vec<Model>,
227    /// Sorted data storage (CID, embedding, position)
228    data: Vec<(Cid, Vec<f32>)>,
229    /// Dimension of embeddings
230    dimension: Option<usize>,
231    /// Statistics
232    stats: IndexStats,
233}
234
235#[derive(Debug, Default)]
236struct IndexStats {
237    /// Number of searches performed
238    searches: usize,
239    /// Total prediction error
240    total_error: f32,
241    /// Number of data points
242    data_points: usize,
243}
244
245impl LearnedIndex {
246    /// Create a new learned index
247    pub fn new(config: RMIConfig) -> Self {
248        Self {
249            config,
250            root_model: None,
251            stage1_models: Vec::new(),
252            data: Vec::new(),
253            dimension: None,
254            stats: IndexStats::default(),
255        }
256    }
257
258    /// Add an embedding to the index
259    pub fn add(&mut self, cid: Cid, embedding: Vec<f32>) -> Result<()> {
260        if let Some(dim) = self.dimension {
261            if embedding.len() != dim {
262                return Err(Error::InvalidInput(format!(
263                    "Dimension mismatch: expected {}, got {}",
264                    dim,
265                    embedding.len()
266                )));
267            }
268        } else {
269            self.dimension = Some(embedding.len());
270        }
271
272        self.data.push((cid, embedding));
273        self.stats.data_points += 1;
274
275        // Rebuild index if we have enough data
276        if self.data.len().is_multiple_of(100) {
277            self.rebuild()?;
278        }
279
280        Ok(())
281    }
282
283    /// Rebuild the learned index from scratch
284    pub fn rebuild(&mut self) -> Result<()> {
285        if self.data.is_empty() {
286            return Ok(());
287        }
288
289        let dim = self
290            .dimension
291            .ok_or_else(|| Error::InvalidInput("No dimension set".to_string()))?;
292
293        // Sort data by first dimension (simple heuristic)
294        self.data.sort_by(|a, b| {
295            a.1[0]
296                .partial_cmp(&b.1[0])
297                .unwrap_or(std::cmp::Ordering::Equal)
298        });
299
300        // Initialize models
301        self.root_model = Some(Model::new(self.config.model_type, dim));
302        self.stage1_models = (0..self.config.num_models)
303            .map(|_| Model::new(self.config.model_type, dim))
304            .collect();
305
306        // Train models (simplified - real implementation would use proper training)
307        self.train_models()?;
308
309        Ok(())
310    }
311
312    fn train_models(&mut self) -> Result<()> {
313        if self.data.is_empty() {
314            return Ok(());
315        }
316
317        let n = self.data.len();
318
319        // Prepare training data for root model
320        let mut root_training_data = Vec::new();
321        for (i, (_cid, embedding)) in self.data.iter().enumerate() {
322            let normalized_pos = i as f32 / n as f32;
323            let normalized_embedding = self.normalize_embedding(embedding);
324            root_training_data.push((normalized_embedding, normalized_pos));
325        }
326
327        // Train root model
328        if let Some(ref mut root) = self.root_model {
329            root.train(
330                &root_training_data,
331                self.config.learning_rate,
332                self.config.training_iterations,
333            );
334        }
335
336        // Train stage 1 models (each responsible for a range)
337        let chunk_size = n / self.config.num_models;
338
339        // First, collect all training data for all models
340        let mut all_model_training_data = Vec::new();
341        for model_idx in 0..self.config.num_models {
342            let start = model_idx * chunk_size;
343            let end = if model_idx == self.config.num_models - 1 {
344                n
345            } else {
346                (model_idx + 1) * chunk_size
347            };
348
349            let mut model_training_data = Vec::new();
350            for i in start..end {
351                if let Some((_cid, embedding)) = self.data.get(i) {
352                    let local_pos = (i - start) as f32 / (end - start) as f32;
353                    let normalized_embedding = self.normalize_embedding(embedding);
354                    model_training_data.push((normalized_embedding, local_pos));
355                }
356            }
357            all_model_training_data.push(model_training_data);
358        }
359
360        // Now train all models with their respective data
361        for (model, training_data) in self
362            .stage1_models
363            .iter_mut()
364            .zip(all_model_training_data.iter())
365        {
366            if !training_data.is_empty() {
367                model.train(
368                    training_data,
369                    self.config.learning_rate,
370                    self.config.training_iterations,
371                );
372            }
373        }
374
375        Ok(())
376    }
377
378    fn normalize_embedding(&self, embedding: &[f32]) -> Vec<f32> {
379        // Simple min-max normalization to [0, 1]
380        let min = embedding.iter().cloned().fold(f32::INFINITY, f32::min);
381        let max = embedding.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
382        let range = max - min;
383
384        if range > 1e-6 {
385            embedding.iter().map(|&x| (x - min) / range).collect()
386        } else {
387            vec![0.5; embedding.len()]
388        }
389    }
390
391    /// Search for k nearest neighbors
392    pub fn search(&mut self, query: &[f32], k: usize) -> Result<Vec<(Cid, f32)>> {
393        if self.data.is_empty() {
394            return Ok(Vec::new());
395        }
396
397        let dim = self
398            .dimension
399            .ok_or_else(|| Error::InvalidInput("No dimension set".to_string()))?;
400
401        if query.len() != dim {
402            return Err(Error::InvalidInput(format!(
403                "Dimension mismatch: expected {}, got {}",
404                dim,
405                query.len()
406            )));
407        }
408
409        // Rebuild index if not built yet
410        if self.root_model.is_none() {
411            self.rebuild()?;
412        }
413
414        self.stats.searches += 1;
415
416        // Use learned index to predict position
417        let predicted_pos = self.predict_position(query)?;
418        let n = self.data.len();
419        let start_idx = (predicted_pos * n as f32) as usize;
420
421        // Search around predicted position (adaptive window)
422        let window_size = (n as f32 * self.config.error_threshold).max(k as f32 * 2.0) as usize;
423        let search_start = start_idx.saturating_sub(window_size / 2);
424        let search_end = (start_idx + window_size / 2).min(n);
425
426        // Find k nearest neighbors in the search window
427        let mut candidates = Vec::new();
428        for i in search_start..search_end {
429            if let Some((cid, embedding)) = self.data.get(i) {
430                let distance = self.compute_distance(query, embedding);
431                candidates.push((*cid, distance));
432            }
433        }
434
435        // Sort by distance and return top k
436        candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
437        Ok(candidates.into_iter().take(k).collect())
438    }
439
440    fn predict_position(&mut self, query: &[f32]) -> Result<f32> {
441        let normalized_query = self.normalize_embedding(query);
442
443        // Stage 0: Root model predicts which stage 1 model to use
444        let root_prediction = if let Some(ref root) = self.root_model {
445            root.predict(&normalized_query)
446        } else {
447            return Err(Error::InvalidInput("No root model".to_string()));
448        };
449
450        // Select stage 1 model
451        let model_idx = ((root_prediction * self.config.num_models as f32) as usize)
452            .min(self.config.num_models - 1);
453
454        // Stage 1: Selected model predicts position within its range
455        let local_prediction = if let Some(model) = self.stage1_models.get(model_idx) {
456            model.predict(&normalized_query)
457        } else {
458            0.5
459        };
460
461        // Combine predictions
462        let chunk_size = 1.0 / self.config.num_models as f32;
463        let final_prediction = model_idx as f32 * chunk_size + local_prediction * chunk_size;
464
465        Ok(final_prediction.clamp(0.0, 1.0))
466    }
467
468    fn compute_distance(&self, a: &[f32], b: &[f32]) -> f32 {
469        // L2 distance
470        a.iter()
471            .zip(b.iter())
472            .map(|(x, y)| (x - y) * (x - y))
473            .sum::<f32>()
474            .sqrt()
475    }
476
477    /// Get index statistics
478    pub fn stats(&self) -> LearnedIndexStats {
479        LearnedIndexStats {
480            data_points: self.stats.data_points,
481            searches: self.stats.searches,
482            num_models: self.stage1_models.len() + 1,
483            avg_error: if self.stats.searches > 0 {
484                self.stats.total_error / self.stats.searches as f32
485            } else {
486                0.0
487            },
488        }
489    }
490
491    /// Get number of data points
492    pub fn size(&self) -> usize {
493        self.data.len()
494    }
495
496    /// Clear the index
497    pub fn clear(&mut self) {
498        self.data.clear();
499        self.root_model = None;
500        self.stage1_models.clear();
501        self.stats = IndexStats::default();
502    }
503}
504
505/// Statistics for the learned index
506#[derive(Debug, Clone)]
507pub struct LearnedIndexStats {
508    /// Number of data points indexed
509    pub data_points: usize,
510    /// Number of searches performed
511    pub searches: usize,
512    /// Total number of models (root + stage 1)
513    pub num_models: usize,
514    /// Average prediction error
515    pub avg_error: f32,
516}
517
518#[cfg(test)]
519mod tests {
520    use super::*;
521
522    #[test]
523    fn test_learned_index_creation() {
524        let index = LearnedIndex::new(RMIConfig::default());
525        assert_eq!(index.size(), 0);
526    }
527
528    #[test]
529    fn test_add_and_search() {
530        let mut index = LearnedIndex::new(RMIConfig::default());
531
532        // Add some embeddings
533        for i in 0..100 {
534            let cid = Cid::default();
535            let embedding = vec![i as f32 / 100.0, 0.5, 0.5, 0.5];
536            index.add(cid, embedding).unwrap();
537        }
538
539        assert_eq!(index.size(), 100);
540
541        // Search
542        let query = vec![0.5, 0.5, 0.5, 0.5];
543        let results = index.search(&query, 5).unwrap();
544        assert_eq!(results.len(), 5);
545    }
546
547    #[test]
548    fn test_model_prediction() {
549        let model = Model::new(ModelType::Linear, 4);
550        let input = vec![0.1, 0.2, 0.3, 0.4];
551        let prediction = model.predict(&input);
552        assert!((0.0..=1.0).contains(&prediction));
553    }
554
555    #[test]
556    fn test_polynomial_model() {
557        let model = Model::new(ModelType::Polynomial, 4);
558        let input = vec![0.5, 0.5, 0.5, 0.5];
559        let prediction = model.predict(&input);
560        assert!((0.0..=1.0).contains(&prediction));
561    }
562
563    #[test]
564    fn test_neural_model() {
565        let model = Model::new(ModelType::NeuralNetwork, 4);
566        let input = vec![0.3, 0.4, 0.5, 0.6];
567        let prediction = model.predict(&input);
568        assert!((0.0..=1.0).contains(&prediction));
569    }
570
571    #[test]
572    fn test_dimension_mismatch() {
573        let mut index = LearnedIndex::new(RMIConfig::default());
574
575        let cid1 = Cid::default();
576        index.add(cid1, vec![1.0, 2.0, 3.0]).unwrap();
577
578        let cid2 = Cid::default();
579        let result = index.add(cid2, vec![1.0, 2.0]);
580        assert!(result.is_err());
581    }
582
583    #[test]
584    fn test_rebuild_index() {
585        let mut index = LearnedIndex::new(RMIConfig::default());
586
587        for i in 0..50 {
588            let cid = Cid::default();
589            let embedding = vec![i as f32, 0.0, 0.0];
590            index.add(cid, embedding).unwrap();
591        }
592
593        index.rebuild().unwrap();
594
595        let query = vec![25.0, 0.0, 0.0];
596        let results = index.search(&query, 3).unwrap();
597        assert_eq!(results.len(), 3);
598    }
599
600    #[test]
601    fn test_stats() {
602        let mut index = LearnedIndex::new(RMIConfig::default());
603
604        for i in 0..10 {
605            let cid = Cid::default();
606            index.add(cid, vec![i as f32, 0.0]).unwrap();
607        }
608
609        let query = vec![5.0, 0.0];
610        let _ = index.search(&query, 3).unwrap();
611
612        let stats = index.stats();
613        assert_eq!(stats.data_points, 10);
614        assert_eq!(stats.searches, 1);
615    }
616
617    #[test]
618    fn test_clear() {
619        let mut index = LearnedIndex::new(RMIConfig::default());
620
621        let cid = Cid::default();
622        index.add(cid, vec![1.0, 2.0, 3.0]).unwrap();
623        assert_eq!(index.size(), 1);
624
625        index.clear();
626        assert_eq!(index.size(), 0);
627    }
628
629    #[test]
630    fn test_config_variants() {
631        let configs = vec![
632            RMIConfig {
633                model_type: ModelType::Linear,
634                ..Default::default()
635            },
636            RMIConfig {
637                model_type: ModelType::Polynomial,
638                ..Default::default()
639            },
640            RMIConfig {
641                model_type: ModelType::NeuralNetwork,
642                ..Default::default()
643            },
644        ];
645
646        for config in configs {
647            let mut index = LearnedIndex::new(config);
648            for i in 0..20 {
649                let cid = Cid::default();
650                index.add(cid, vec![i as f32, 0.0, 0.0]).unwrap();
651            }
652
653            let query = vec![10.0, 0.0, 0.0];
654            let results = index.search(&query, 5).unwrap();
655            assert!(!results.is_empty());
656        }
657    }
658}