use crate::routing::predictive::PredictiveLayer;
use crate::{NervousSystemError, Result};
#[derive(Debug, Clone)]
pub struct PredictiveConfig {
pub dimension: usize,
pub threshold: f32,
pub learning_rate: f32,
pub adaptive_threshold: bool,
pub target_compression: f32,
}
impl Default for PredictiveConfig {
fn default() -> Self {
Self {
dimension: 128,
threshold: 0.1, learning_rate: 0.1, adaptive_threshold: true,
target_compression: 0.1, }
}
}
impl PredictiveConfig {
pub fn new(dimension: usize) -> Self {
Self {
dimension,
..Default::default()
}
}
pub fn with_threshold(mut self, threshold: f32) -> Self {
self.threshold = threshold;
self
}
pub fn with_learning_rate(mut self, lr: f32) -> Self {
self.learning_rate = lr;
self
}
pub fn with_target_compression(mut self, target: f32) -> Self {
self.target_compression = target;
self
}
}
pub struct PredictiveWriter {
config: PredictiveConfig,
prediction_layer: PredictiveLayer,
stats: WriterStats,
}
#[derive(Debug, Clone)]
struct WriterStats {
attempts: usize,
writes: usize,
compression: f32,
}
impl WriterStats {
fn new() -> Self {
Self {
attempts: 0,
writes: 0,
compression: 0.0,
}
}
fn record_attempt(&mut self, wrote: bool) {
self.attempts += 1;
if wrote {
self.writes += 1;
}
if self.attempts > 0 {
self.compression = self.writes as f32 / self.attempts as f32;
}
}
}
impl PredictiveWriter {
pub fn new(config: PredictiveConfig) -> Self {
let prediction_layer = PredictiveLayer::with_learning_rate(
config.dimension,
config.threshold,
config.learning_rate,
);
Self {
config,
prediction_layer,
stats: WriterStats::new(),
}
}
pub fn should_write(&self, new_vector: &[f32]) -> bool {
self.prediction_layer.should_transmit(new_vector)
}
pub fn residual_write(&mut self, new_vector: &[f32]) -> Option<Vec<f32>> {
let result = self.prediction_layer.residual_gated_write(new_vector);
self.stats.record_attempt(result.is_some());
if self.config.adaptive_threshold && self.stats.attempts % 100 == 0 {
self.adapt_threshold();
}
result
}
pub fn record_write(&mut self, written_vector: &[f32]) {
self.prediction_layer.update(written_vector);
self.stats.record_attempt(true);
}
pub fn current_prediction(&self) -> &[f32] {
self.prediction_layer.prediction()
}
pub fn stats(&self) -> CompressionStats {
CompressionStats {
total_attempts: self.stats.attempts,
actual_writes: self.stats.writes,
compression_ratio: self.stats.compression,
bandwidth_reduction: 1.0 - self.stats.compression,
}
}
pub fn reset_stats(&mut self) {
self.stats = WriterStats::new();
}
fn adapt_threshold(&mut self) {
let current_ratio = self.stats.compression;
let target = self.config.target_compression;
if current_ratio > target * 1.1 {
let new_threshold = self.config.threshold * 1.1;
self.config.threshold = new_threshold.min(0.5); self.prediction_layer.set_threshold(self.config.threshold);
}
else if current_ratio < target * 0.9 {
let new_threshold = self.config.threshold * 0.9;
self.config.threshold = new_threshold.max(0.01); self.prediction_layer.set_threshold(self.config.threshold);
}
}
pub fn threshold(&self) -> f32 {
self.config.threshold
}
}
#[derive(Debug, Clone)]
pub struct CompressionStats {
pub total_attempts: usize,
pub actual_writes: usize,
pub compression_ratio: f32,
pub bandwidth_reduction: f32,
}
impl CompressionStats {
pub fn reduction_percent(&self) -> f32 {
self.bandwidth_reduction * 100.0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_predictive_writer_creation() {
let config = PredictiveConfig::new(128);
let writer = PredictiveWriter::new(config);
let stats = writer.stats();
assert_eq!(stats.total_attempts, 0);
assert_eq!(stats.actual_writes, 0);
}
#[test]
fn test_first_write_always_happens() {
let config = PredictiveConfig::new(64);
let writer = PredictiveWriter::new(config);
let vector = vec![0.5; 64];
assert!(writer.should_write(&vector));
}
#[test]
fn test_residual_write() {
let config = PredictiveConfig::new(64).with_threshold(0.1);
let mut writer = PredictiveWriter::new(config);
let v1 = vec![0.5; 64];
let residual1 = writer.residual_write(&v1);
assert!(residual1.is_some());
let v2 = vec![0.501; 64];
let _residual2 = writer.residual_write(&v2);
let stats = writer.stats();
assert!(stats.total_attempts >= 2);
}
#[test]
fn test_compression_statistics() {
let config = PredictiveConfig::new(32).with_threshold(0.2);
let mut writer = PredictiveWriter::new(config);
let stable = vec![1.0; 32];
for _ in 0..100 {
let _ = writer.residual_write(&stable);
}
let stats = writer.stats();
assert_eq!(stats.total_attempts, 100);
assert!(
stats.compression_ratio < 0.5,
"Compression ratio too high: {}",
stats.compression_ratio
);
assert!(stats.bandwidth_reduction > 0.5);
}
#[test]
fn test_adaptive_threshold() {
let config = PredictiveConfig::new(32)
.with_threshold(0.1)
.with_target_compression(0.1);
let mut writer = PredictiveWriter::new(config);
let _initial_threshold = writer.threshold();
for i in 0..200 {
let mut signal = vec![1.0; 32];
signal[0] = 1.0 + (i as f32 * 0.001).sin() * 0.05;
let _ = writer.residual_write(&signal);
}
let final_threshold = writer.threshold();
assert!(final_threshold > 0.01 && final_threshold < 0.5);
}
#[test]
fn test_record_write() {
let config = PredictiveConfig::new(16);
let mut writer = PredictiveWriter::new(config);
let v1 = vec![0.5; 16];
writer.record_write(&v1);
let stats = writer.stats();
assert_eq!(stats.actual_writes, 1);
assert_eq!(stats.total_attempts, 1);
}
#[test]
fn test_config_builder() {
let config = PredictiveConfig::new(256)
.with_threshold(0.15)
.with_learning_rate(0.2)
.with_target_compression(0.05);
assert_eq!(config.dimension, 256);
assert_eq!(config.threshold, 0.15);
assert_eq!(config.learning_rate, 0.2);
assert_eq!(config.target_compression, 0.05);
}
#[test]
fn test_prediction_convergence() {
let config = PredictiveConfig::new(8).with_learning_rate(0.3);
let mut writer = PredictiveWriter::new(config);
let signal = vec![0.7; 8];
for _ in 0..50 {
let _ = writer.residual_write(&signal);
}
let prediction = writer.current_prediction();
let error: f32 = prediction
.iter()
.zip(signal.iter())
.map(|(p, s)| (p - s).abs())
.sum::<f32>()
/ signal.len() as f32;
assert!(error < 0.05, "Prediction error too high: {}", error);
}
}