use crate::error::{MLError, Result};
use scirs2_core::ndarray::{Array1, Array2};
use super::config::{MemoryType, MemoryConfig};
pub trait MemorySystem: std::fmt::Debug {
fn store_examples(&mut self, data: &Array2<f64>, labels: &Array1<i32>) -> Result<()>;
fn retrieve_examples(&self, n_examples: usize) -> Result<(Array2<f64>, Array1<i32>)>;
fn get_statistics(&self) -> MemoryStatistics;
}
#[derive(Debug, Clone)]
pub struct MemoryStatistics {
pub capacity: usize,
pub used: usize,
pub efficiency: f64,
}
#[derive(Debug)]
pub struct EpisodicMemory {
config: MemoryConfig,
stored_data: Vec<Array1<f64>>,
stored_labels: Vec<i32>,
}
impl EpisodicMemory {
pub fn new(config: MemoryConfig) -> Self {
Self {
config,
stored_data: Vec::new(),
stored_labels: Vec::new(),
}
}
}
impl MemorySystem for EpisodicMemory {
fn store_examples(&mut self, data: &Array2<f64>, labels: &Array1<i32>) -> Result<()> {
for (i, row) in data.outer_iter().enumerate() {
if self.stored_data.len() < self.config.capacity {
self.stored_data.push(row.to_owned());
self.stored_labels.push(labels[i]);
}
}
Ok(())
}
fn retrieve_examples(&self, n_examples: usize) -> Result<(Array2<f64>, Array1<i32>)> {
let n = n_examples.min(self.stored_data.len());
if n == 0 {
return Ok((Array2::zeros((0, 0)), Array1::zeros(0)));
}
let data = Array2::zeros((n, self.stored_data[0].len()));
let labels = Array1::zeros(n);
Ok((data, labels))
}
fn get_statistics(&self) -> MemoryStatistics {
MemoryStatistics {
capacity: self.config.capacity,
used: self.stored_data.len(),
efficiency: self.stored_data.len() as f64 / self.config.capacity as f64,
}
}
}
pub fn create_memory_system(
memory_type: MemoryType,
config: MemoryConfig,
) -> Result<Box<dyn MemorySystem>> {
match memory_type {
MemoryType::Episodic => Ok(Box::new(EpisodicMemory::new(config))),
MemoryType::Semantic => Ok(Box::new(EpisodicMemory::new(config))), MemoryType::Working => Ok(Box::new(EpisodicMemory::new(config))), MemoryType::Quantum => Ok(Box::new(EpisodicMemory::new(config))), }
}