tinycortex 0.1.0

Rust core for the TinyCortex memory system
Documentation
use std::collections::BTreeMap;
use std::sync::{Arc, RwLock};

use async_trait::async_trait;

use super::types::{
    MemoryError, MemoryId, MemoryInput, MemoryQuery, MemoryRecord, MemoryResult, SearchHit,
};

#[async_trait]
pub trait MemoryStore: Send + Sync {
    async fn insert(&self, input: MemoryInput) -> MemoryResult<MemoryRecord>;
    async fn get(&self, id: MemoryId) -> MemoryResult<MemoryRecord>;
    async fn delete(&self, id: MemoryId) -> MemoryResult<MemoryRecord>;
    async fn search(&self, query: MemoryQuery) -> MemoryResult<Vec<SearchHit>>;
}

#[derive(Clone, Debug, Default)]
pub struct InMemoryMemoryStore {
    records: Arc<RwLock<BTreeMap<MemoryId, MemoryRecord>>>,
}

impl InMemoryMemoryStore {
    pub fn new() -> Self {
        Self::default()
    }
}

#[async_trait]
impl MemoryStore for InMemoryMemoryStore {
    async fn insert(&self, input: MemoryInput) -> MemoryResult<MemoryRecord> {
        let record = MemoryRecord::from_input(input)?;
        self.records
            .write()
            .expect("memory store lock poisoned")
            .insert(record.id, record.clone());
        Ok(record)
    }

    async fn get(&self, id: MemoryId) -> MemoryResult<MemoryRecord> {
        self.records
            .read()
            .expect("memory store lock poisoned")
            .get(&id)
            .cloned()
            .ok_or(MemoryError::NotFound(id))
    }

    async fn delete(&self, id: MemoryId) -> MemoryResult<MemoryRecord> {
        self.records
            .write()
            .expect("memory store lock poisoned")
            .remove(&id)
            .ok_or(MemoryError::NotFound(id))
    }

    async fn search(&self, query: MemoryQuery) -> MemoryResult<Vec<SearchHit>> {
        let needle = query.text.as_deref().map(str::to_lowercase);
        let limit = query.limit.unwrap_or(20);

        let mut hits = self
            .records
            .read()
            .expect("memory store lock poisoned")
            .values()
            .filter(|record| {
                query
                    .namespace
                    .as_deref()
                    .map_or(true, |namespace| record.namespace == namespace)
            })
            .filter_map(|record| {
                let score = match needle.as_deref() {
                    Some(text) if !text.is_empty() => {
                        let content = record.content.to_lowercase();
                        if !content.contains(text) {
                            return None;
                        }
                        text.split_whitespace()
                            .filter(|term| content.contains(term))
                            .count()
                            .max(1) as f32
                    }
                    _ => 1.0,
                };

                Some(SearchHit {
                    record: record.clone(),
                    score,
                })
            })
            .collect::<Vec<_>>();

        hits.sort_by(|left, right| {
            right
                .score
                .partial_cmp(&left.score)
                .unwrap_or(std::cmp::Ordering::Equal)
                .then_with(|| right.record.created_at.cmp(&left.record.created_at))
        });
        hits.truncate(limit);
        Ok(hits)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn inserts_and_retrieves_memory() {
        let store = InMemoryMemoryStore::new();
        let inserted = store
            .insert(MemoryInput::new("profile", "User prefers dark mode"))
            .await
            .expect("insert memory");

        let fetched = store.get(inserted.id).await.expect("get memory");

        assert_eq!(fetched.namespace, "profile");
        assert_eq!(fetched.content, "User prefers dark mode");
    }

    #[tokio::test]
    async fn searches_by_namespace_and_text() {
        let store = InMemoryMemoryStore::new();
        store
            .insert(MemoryInput::new("profile", "User prefers dark mode"))
            .await
            .expect("insert profile memory");
        store
            .insert(MemoryInput::new(
                "project",
                "TinyCortex stores durable memories",
            ))
            .await
            .expect("insert project memory");

        let hits = store
            .search(MemoryQuery {
                namespace: Some("project".to_owned()),
                text: Some("durable".to_owned()),
                limit: Some(10),
            })
            .await
            .expect("search memory");

        assert_eq!(hits.len(), 1);
        assert_eq!(hits[0].record.namespace, "project");
    }
}