use super::facts::TypedFacts;
use std::collections::{HashMap, HashSet};
use std::sync::atomic::{AtomicU64, Ordering};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct FactHandle(u64);
impl FactHandle {
pub fn new(id: u64) -> Self {
Self(id)
}
pub fn id(&self) -> u64 {
self.0
}
}
impl std::fmt::Display for FactHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "FactHandle({})", self.0)
}
}
#[derive(Debug, Clone)]
pub struct WorkingMemoryFact {
pub handle: FactHandle,
pub fact_type: String,
pub data: TypedFacts,
pub metadata: FactMetadata,
#[cfg(feature = "streaming")]
pub stream_source: Option<String>,
#[cfg(feature = "streaming")]
pub stream_event: Option<crate::streaming::event::StreamEvent>,
}
#[derive(Debug, Clone)]
pub struct FactMetadata {
pub inserted_at: std::time::Instant,
pub updated_at: std::time::Instant,
pub update_count: usize,
pub retracted: bool,
}
impl Default for FactMetadata {
fn default() -> Self {
let now = std::time::Instant::now();
Self {
inserted_at: now,
updated_at: now,
update_count: 0,
retracted: false,
}
}
}
pub struct WorkingMemory {
facts: HashMap<FactHandle, WorkingMemoryFact>,
type_index: HashMap<String, HashSet<FactHandle>>,
next_id: AtomicU64,
modified_handles: HashSet<FactHandle>,
retracted_handles: HashSet<FactHandle>,
}
impl WorkingMemory {
pub fn new() -> Self {
Self {
facts: HashMap::new(),
type_index: HashMap::new(),
next_id: AtomicU64::new(1),
modified_handles: HashSet::new(),
retracted_handles: HashSet::new(),
}
}
pub fn insert(&mut self, fact_type: String, data: TypedFacts) -> FactHandle {
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
let handle = FactHandle::new(id);
let fact = WorkingMemoryFact {
handle,
fact_type: fact_type.clone(),
data,
metadata: FactMetadata::default(),
#[cfg(feature = "streaming")]
stream_source: None,
#[cfg(feature = "streaming")]
stream_event: None,
};
self.facts.insert(handle, fact);
self.type_index.entry(fact_type).or_default().insert(handle);
self.modified_handles.insert(handle);
handle
}
#[cfg(feature = "streaming")]
pub fn insert_from_stream(
&mut self,
stream_name: String,
event: crate::streaming::event::StreamEvent,
) -> FactHandle {
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
let handle = FactHandle::new(id);
let mut typed_facts = super::facts::TypedFacts::new();
for (key, value) in &event.data {
let fact_value: super::facts::FactValue = value.clone().into();
typed_facts.set(key.clone(), fact_value);
}
let fact = WorkingMemoryFact {
handle,
fact_type: event.event_type.clone(),
data: typed_facts,
metadata: FactMetadata::default(),
stream_source: Some(stream_name.clone()),
stream_event: Some(event),
};
self.facts.insert(handle, fact);
self.type_index
.entry(stream_name)
.or_default()
.insert(handle);
self.modified_handles.insert(handle);
handle
}
pub fn update(&mut self, handle: FactHandle, data: TypedFacts) -> Result<(), String> {
let fact = self
.facts
.get_mut(&handle)
.ok_or_else(|| format!("FactHandle {} not found", handle))?;
if fact.metadata.retracted {
return Err(format!("FactHandle {} is retracted", handle));
}
fact.data = data;
fact.metadata.updated_at = std::time::Instant::now();
fact.metadata.update_count += 1;
self.modified_handles.insert(handle);
Ok(())
}
pub fn retract(&mut self, handle: FactHandle) -> Result<(), String> {
let fact = self
.facts
.get_mut(&handle)
.ok_or_else(|| format!("FactHandle {} not found", handle))?;
if fact.metadata.retracted {
return Err(format!("FactHandle {} already retracted", handle));
}
fact.metadata.retracted = true;
self.retracted_handles.insert(handle);
if let Some(handles) = self.type_index.get_mut(&fact.fact_type) {
handles.remove(&handle);
}
Ok(())
}
pub fn get(&self, handle: &FactHandle) -> Option<&WorkingMemoryFact> {
self.facts.get(handle).filter(|f| !f.metadata.retracted)
}
pub fn get_by_type(&self, fact_type: &str) -> Vec<&WorkingMemoryFact> {
if let Some(handles) = self.type_index.get(fact_type) {
handles
.iter()
.filter_map(|h| self.facts.get(h))
.filter(|f| !f.metadata.retracted)
.collect()
} else {
Vec::new()
}
}
pub fn get_all_facts(&self) -> Vec<&WorkingMemoryFact> {
self.facts
.values()
.filter(|f| !f.metadata.retracted)
.collect()
}
pub fn get_all_handles(&self) -> Vec<FactHandle> {
self.facts
.values()
.filter(|f| !f.metadata.retracted)
.map(|f| f.handle)
.collect()
}
pub fn get_modified_handles(&self) -> &HashSet<FactHandle> {
&self.modified_handles
}
pub fn get_retracted_handles(&self) -> &HashSet<FactHandle> {
&self.retracted_handles
}
pub fn clear_modification_tracking(&mut self) {
self.modified_handles.clear();
self.retracted_handles.clear();
}
pub fn stats(&self) -> WorkingMemoryStats {
let active_facts = self
.facts
.values()
.filter(|f| !f.metadata.retracted)
.count();
let retracted_facts = self.facts.values().filter(|f| f.metadata.retracted).count();
WorkingMemoryStats {
total_facts: self.facts.len(),
active_facts,
retracted_facts,
types: self.type_index.len(),
modified_pending: self.modified_handles.len(),
retracted_pending: self.retracted_handles.len(),
}
}
pub fn clear(&mut self) {
self.facts.clear();
self.type_index.clear();
self.modified_handles.clear();
self.retracted_handles.clear();
}
pub fn to_typed_facts(&self) -> TypedFacts {
let mut result = TypedFacts::new();
for fact in self.get_all_facts() {
let prefix = format!("{}.{}", fact.fact_type, fact.handle.id());
for (key, value) in fact.data.get_all() {
result.set(format!("{}.{}", prefix, key), value.clone());
}
for (key, value) in fact.data.get_all() {
result.set(format!("{}.{}", fact.fact_type, key), value.clone());
}
result.set_fact_handle(fact.fact_type.clone(), fact.handle);
}
result
}
}
impl Default for WorkingMemory {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct WorkingMemoryStats {
pub total_facts: usize,
pub active_facts: usize,
pub retracted_facts: usize,
pub types: usize,
pub modified_pending: usize,
pub retracted_pending: usize,
}
impl std::fmt::Display for WorkingMemoryStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"WM Stats: {} active, {} retracted, {} types, {} modified, {} pending retraction",
self.active_facts,
self.retracted_facts,
self.types,
self.modified_pending,
self.retracted_pending
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_insert_and_get() {
let mut wm = WorkingMemory::new();
let mut person_data = TypedFacts::new();
person_data.set("name", "John");
person_data.set("age", 25i64);
let handle = wm.insert("Person".to_string(), person_data);
let fact = wm.get(&handle).unwrap();
assert_eq!(fact.fact_type, "Person");
assert_eq!(fact.data.get("name").unwrap().as_string(), "John");
}
#[test]
fn test_update() {
let mut wm = WorkingMemory::new();
let mut data = TypedFacts::new();
data.set("age", 25i64);
let handle = wm.insert("Person".to_string(), data);
let mut updated_data = TypedFacts::new();
updated_data.set("age", 26i64);
wm.update(handle, updated_data).unwrap();
let fact = wm.get(&handle).unwrap();
assert_eq!(fact.data.get("age").unwrap().as_integer(), Some(26));
assert_eq!(fact.metadata.update_count, 1);
}
#[test]
fn test_retract() {
let mut wm = WorkingMemory::new();
let data = TypedFacts::new();
let handle = wm.insert("Person".to_string(), data);
wm.retract(handle).unwrap();
assert!(wm.get(&handle).is_none());
assert_eq!(wm.get_all_facts().len(), 0);
}
#[test]
fn test_type_index() {
let mut wm = WorkingMemory::new();
for i in 0..5 {
let mut data = TypedFacts::new();
data.set("id", i as i64);
wm.insert("Person".to_string(), data);
}
for i in 0..3 {
let mut data = TypedFacts::new();
data.set("id", i as i64);
wm.insert("Order".to_string(), data);
}
assert_eq!(wm.get_by_type("Person").len(), 5);
assert_eq!(wm.get_by_type("Order").len(), 3);
assert_eq!(wm.get_by_type("Unknown").len(), 0);
}
#[test]
fn test_modification_tracking() {
let mut wm = WorkingMemory::new();
let data = TypedFacts::new();
let h1 = wm.insert("Person".to_string(), data.clone());
let h2 = wm.insert("Person".to_string(), data.clone());
assert_eq!(wm.get_modified_handles().len(), 2);
wm.clear_modification_tracking();
assert_eq!(wm.get_modified_handles().len(), 0);
wm.update(h1, data.clone()).unwrap();
assert_eq!(wm.get_modified_handles().len(), 1);
wm.retract(h2).unwrap();
assert_eq!(wm.get_retracted_handles().len(), 1);
}
}