use crate::config::ConsolidationConfig;
use crate::error::Result;
use crate::ltm::LongTermMemory;
use crate::stm::ShortTermMemory;
use crate::types::{Entity, Link, MemoryEntry, Relation, Timestamp};
pub struct Consolidator {
config: ConsolidationConfig,
last_run: Timestamp,
stats: ConsolidationStats,
}
impl Consolidator {
pub fn new(config: ConsolidationConfig) -> Self {
Self {
config,
last_run: Timestamp::now(),
stats: ConsolidationStats::default(),
}
}
pub fn run(&mut self, stm: &mut ShortTermMemory, ltm: &mut LongTermMemory) -> Result<usize> {
let mut consolidated_count = 0;
let candidates = self.select_candidates(stm);
for entry in candidates.into_iter().take(self.config.batch_size) {
let id = entry.id.clone();
ltm.store(entry.clone())?;
self.extract_knowledge(&entry, ltm)?;
stm.mark_consolidated(&id)?;
consolidated_count += 1;
self.stats.total_consolidated += 1;
}
self.last_run = Timestamp::now();
self.stats.last_run = self.last_run;
self.stats.runs += 1;
Ok(consolidated_count)
}
pub fn should_run(&self, stm: &ShortTermMemory) -> bool {
if !self.config.auto_consolidate {
return false;
}
if stm.len() >= self.config.max_stm_before_consolidate {
return true;
}
false
}
pub fn stats(&self) -> &ConsolidationStats {
&self.stats
}
fn select_candidates(&self, stm: &ShortTermMemory) -> Vec<MemoryEntry> {
let importance_threshold = self.config.importance_threshold;
let min_age = self.config.min_age_secs;
let min_access = self.config.min_access_count;
stm.get_consolidation_candidates(importance_threshold)
.into_iter()
.filter(|entry| {
let age = entry.metadata.created_at.age_secs();
if age < min_age {
return false;
}
if entry.metadata.access_count < min_access {
return false;
}
true
})
.cloned()
.collect()
}
fn extract_knowledge(&self, entry: &MemoryEntry, ltm: &mut LongTermMemory) -> Result<()> {
if let Some(name) = self.extract_name(entry) {
let mut entity = Entity::new(&entry.entry_type, &name);
if let serde_json::Value::Object(map) = &entry.data {
for (key, value) in map {
entity.properties.insert(key.clone(), value.clone());
}
}
if let Some(ref emb) = entry.embedding {
entity.embedding = Some(emb.clone());
}
let entity_id = ltm.add_entity(entity)?;
for tag in &entry.tags {
let tag_entity = Entity::new("tag", &tag.0);
if let Ok(tag_id) = ltm.add_entity(tag_entity) {
let link = Link::new(entity_id.clone(), Relation::new("TAGGED"), tag_id);
let _ = ltm.add_link(link); }
}
}
Ok(())
}
fn extract_name(&self, entry: &MemoryEntry) -> Option<String> {
let name_fields = ["name", "id", "identifier", "label", "title"];
if let serde_json::Value::Object(map) = &entry.data {
for field in name_fields {
if let Some(serde_json::Value::String(s)) = map.get(field) {
return Some(s.clone());
}
}
}
Some(entry.id.to_hex()[..16].to_string())
}
}
#[derive(Debug, Clone, Default)]
pub struct ConsolidationStats {
pub total_consolidated: usize,
pub runs: usize,
pub last_run: Timestamp,
}
#[derive(Debug, Clone, Copy)]
pub enum ConsolidationStrategy {
FrequencyBased,
ImportanceBased,
NoveltyBased,
Combined,
}
impl Default for ConsolidationStrategy {
fn default() -> Self {
Self::Combined
}
}
pub struct AdvancedConsolidator {
base: Consolidator,
strategy: ConsolidationStrategy,
}
impl AdvancedConsolidator {
pub fn new(config: ConsolidationConfig, strategy: ConsolidationStrategy) -> Self {
Self {
base: Consolidator::new(config),
strategy,
}
}
pub fn run(&mut self, stm: &mut ShortTermMemory, ltm: &mut LongTermMemory) -> Result<usize> {
match self.strategy {
ConsolidationStrategy::FrequencyBased => self.run_frequency_based(stm, ltm),
ConsolidationStrategy::ImportanceBased => self.base.run(stm, ltm),
ConsolidationStrategy::NoveltyBased => self.run_novelty_based(stm, ltm),
ConsolidationStrategy::Combined => self.run_combined(stm, ltm),
}
}
fn run_frequency_based(
&mut self,
stm: &mut ShortTermMemory,
ltm: &mut LongTermMemory,
) -> Result<usize> {
let mut candidates: Vec<_> = stm
.get_consolidation_candidates(0.0) .into_iter()
.cloned()
.collect();
candidates.sort_by(|a, b| b.metadata.access_count.cmp(&a.metadata.access_count));
let mut count = 0;
for entry in candidates.into_iter().take(self.base.config.batch_size) {
let id = entry.id.clone();
ltm.store(entry)?;
stm.mark_consolidated(&id)?;
count += 1;
}
Ok(count)
}
fn run_novelty_based(
&mut self,
stm: &mut ShortTermMemory,
ltm: &mut LongTermMemory,
) -> Result<usize> {
let candidates: Vec<_> = stm
.get_consolidation_candidates(0.3)
.into_iter()
.cloned()
.collect();
let mut scored: Vec<_> = candidates
.into_iter()
.map(|entry| {
let novelty = if let Some(ref emb) = entry.embedding {
let max_similarity = ltm
.semantic_search(emb, 1)
.first()
.map(|(_, sim)| *sim)
.unwrap_or(0.0);
1.0 - max_similarity } else {
0.5 };
(entry, novelty)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut count = 0;
for (entry, _novelty) in scored.into_iter().take(self.base.config.batch_size) {
let id = entry.id.clone();
ltm.store(entry)?;
stm.mark_consolidated(&id)?;
count += 1;
}
Ok(count)
}
fn run_combined(
&mut self,
stm: &mut ShortTermMemory,
ltm: &mut LongTermMemory,
) -> Result<usize> {
let candidates: Vec<_> = stm
.get_consolidation_candidates(self.base.config.importance_threshold * 0.5)
.into_iter()
.cloned()
.collect();
let mut scored: Vec<_> = candidates
.into_iter()
.map(|entry| {
let importance_score = entry.metadata.importance;
let frequency_score = (entry.metadata.access_count as f32 / 10.0).min(1.0);
let age_secs = entry.metadata.created_at.age_secs();
let recency_score = 1.0 / (1.0 + (age_secs as f32 / 3600.0));
let novelty_score = if let Some(ref emb) = entry.embedding {
let max_sim = ltm
.semantic_search(emb, 1)
.first()
.map(|(_, sim)| *sim)
.unwrap_or(0.0);
1.0 - max_sim
} else {
0.5
};
let combined = importance_score * 0.35
+ frequency_score * 0.25
+ recency_score * 0.15
+ novelty_score * 0.25;
(entry, combined)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut count = 0;
for (entry, _score) in scored.into_iter().take(self.base.config.batch_size) {
let id = entry.id.clone();
ltm.store(entry.clone())?;
self.base.extract_knowledge(&entry, ltm)?;
stm.mark_consolidated(&id)?;
count += 1;
self.base.stats.total_consolidated += 1;
}
self.base.stats.runs += 1;
self.base.last_run = Timestamp::now();
Ok(count)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{LtmConfig, StmConfig};
use crate::types::MemoryEntry;
fn make_entry(name: &str, importance: f32) -> MemoryEntry {
let mut entry = MemoryEntry::new("test", serde_json::json!({"name": name}));
entry.metadata.importance = importance;
entry.metadata.access_count = 3;
entry
}
#[test]
fn test_consolidation() {
let stm_config = StmConfig::default();
let ltm_config = LtmConfig::default();
let cons_config = ConsolidationConfig {
importance_threshold: 0.5,
min_access_count: 2,
min_age_secs: 0, batch_size: 10,
..Default::default()
};
let mut stm = ShortTermMemory::new(stm_config);
let mut ltm = LongTermMemory::new(ltm_config);
let mut consolidator = Consolidator::new(cons_config);
stm.store(make_entry("low", 0.3)).unwrap();
stm.store(make_entry("high", 0.8)).unwrap();
stm.store(make_entry("medium", 0.6)).unwrap();
let count = consolidator.run(&mut stm, &mut ltm).unwrap();
assert!(count >= 1);
assert!(ltm.memory_count() > 0);
}
#[test]
fn test_knowledge_extraction() {
let stm_config = StmConfig::default();
let ltm_config = LtmConfig::default();
let cons_config = ConsolidationConfig {
importance_threshold: 0.1,
min_access_count: 0,
min_age_secs: 0,
batch_size: 10,
..Default::default()
};
let mut stm = ShortTermMemory::new(stm_config);
let mut ltm = LongTermMemory::new(ltm_config);
let mut consolidator = Consolidator::new(cons_config);
let entry = make_entry("sensor_001", 0.8).with_tags(&["iot", "temperature"]);
stm.store(entry).unwrap();
consolidator.run(&mut stm, &mut ltm).unwrap();
assert!(ltm.entity_count() > 0);
}
}