use super::config::LearnedIndexConfig;
use super::training::{IndexTrainer, TrainingStats};
use super::types::{
IndexStatistics, LearnedIndexError, LearnedIndexResult, PredictionBounds, TrainingExample,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Instant;
#[derive(Clone, Serialize, Deserialize)]
pub struct NeuralVectorIndex {
config: LearnedIndexConfig,
weights: Vec<Vec<f32>>,
biases: Vec<f32>,
sorted_keys: Vec<Vec<f32>>,
key_positions: HashMap<String, usize>,
error_stats: ErrorStatistics,
stats: IndexStatistics,
is_trained: bool,
}
#[derive(Clone, Serialize, Deserialize)]
struct ErrorStatistics {
mean_error: f64,
std_error: f64,
max_error: usize,
}
impl NeuralVectorIndex {
pub fn new(config: LearnedIndexConfig) -> LearnedIndexResult<Self> {
config
.validate()
.map_err(|e| LearnedIndexError::InvalidConfiguration { message: e })?;
Ok(Self {
config,
weights: Vec::new(),
biases: Vec::new(),
sorted_keys: Vec::new(),
key_positions: HashMap::new(),
error_stats: ErrorStatistics {
mean_error: 0.0,
std_error: 0.0,
max_error: 0,
},
stats: IndexStatistics::new(),
is_trained: false,
})
}
pub fn train(&mut self, examples: Vec<TrainingExample>) -> LearnedIndexResult<TrainingStats> {
if examples.len() < self.config.min_training_examples {
return Err(LearnedIndexError::InsufficientData {
min_required: self.config.min_training_examples,
actual: examples.len(),
});
}
tracing::info!("Training learned index on {} examples", examples.len());
let trainer = IndexTrainer::new(self.config.training.clone());
let training_stats = trainer.train(&mut self.weights, &mut self.biases, &examples)?;
let mut sorted_examples = examples.clone();
sorted_examples.sort_by(|a, b| {
a.target_position
.partial_cmp(&b.target_position)
.unwrap_or(std::cmp::Ordering::Equal)
});
self.sorted_keys = sorted_examples
.iter()
.map(|ex| ex.features.clone())
.collect();
for (idx, example) in sorted_examples.iter().enumerate() {
let key = self.key_to_string(&example.features);
self.key_positions.insert(key, idx);
}
self.compute_error_statistics(&examples);
self.is_trained = true;
tracing::info!(
"Training complete: loss={:.4}, accuracy={:.2}%",
training_stats.final_loss,
training_stats.final_accuracy * 100.0
);
Ok(training_stats)
}
pub fn predict(&mut self, key: &[f32]) -> LearnedIndexResult<PredictionBounds> {
if !self.is_trained {
return Err(LearnedIndexError::ModelNotTrained);
}
let start = Instant::now();
let prediction = self.forward(key);
let normalized_pred = Self::sigmoid(prediction);
let predicted_pos = (normalized_pred * self.sorted_keys.len() as f32) as usize;
let predicted_pos = predicted_pos.min(self.sorted_keys.len().saturating_sub(1));
let error_bound =
(self.error_stats.std_error * self.config.error_bound_multiplier as f64) as usize;
let lower = predicted_pos.saturating_sub(error_bound);
let upper = (predicted_pos + error_bound).min(self.sorted_keys.len());
let confidence = 1.0 - (normalized_pred - 0.5).abs() * 2.0;
let bounds = PredictionBounds::new(predicted_pos, lower, upper, confidence);
let elapsed = start.elapsed().as_micros() as f64;
self.stats.record_lookup(bounds.error_magnitude, elapsed);
Ok(bounds)
}
pub fn lookup(&mut self, key: &[f32]) -> LearnedIndexResult<usize> {
let key_str = self.key_to_string(key);
if let Some(&pos) = self.key_positions.get(&key_str) {
return Ok(pos);
}
let bounds = self.predict(key)?;
let search_range = bounds.search_range();
let result = self.binary_search_range(key, search_range);
if let Ok(actual_pos) = result {
let within_bounds = actual_pos >= bounds.lower_bound && actual_pos < bounds.upper_bound;
self.stats
.record_prediction(bounds.predicted, actual_pos, within_bounds);
}
result
}
fn forward(&self, input: &[f32]) -> f32 {
if self.weights.is_empty() {
return input.iter().sum::<f32>() / input.len() as f32;
}
let mut activation = input.to_vec();
for (layer_idx, layer_weights) in self.weights.iter().enumerate() {
let mut next_activation = Vec::new();
let output_size = layer_weights.len() / activation.len();
for i in 0..output_size {
let mut sum = 0.0;
for (j, &input_val) in activation.iter().enumerate() {
let weight_idx = j * output_size + i;
if weight_idx < layer_weights.len() {
sum += input_val * layer_weights[weight_idx];
}
}
if layer_idx < self.biases.len() {
sum += self.biases[layer_idx];
}
next_activation.push(Self::relu(sum));
}
activation = next_activation;
}
activation.first().copied().unwrap_or(0.5)
}
fn binary_search_range(
&self,
key: &[f32],
range: std::ops::Range<usize>,
) -> LearnedIndexResult<usize> {
let mut left = range.start;
let mut right = range.end;
while left < right {
let mid = (left + right) / 2;
if mid >= self.sorted_keys.len() {
break;
}
let mid_key = &self.sorted_keys[mid];
match Self::compare_keys(key, mid_key) {
std::cmp::Ordering::Less => right = mid,
std::cmp::Ordering::Greater => left = mid + 1,
std::cmp::Ordering::Equal => return Ok(mid),
}
}
if left < self.sorted_keys.len() {
Ok(left)
} else {
Err(LearnedIndexError::PredictionOutOfBounds {
predicted: left,
actual_size: self.sorted_keys.len(),
})
}
}
fn compute_error_statistics(&mut self, examples: &[TrainingExample]) {
let mut errors = Vec::new();
for example in examples.iter().take(1000) {
let predicted = self.forward(&example.features);
let normalized = Self::sigmoid(predicted);
let predicted_pos = (normalized * self.sorted_keys.len() as f32) as usize;
let error = predicted_pos.abs_diff(example.target_position);
errors.push(error as f64);
}
if !errors.is_empty() {
let mean = errors.iter().sum::<f64>() / errors.len() as f64;
let variance =
errors.iter().map(|e| (e - mean).powi(2)).sum::<f64>() / errors.len() as f64;
self.error_stats = ErrorStatistics {
mean_error: mean,
std_error: variance.sqrt(),
max_error: errors.iter().map(|&e| e as usize).max().unwrap_or(0),
};
}
}
pub fn statistics(&self) -> &IndexStatistics {
&self.stats
}
pub fn is_trained(&self) -> bool {
self.is_trained
}
fn sigmoid(x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
fn relu(x: f32) -> f32 {
x.max(0.0)
}
fn key_to_string(&self, key: &[f32]) -> String {
key.iter()
.map(|v| format!("{:.4}", v))
.collect::<Vec<_>>()
.join(",")
}
fn compare_keys(a: &[f32], b: &[f32]) -> std::cmp::Ordering {
for (av, bv) in a.iter().zip(b.iter()) {
match av.partial_cmp(bv) {
Some(std::cmp::Ordering::Equal) => continue,
Some(ord) => return ord,
None => return std::cmp::Ordering::Equal,
}
}
a.len().cmp(&b.len())
}
}
#[cfg(test)]
mod tests {
type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
use super::*;
use scirs2_core::random::Random;
fn create_test_examples(n: usize) -> Vec<TrainingExample> {
let mut rng = Random::seed(42);
(0..n)
.map(|i| {
let features = vec![i as f32 / n as f32, rng.gen_range(0.0..1.0)];
TrainingExample::new(features, i)
})
.collect()
}
#[test]
fn test_neural_index_creation() {
let config = LearnedIndexConfig::default_config();
let index = NeuralVectorIndex::new(config);
assert!(index.is_ok());
}
#[test]
fn test_training_insufficient_data() -> Result<()> {
let config = LearnedIndexConfig::default_config();
let mut index = NeuralVectorIndex::new(config)?;
let examples = create_test_examples(10);
let result = index.train(examples);
assert!(result.is_err());
Ok(())
}
#[test]
fn test_prediction_before_training() -> Result<()> {
let config = LearnedIndexConfig::default_config();
let mut index = NeuralVectorIndex::new(config)?;
let key = vec![0.5, 0.5];
let result = index.predict(&key);
assert!(result.is_err());
Ok(())
}
#[test]
fn test_training_and_prediction() -> Result<()> {
let mut config = LearnedIndexConfig::speed_optimized();
config.min_training_examples = 100;
let mut index = NeuralVectorIndex::new(config)?;
let examples = create_test_examples(100);
let stats = index.train(examples)?;
assert!(stats.final_loss >= 0.0);
assert!(index.is_trained());
let key = vec![0.5, 0.5];
let bounds = index.predict(&key)?;
assert!(bounds.predicted < 100);
assert!(bounds.lower_bound <= bounds.predicted);
assert!(bounds.upper_bound > bounds.predicted);
Ok(())
}
}