pub mod config;
pub mod consolidation;
pub mod error;
pub mod ltm;
pub mod stm;
pub mod types;
pub use config::{ConsolidationConfig, LtmConfig, MemoryConfig, StmConfig};
pub use consolidation::Consolidator;
pub use error::{Error, Result};
pub use ltm::{KnowledgeGraph, LongTermMemory};
pub use stm::ShortTermMemory;
pub use types::{
Embedding, Entity, EntityId, Link, LinkType, MemoryEntry, MemoryId, MemoryMetadata,
MemoryQuery, MemoryResult, Relation, SemanticTag,
};
pub struct TitansMemory {
pub stm: ShortTermMemory,
pub ltm: LongTermMemory,
consolidator: Consolidator,
config: MemoryConfig,
}
impl TitansMemory {
pub fn new(config: MemoryConfig) -> Self {
Self {
stm: ShortTermMemory::new(config.stm.clone()),
ltm: LongTermMemory::new(config.ltm.clone()),
consolidator: Consolidator::new(config.consolidation.clone()),
config,
}
}
pub fn iot_mode() -> Self {
Self::new(MemoryConfig::iot_mode())
}
pub fn agent_mode() -> Self {
Self::new(MemoryConfig::agent_mode())
}
pub fn remember(&mut self, entry: MemoryEntry) -> Result<MemoryId> {
self.stm.store(entry)
}
pub fn remember_important(&mut self, entry: MemoryEntry, importance: f32) -> Result<MemoryId> {
let mut entry = entry;
entry.metadata.importance = importance;
self.stm.store(entry)
}
pub fn recall(&self, query: &MemoryQuery) -> Result<Vec<MemoryResult>> {
let mut results = Vec::new();
results.extend(self.stm.query(query)?);
results.extend(self.ltm.query(query)?);
results.sort_by(|a, b| {
b.relevance
.partial_cmp(&a.relevance)
.unwrap_or(std::cmp::Ordering::Equal)
});
if let Some(limit) = query.limit {
results.truncate(limit);
}
Ok(results)
}
pub fn recall_text(&self, query_text: &str) -> Result<Vec<MemoryResult>> {
let query = MemoryQuery::text(query_text);
self.recall(&query)
}
pub fn recall_tagged(&self, tags: &[&str]) -> Result<Vec<MemoryResult>> {
let query = MemoryQuery::tags(tags);
self.recall(&query)
}
pub fn recall_recent(&self, count: usize) -> Result<Vec<MemoryResult>> {
self.stm.get_recent(count)
}
pub fn consolidate(&mut self) -> Result<usize> {
self.consolidator.run(&mut self.stm, &mut self.ltm)
}
pub fn consolidate_memory(&mut self, id: &MemoryId) -> Result<()> {
if let Some(entry) = self.stm.get(id)? {
self.ltm.store(entry)?;
self.stm.remove(id)?;
}
Ok(())
}
pub fn get(&self, id: &MemoryId) -> Result<Option<MemoryEntry>> {
if let Some(entry) = self.stm.get(id)? {
return Ok(Some(entry));
}
self.ltm.get(id)
}
pub fn forget(&mut self, id: &MemoryId) -> Result<()> {
self.stm.remove(id)?;
self.ltm.remove(id)?;
Ok(())
}
pub fn decay(&mut self) -> Result<()> {
self.stm.decay()
}
pub fn prune_stm(&mut self) -> Result<usize> {
self.stm.prune()
}
pub fn stats(&self) -> MemoryStats {
MemoryStats {
stm_count: self.stm.len(),
stm_capacity: self.config.stm.max_entries,
ltm_entity_count: self.ltm.entity_count(),
ltm_link_count: self.ltm.link_count(),
total_memory_bytes: self.stm.memory_usage() + self.ltm.memory_usage(),
}
}
pub fn clear(&mut self) -> Result<()> {
self.stm.clear()?;
self.ltm.clear()?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct MemoryStats {
pub stm_count: usize,
pub stm_capacity: usize,
pub ltm_entity_count: usize,
pub ltm_link_count: usize,
pub total_memory_bytes: usize,
}
impl Default for TitansMemory {
fn default() -> Self {
Self::new(MemoryConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_creation() {
let memory = TitansMemory::default();
assert_eq!(memory.stats().stm_count, 0);
}
#[test]
fn test_remember_recall() {
let mut memory = TitansMemory::default();
let entry = MemoryEntry::new("test", serde_json::json!({"value": 42}));
let id = memory.remember(entry).unwrap();
let retrieved = memory.get(&id).unwrap();
assert!(retrieved.is_some());
}
#[test]
fn test_iot_mode() {
let memory = TitansMemory::iot_mode();
assert!(memory.config.stm.max_entries <= 100);
}
#[test]
fn test_agent_mode() {
let memory = TitansMemory::agent_mode();
assert!(memory.config.stm.max_entries >= 100);
}
#[test]
fn test_remember_important() {
let mut memory = TitansMemory::default();
let entry = MemoryEntry::new("important", serde_json::json!({"critical": true}));
let id = memory.remember_important(entry, 0.95).unwrap();
let retrieved = memory.get(&id).unwrap().unwrap();
assert_eq!(retrieved.metadata.importance, 0.95);
}
#[test]
fn test_recall_empty() {
let memory = TitansMemory::default();
let query = MemoryQuery::text("anything");
let results = memory.recall(&query).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_recall_with_limit() {
let mut memory = TitansMemory::default();
for i in 0..10 {
let entry = MemoryEntry::new(&format!("entry_{}", i), serde_json::json!({"n": i}));
memory.remember(entry).unwrap();
}
let mut query = MemoryQuery::text("entry");
query.limit = Some(3);
let results = memory.recall(&query).unwrap();
assert!(results.len() <= 3);
}
#[test]
fn test_recall_text() {
let mut memory = TitansMemory::default();
let entry = MemoryEntry::new("sensor_data", serde_json::json!({"temp": 25.0}));
memory.remember(entry).unwrap();
let results = memory.recall_text("sensor").unwrap();
assert!(results.len() <= 1);
}
#[test]
fn test_recall_tagged() {
let mut memory = TitansMemory::default();
let mut entry = MemoryEntry::new("tagged_entry", serde_json::json!({"data": 123}));
entry.tags.push(SemanticTag::new("test_tag"));
memory.remember(entry).unwrap();
let results = memory.recall_tagged(&["test_tag"]).unwrap();
assert!(results.len() <= 1);
}
#[test]
fn test_recall_recent() {
let mut memory = TitansMemory::default();
for i in 0..5 {
let entry = MemoryEntry::new(&format!("recent_{}", i), serde_json::json!({"n": i}));
memory.remember(entry).unwrap();
}
let results = memory.recall_recent(3).unwrap();
assert!(results.len() <= 3);
}
#[test]
fn test_consolidate() {
let mut memory = TitansMemory::default();
for i in 0..3 {
let entry = MemoryEntry::new(&format!("important_{}", i), serde_json::json!({"n": i}));
memory.remember_important(entry, 0.9).unwrap();
}
let consolidated = memory.consolidate().unwrap();
assert!(consolidated >= 0);
}
#[test]
fn test_consolidate_memory() {
let mut memory = TitansMemory::default();
let entry = MemoryEntry::new("to_consolidate", serde_json::json!({"data": 1}));
let id = memory.remember(entry).unwrap();
memory.consolidate_memory(&id).unwrap();
}
#[test]
fn test_consolidate_nonexistent() {
let mut memory = TitansMemory::default();
let fake_id = MemoryId::from_bytes([0u8; 32]);
memory.consolidate_memory(&fake_id).unwrap();
}
#[test]
fn test_forget() {
let mut memory = TitansMemory::default();
let entry = MemoryEntry::new("to_forget", serde_json::json!({"temp": 1}));
let id = memory.remember(entry).unwrap();
assert!(memory.get(&id).unwrap().is_some());
memory.forget(&id).unwrap();
assert!(memory.get(&id).unwrap().is_none());
}
#[test]
fn test_decay() {
let mut memory = TitansMemory::default();
let entry = MemoryEntry::new("decaying", serde_json::json!({"val": 1}));
memory.remember_important(entry, 1.0).unwrap();
memory.decay().unwrap();
}
#[test]
fn test_prune_stm() {
let mut memory = TitansMemory::iot_mode();
for i in 0..200 {
let entry = MemoryEntry::new(&format!("entry_{}", i), serde_json::json!({"n": i}));
memory.remember(entry).unwrap();
}
let pruned = memory.prune_stm().unwrap();
assert!(pruned >= 0);
}
#[test]
fn test_stats() {
let mut memory = TitansMemory::default();
for i in 0..5 {
let entry = MemoryEntry::new(&format!("stat_{}", i), serde_json::json!({"n": i}));
memory.remember(entry).unwrap();
}
let stats = memory.stats();
assert_eq!(stats.stm_count, 5);
assert!(stats.stm_capacity > 0);
assert!(stats.total_memory_bytes > 0);
}
#[test]
fn test_stats_clone() {
let stats = MemoryStats {
stm_count: 10,
stm_capacity: 100,
ltm_entity_count: 5,
ltm_link_count: 3,
total_memory_bytes: 1024,
};
let cloned = stats.clone();
assert_eq!(cloned.stm_count, 10);
assert_eq!(cloned.ltm_entity_count, 5);
}
#[test]
fn test_stats_debug() {
let stats = MemoryStats {
stm_count: 1,
stm_capacity: 50,
ltm_entity_count: 2,
ltm_link_count: 1,
total_memory_bytes: 512,
};
let debug_str = format!("{:?}", stats);
assert!(debug_str.contains("MemoryStats"));
assert!(debug_str.contains("stm_count"));
}
#[test]
fn test_clear() {
let mut memory = TitansMemory::default();
for i in 0..5 {
let entry = MemoryEntry::new(&format!("clear_{}", i), serde_json::json!({"n": i}));
memory.remember(entry).unwrap();
}
assert_eq!(memory.stats().stm_count, 5);
memory.clear().unwrap();
assert_eq!(memory.stats().stm_count, 0);
}
#[test]
fn test_get_from_ltm() {
let mut memory = TitansMemory::default();
let entry = MemoryEntry::new("ltm_entry", serde_json::json!({"data": 42}));
let id = memory.remember(entry).unwrap();
memory.consolidate_memory(&id).unwrap();
let result = memory.get(&id).unwrap();
let _ = result;
}
#[test]
fn test_get_nonexistent() {
let memory = TitansMemory::default();
let fake_id = MemoryId::from_bytes([99u8; 32]);
let result = memory.get(&fake_id).unwrap();
assert!(result.is_none());
}
#[test]
fn test_multiple_operations() {
let mut memory = TitansMemory::default();
let ids: Vec<MemoryId> = (0..10)
.map(|i| {
let entry = MemoryEntry::new(&format!("op_{}", i), serde_json::json!({"n": i}));
memory.remember(entry).unwrap()
})
.collect();
assert_eq!(memory.stats().stm_count, 10);
memory.forget(&ids[0]).unwrap();
memory.forget(&ids[1]).unwrap();
assert_eq!(memory.stats().stm_count, 8);
memory.consolidate_memory(&ids[2]).unwrap();
memory.decay().unwrap();
memory.clear().unwrap();
assert_eq!(memory.stats().stm_count, 0);
}
}