1use ipfrs_core::{Cid, Error, Result};
37use serde::{Deserialize, Serialize};
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct RMIConfig {
42 pub num_models: usize,
44 pub model_type: ModelType,
46 pub training_iterations: usize,
48 pub learning_rate: f32,
50 pub error_threshold: f32,
52}
53
54impl Default for RMIConfig {
55 fn default() -> Self {
56 Self {
57 num_models: 10,
58 model_type: ModelType::Linear,
59 training_iterations: 100,
60 learning_rate: 0.01,
61 error_threshold: 0.05,
62 }
63 }
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
68pub enum ModelType {
69 Linear,
71 NeuralNetwork,
73 Polynomial,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79struct Model {
80 model_type: ModelType,
82 weights: Vec<f32>,
84 bias: f32,
86 input_dim: usize,
88}
89
90impl Model {
91 fn new(model_type: ModelType, input_dim: usize) -> Self {
93 let weight_count = match model_type {
94 ModelType::Linear => input_dim,
95 ModelType::Polynomial => input_dim * 2, ModelType::NeuralNetwork => input_dim * 8 + 8, };
98
99 Self {
100 model_type,
101 weights: vec![0.01; weight_count],
102 bias: 0.0,
103 input_dim,
104 }
105 }
106
107 fn predict(&self, input: &[f32]) -> f32 {
109 match self.model_type {
110 ModelType::Linear => self.predict_linear(input),
111 ModelType::Polynomial => self.predict_polynomial(input),
112 ModelType::NeuralNetwork => self.predict_neural(input),
113 }
114 }
115
116 fn predict_linear(&self, input: &[f32]) -> f32 {
117 let mut sum = self.bias;
118 for (i, &val) in input.iter().enumerate() {
119 if i < self.weights.len() {
120 sum += self.weights[i] * val;
121 }
122 }
123 sum.clamp(0.0, 1.0)
124 }
125
126 fn predict_polynomial(&self, input: &[f32]) -> f32 {
127 let mut sum = self.bias;
128 let half = self.weights.len() / 2;
129
130 for (i, &val) in input.iter().enumerate() {
132 if i < half {
133 sum += self.weights[i] * val;
134 }
135 }
136
137 for (i, &val) in input.iter().enumerate() {
139 if half + i < self.weights.len() {
140 sum += self.weights[half + i] * val * val;
141 }
142 }
143
144 sum.clamp(0.0, 1.0)
145 }
146
147 fn predict_neural(&self, input: &[f32]) -> f32 {
148 let hidden_size = 8;
149 let input_weights = &self.weights[0..self.input_dim * hidden_size];
150 let output_weights = &self.weights[self.input_dim * hidden_size..];
151
152 let mut hidden = vec![0.0; hidden_size];
154 for h in 0..hidden_size {
155 let mut sum = 0.0;
156 for (i, &val) in input.iter().enumerate() {
157 if h * self.input_dim + i < input_weights.len() {
158 sum += input_weights[h * self.input_dim + i] * val;
159 }
160 }
161 hidden[h] = sum.max(0.0); }
163
164 let mut output = self.bias;
166 for (h, &val) in hidden.iter().enumerate() {
167 if h < output_weights.len() {
168 output += output_weights[h] * val;
169 }
170 }
171
172 1.0 / (1.0 + (-output).exp())
174 }
175
176 #[allow(dead_code)]
178 fn train(&mut self, data: &[(Vec<f32>, f32)], learning_rate: f32, iterations: usize) {
179 for _ in 0..iterations {
180 for (input, target) in data {
181 let prediction = self.predict(input);
182 let error = target - prediction;
183
184 match self.model_type {
186 ModelType::Linear => {
187 for (i, &val) in input.iter().enumerate() {
188 if i < self.weights.len() {
189 self.weights[i] += learning_rate * error * val;
190 }
191 }
192 self.bias += learning_rate * error;
193 }
194 ModelType::Polynomial => {
195 let half = self.weights.len() / 2;
196 for (i, &val) in input.iter().enumerate() {
197 if i < half {
198 self.weights[i] += learning_rate * error * val;
199 }
200 if half + i < self.weights.len() {
201 self.weights[half + i] += learning_rate * error * val * val;
202 }
203 }
204 self.bias += learning_rate * error;
205 }
206 ModelType::NeuralNetwork => {
207 for i in 0..self.weights.len() {
209 self.weights[i] += learning_rate * error * 0.01;
210 }
211 self.bias += learning_rate * error;
212 }
213 }
214 }
215 }
216 }
217}
218
219pub struct LearnedIndex {
221 config: RMIConfig,
223 root_model: Option<Model>,
225 stage1_models: Vec<Model>,
227 data: Vec<(Cid, Vec<f32>)>,
229 dimension: Option<usize>,
231 stats: IndexStats,
233}
234
235#[derive(Debug, Default)]
236struct IndexStats {
237 searches: usize,
239 total_error: f32,
241 data_points: usize,
243}
244
245impl LearnedIndex {
246 pub fn new(config: RMIConfig) -> Self {
248 Self {
249 config,
250 root_model: None,
251 stage1_models: Vec::new(),
252 data: Vec::new(),
253 dimension: None,
254 stats: IndexStats::default(),
255 }
256 }
257
258 pub fn add(&mut self, cid: Cid, embedding: Vec<f32>) -> Result<()> {
260 if let Some(dim) = self.dimension {
261 if embedding.len() != dim {
262 return Err(Error::InvalidInput(format!(
263 "Dimension mismatch: expected {}, got {}",
264 dim,
265 embedding.len()
266 )));
267 }
268 } else {
269 self.dimension = Some(embedding.len());
270 }
271
272 self.data.push((cid, embedding));
273 self.stats.data_points += 1;
274
275 if self.data.len().is_multiple_of(100) {
277 self.rebuild()?;
278 }
279
280 Ok(())
281 }
282
283 pub fn rebuild(&mut self) -> Result<()> {
285 if self.data.is_empty() {
286 return Ok(());
287 }
288
289 let dim = self
290 .dimension
291 .ok_or_else(|| Error::InvalidInput("No dimension set".to_string()))?;
292
293 self.data.sort_by(|a, b| {
295 a.1[0]
296 .partial_cmp(&b.1[0])
297 .unwrap_or(std::cmp::Ordering::Equal)
298 });
299
300 self.root_model = Some(Model::new(self.config.model_type, dim));
302 self.stage1_models = (0..self.config.num_models)
303 .map(|_| Model::new(self.config.model_type, dim))
304 .collect();
305
306 self.train_models()?;
308
309 Ok(())
310 }
311
312 fn train_models(&mut self) -> Result<()> {
313 if self.data.is_empty() {
314 return Ok(());
315 }
316
317 let n = self.data.len();
318
319 let mut root_training_data = Vec::new();
321 for (i, (_cid, embedding)) in self.data.iter().enumerate() {
322 let normalized_pos = i as f32 / n as f32;
323 let normalized_embedding = self.normalize_embedding(embedding);
324 root_training_data.push((normalized_embedding, normalized_pos));
325 }
326
327 if let Some(ref mut root) = self.root_model {
329 root.train(
330 &root_training_data,
331 self.config.learning_rate,
332 self.config.training_iterations,
333 );
334 }
335
336 let chunk_size = n / self.config.num_models;
338
339 let mut all_model_training_data = Vec::new();
341 for model_idx in 0..self.config.num_models {
342 let start = model_idx * chunk_size;
343 let end = if model_idx == self.config.num_models - 1 {
344 n
345 } else {
346 (model_idx + 1) * chunk_size
347 };
348
349 let mut model_training_data = Vec::new();
350 for i in start..end {
351 if let Some((_cid, embedding)) = self.data.get(i) {
352 let local_pos = (i - start) as f32 / (end - start) as f32;
353 let normalized_embedding = self.normalize_embedding(embedding);
354 model_training_data.push((normalized_embedding, local_pos));
355 }
356 }
357 all_model_training_data.push(model_training_data);
358 }
359
360 for (model, training_data) in self
362 .stage1_models
363 .iter_mut()
364 .zip(all_model_training_data.iter())
365 {
366 if !training_data.is_empty() {
367 model.train(
368 training_data,
369 self.config.learning_rate,
370 self.config.training_iterations,
371 );
372 }
373 }
374
375 Ok(())
376 }
377
378 fn normalize_embedding(&self, embedding: &[f32]) -> Vec<f32> {
379 let min = embedding.iter().cloned().fold(f32::INFINITY, f32::min);
381 let max = embedding.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
382 let range = max - min;
383
384 if range > 1e-6 {
385 embedding.iter().map(|&x| (x - min) / range).collect()
386 } else {
387 vec![0.5; embedding.len()]
388 }
389 }
390
391 pub fn search(&mut self, query: &[f32], k: usize) -> Result<Vec<(Cid, f32)>> {
393 if self.data.is_empty() {
394 return Ok(Vec::new());
395 }
396
397 let dim = self
398 .dimension
399 .ok_or_else(|| Error::InvalidInput("No dimension set".to_string()))?;
400
401 if query.len() != dim {
402 return Err(Error::InvalidInput(format!(
403 "Dimension mismatch: expected {}, got {}",
404 dim,
405 query.len()
406 )));
407 }
408
409 if self.root_model.is_none() {
411 self.rebuild()?;
412 }
413
414 self.stats.searches += 1;
415
416 let predicted_pos = self.predict_position(query)?;
418 let n = self.data.len();
419 let start_idx = (predicted_pos * n as f32) as usize;
420
421 let window_size = (n as f32 * self.config.error_threshold).max(k as f32 * 2.0) as usize;
423 let search_start = start_idx.saturating_sub(window_size / 2);
424 let search_end = (start_idx + window_size / 2).min(n);
425
426 let mut candidates = Vec::new();
428 for i in search_start..search_end {
429 if let Some((cid, embedding)) = self.data.get(i) {
430 let distance = self.compute_distance(query, embedding);
431 candidates.push((*cid, distance));
432 }
433 }
434
435 candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
437 Ok(candidates.into_iter().take(k).collect())
438 }
439
440 fn predict_position(&mut self, query: &[f32]) -> Result<f32> {
441 let normalized_query = self.normalize_embedding(query);
442
443 let root_prediction = if let Some(ref root) = self.root_model {
445 root.predict(&normalized_query)
446 } else {
447 return Err(Error::InvalidInput("No root model".to_string()));
448 };
449
450 let model_idx = ((root_prediction * self.config.num_models as f32) as usize)
452 .min(self.config.num_models - 1);
453
454 let local_prediction = if let Some(model) = self.stage1_models.get(model_idx) {
456 model.predict(&normalized_query)
457 } else {
458 0.5
459 };
460
461 let chunk_size = 1.0 / self.config.num_models as f32;
463 let final_prediction = model_idx as f32 * chunk_size + local_prediction * chunk_size;
464
465 Ok(final_prediction.clamp(0.0, 1.0))
466 }
467
468 fn compute_distance(&self, a: &[f32], b: &[f32]) -> f32 {
469 a.iter()
471 .zip(b.iter())
472 .map(|(x, y)| (x - y) * (x - y))
473 .sum::<f32>()
474 .sqrt()
475 }
476
477 pub fn stats(&self) -> LearnedIndexStats {
479 LearnedIndexStats {
480 data_points: self.stats.data_points,
481 searches: self.stats.searches,
482 num_models: self.stage1_models.len() + 1,
483 avg_error: if self.stats.searches > 0 {
484 self.stats.total_error / self.stats.searches as f32
485 } else {
486 0.0
487 },
488 }
489 }
490
491 pub fn size(&self) -> usize {
493 self.data.len()
494 }
495
496 pub fn clear(&mut self) {
498 self.data.clear();
499 self.root_model = None;
500 self.stage1_models.clear();
501 self.stats = IndexStats::default();
502 }
503}
504
505#[derive(Debug, Clone)]
507pub struct LearnedIndexStats {
508 pub data_points: usize,
510 pub searches: usize,
512 pub num_models: usize,
514 pub avg_error: f32,
516}
517
518#[cfg(test)]
519mod tests {
520 use super::*;
521
522 #[test]
523 fn test_learned_index_creation() {
524 let index = LearnedIndex::new(RMIConfig::default());
525 assert_eq!(index.size(), 0);
526 }
527
528 #[test]
529 fn test_add_and_search() {
530 let mut index = LearnedIndex::new(RMIConfig::default());
531
532 for i in 0..100 {
534 let cid = Cid::default();
535 let embedding = vec![i as f32 / 100.0, 0.5, 0.5, 0.5];
536 index.add(cid, embedding).unwrap();
537 }
538
539 assert_eq!(index.size(), 100);
540
541 let query = vec![0.5, 0.5, 0.5, 0.5];
543 let results = index.search(&query, 5).unwrap();
544 assert_eq!(results.len(), 5);
545 }
546
547 #[test]
548 fn test_model_prediction() {
549 let model = Model::new(ModelType::Linear, 4);
550 let input = vec![0.1, 0.2, 0.3, 0.4];
551 let prediction = model.predict(&input);
552 assert!((0.0..=1.0).contains(&prediction));
553 }
554
555 #[test]
556 fn test_polynomial_model() {
557 let model = Model::new(ModelType::Polynomial, 4);
558 let input = vec![0.5, 0.5, 0.5, 0.5];
559 let prediction = model.predict(&input);
560 assert!((0.0..=1.0).contains(&prediction));
561 }
562
563 #[test]
564 fn test_neural_model() {
565 let model = Model::new(ModelType::NeuralNetwork, 4);
566 let input = vec![0.3, 0.4, 0.5, 0.6];
567 let prediction = model.predict(&input);
568 assert!((0.0..=1.0).contains(&prediction));
569 }
570
571 #[test]
572 fn test_dimension_mismatch() {
573 let mut index = LearnedIndex::new(RMIConfig::default());
574
575 let cid1 = Cid::default();
576 index.add(cid1, vec![1.0, 2.0, 3.0]).unwrap();
577
578 let cid2 = Cid::default();
579 let result = index.add(cid2, vec![1.0, 2.0]);
580 assert!(result.is_err());
581 }
582
583 #[test]
584 fn test_rebuild_index() {
585 let mut index = LearnedIndex::new(RMIConfig::default());
586
587 for i in 0..50 {
588 let cid = Cid::default();
589 let embedding = vec![i as f32, 0.0, 0.0];
590 index.add(cid, embedding).unwrap();
591 }
592
593 index.rebuild().unwrap();
594
595 let query = vec![25.0, 0.0, 0.0];
596 let results = index.search(&query, 3).unwrap();
597 assert_eq!(results.len(), 3);
598 }
599
600 #[test]
601 fn test_stats() {
602 let mut index = LearnedIndex::new(RMIConfig::default());
603
604 for i in 0..10 {
605 let cid = Cid::default();
606 index.add(cid, vec![i as f32, 0.0]).unwrap();
607 }
608
609 let query = vec![5.0, 0.0];
610 let _ = index.search(&query, 3).unwrap();
611
612 let stats = index.stats();
613 assert_eq!(stats.data_points, 10);
614 assert_eq!(stats.searches, 1);
615 }
616
617 #[test]
618 fn test_clear() {
619 let mut index = LearnedIndex::new(RMIConfig::default());
620
621 let cid = Cid::default();
622 index.add(cid, vec![1.0, 2.0, 3.0]).unwrap();
623 assert_eq!(index.size(), 1);
624
625 index.clear();
626 assert_eq!(index.size(), 0);
627 }
628
629 #[test]
630 fn test_config_variants() {
631 let configs = vec![
632 RMIConfig {
633 model_type: ModelType::Linear,
634 ..Default::default()
635 },
636 RMIConfig {
637 model_type: ModelType::Polynomial,
638 ..Default::default()
639 },
640 RMIConfig {
641 model_type: ModelType::NeuralNetwork,
642 ..Default::default()
643 },
644 ];
645
646 for config in configs {
647 let mut index = LearnedIndex::new(config);
648 for i in 0..20 {
649 let cid = Cid::default();
650 index.add(cid, vec![i as f32, 0.0, 0.0]).unwrap();
651 }
652
653 let query = vec![10.0, 0.0, 0.0];
654 let results = index.search(&query, 5).unwrap();
655 assert!(!results.is_empty());
656 }
657 }
658}