use std::collections::HashMap;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use super::filter::MemoriesFilter;
use super::types::{Memory, MemoryId};
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct MemoriesState {
pub(super) memories: HashMap<MemoryId, Arc<Memory>>,
}
impl MemoriesState {
pub fn new() -> Self {
Self::default()
}
pub fn get(&self, id: MemoryId) -> Option<&Memory> {
self.memories.get(&id).map(|a| a.as_ref())
}
pub fn get_arc(&self, id: MemoryId) -> Option<Arc<Memory>> {
self.memories.get(&id).cloned()
}
pub fn len(&self) -> usize {
self.memories.len()
}
pub fn is_empty(&self) -> bool {
self.memories.is_empty()
}
pub fn contains(&self, id: MemoryId) -> bool {
self.memories.contains_key(&id)
}
pub fn all(&self) -> impl Iterator<Item = &Memory> {
self.memories.values().map(|a| a.as_ref())
}
pub fn pinned(&self) -> impl Iterator<Item = &Memory> {
self.memories
.values()
.map(|a| a.as_ref())
.filter(|m| m.pinned)
}
pub fn unpinned(&self) -> impl Iterator<Item = &Memory> {
self.memories
.values()
.map(|a| a.as_ref())
.filter(|m| !m.pinned)
}
pub fn find_unique(&self, id: MemoryId) -> Option<&Memory> {
self.get(id)
}
pub fn find_many(&self, filter: &MemoriesFilter) -> Vec<Arc<Memory>> {
filter.apply(self.query()).collect()
}
pub fn count_where(&self, filter: &MemoriesFilter) -> usize {
filter.apply(self.query()).count()
}
pub fn exists_where(&self, filter: &MemoriesFilter) -> bool {
filter.apply(self.query()).exists()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn mem(id: MemoryId, pinned: bool) -> Arc<Memory> {
Arc::new(Memory {
id,
content: format!("mem-{}", id),
tags: Vec::new(),
source: "test".into(),
created_ns: 0,
updated_ns: 0,
pinned,
})
}
#[test]
fn test_empty_state() {
let s = MemoriesState::new();
assert!(s.is_empty());
assert_eq!(s.len(), 0);
assert!(s.get(1).is_none());
assert!(!s.contains(1));
assert_eq!(s.pinned().count(), 0);
assert_eq!(s.unpinned().count(), 0);
}
#[test]
fn test_pin_split() {
let mut s = MemoriesState::new();
s.memories.insert(1, mem(1, true));
s.memories.insert(2, mem(2, false));
s.memories.insert(3, mem(3, true));
assert_eq!(s.len(), 3);
assert_eq!(s.pinned().count(), 2);
assert_eq!(s.unpinned().count(), 1);
assert!(s.contains(2));
assert!(!s.contains(99));
}
}