open-kioku-memory 2.1.0

Append-only repo memory and entity-linked recall for Open Kioku.
Documentation
use chrono::Utc;
use open_kioku_core::{Confidence, EntityLink, MemoryFact, MemoryFactId, MemorySearchResult};
use open_kioku_errors::{OkError, Result};
use rusqlite::{params, Connection, OptionalExtension};
use sha2::{Digest, Sha256};
use std::path::{Path, PathBuf};
use std::sync::Mutex;

pub struct RepoMemoryStore {
    connection: Mutex<Connection>,
}

impl RepoMemoryStore {
    pub fn open(path: impl AsRef<Path>) -> Result<Self> {
        let path = path.as_ref();
        if let Some(parent) = path.parent() {
            std::fs::create_dir_all(parent)
                .map_err(|err| OkError::Storage(format!("create memory dir: {err}")))?;
        }
        let connection = Connection::open(path).map_err(storage_err)?;
        let store = Self {
            connection: Mutex::new(connection),
        };
        store.initialize()?;
        Ok(store)
    }

    pub fn open_repo(repo: impl AsRef<Path>) -> Result<Self> {
        Self::open(default_memory_path(repo))
    }

    pub fn remember(&self, text: &str, source: &str, confidence: Confidence) -> Result<MemoryFact> {
        let text = text.trim();
        if text.is_empty() {
            return Err(OkError::Config("memory fact text cannot be empty".into()));
        }
        let created_at = Utc::now();
        let fact = MemoryFact {
            id: MemoryFactId::new(memory_id(text, source, created_at.timestamp_micros())),
            text: text.into(),
            source: source.into(),
            confidence,
            entities: extract_entities(text),
            created_at,
        };
        let conn = self
            .connection
            .lock()
            .map_err(|_| OkError::Storage("memory sqlite mutex poisoned".into()))?;
        conn.execute(
            "INSERT INTO memory_facts(id, created_at, source, text, json) VALUES(?1, ?2, ?3, ?4, ?5)",
            params![
                &fact.id.0,
                fact.created_at.to_rfc3339(),
                &fact.source,
                &fact.text,
                serde_json::to_string(&fact)?
            ],
        )
        .map_err(storage_err)?;
        Ok(fact)
    }

    pub fn get(&self, id: &MemoryFactId) -> Result<Option<MemoryFact>> {
        let conn = self
            .connection
            .lock()
            .map_err(|_| OkError::Storage("memory sqlite mutex poisoned".into()))?;
        let raw = conn
            .query_row(
                "SELECT json FROM memory_facts WHERE id = ?1",
                params![&id.0],
                |row| row.get::<_, String>(0),
            )
            .optional()
            .map_err(storage_err)?;
        raw.map(|json| serde_json::from_str(&json).map_err(Into::into))
            .transpose()
    }

    pub fn search(&self, query: &str, limit: usize) -> Result<Vec<MemorySearchResult>> {
        let query = query.trim();
        if query.is_empty() {
            return Ok(Vec::new());
        }
        let facts = self.recent(500)?;
        let query_terms = terms(query);
        let query_entities = extract_entities(query);
        let mut scored = facts
            .into_iter()
            .filter_map(|fact| score_fact(fact, &query_terms, &query_entities))
            .collect::<Vec<_>>();
        scored.sort_by(|a, b| {
            b.score
                .partial_cmp(&a.score)
                .unwrap_or(std::cmp::Ordering::Equal)
                .then_with(|| b.fact.created_at.cmp(&a.fact.created_at))
        });
        scored.truncate(limit.min(100));
        Ok(scored)
    }

    pub fn recent(&self, limit: usize) -> Result<Vec<MemoryFact>> {
        let conn = self
            .connection
            .lock()
            .map_err(|_| OkError::Storage("memory sqlite mutex poisoned".into()))?;
        let mut stmt = conn
            .prepare("SELECT json FROM memory_facts ORDER BY created_at DESC LIMIT ?1")
            .map_err(storage_err)?;
        let rows = stmt
            .query_map(params![limit as i64], |row| row.get::<_, String>(0))
            .map_err(storage_err)?;
        let mut facts = Vec::new();
        for row in rows {
            facts.push(serde_json::from_str(&row.map_err(storage_err)?)?);
        }
        Ok(facts)
    }

    fn initialize(&self) -> Result<()> {
        let conn = self
            .connection
            .lock()
            .map_err(|_| OkError::Storage("memory sqlite mutex poisoned".into()))?;
        conn.execute_batch(
            "
            CREATE TABLE IF NOT EXISTS memory_facts (
                id TEXT PRIMARY KEY,
                created_at TEXT NOT NULL,
                source TEXT NOT NULL,
                text TEXT NOT NULL,
                json TEXT NOT NULL
            );
            CREATE INDEX IF NOT EXISTS idx_memory_created_at ON memory_facts(created_at);
            CREATE INDEX IF NOT EXISTS idx_memory_source ON memory_facts(source);
            ",
        )
        .map_err(storage_err)?;
        Ok(())
    }
}

pub fn default_memory_path(repo: impl AsRef<Path>) -> PathBuf {
    repo.as_ref().join(".ok/memory.sqlite")
}

pub fn extract_entities(text: &str) -> Vec<EntityLink> {
    let mut entities = Vec::new();
    for token in text.split_whitespace() {
        let cleaned = token.trim_matches(|ch: char| {
            !(ch.is_ascii_alphanumeric()
                || ch == '_'
                || ch == '-'
                || ch == '/'
                || ch == '.'
                || ch == ':')
        });
        if cleaned.len() < 3 {
            continue;
        }
        let kind = if is_path_like(cleaned) {
            "file"
        } else if is_ticket_id(cleaned) {
            "ticket"
        } else if cleaned.starts_with("cargo")
            || cleaned.starts_with("./")
            || cleaned.starts_with("npm")
            || cleaned.starts_with("pytest")
        {
            "command"
        } else if is_identifier(cleaned) {
            "symbol"
        } else {
            continue;
        };
        if !entities
            .iter()
            .any(|entity: &EntityLink| entity.kind == kind && entity.value == cleaned)
        {
            entities.push(EntityLink {
                kind: kind.into(),
                value: cleaned.into(),
                file_range: None,
                confidence: Confidence::Medium,
            });
        }
    }
    entities
}

fn score_fact(
    fact: MemoryFact,
    query_terms: &[String],
    query_entities: &[EntityLink],
) -> Option<MemorySearchResult> {
    let lower = fact.text.to_ascii_lowercase();
    let mut score = 0.0;
    let mut evidence = Vec::new();
    let term_hits = query_terms
        .iter()
        .filter(|term| lower.contains(term.as_str()))
        .count();
    if term_hits > 0 {
        score += 0.25 + term_hits as f32 * 0.08;
        evidence.push(format!("{term_hits} lexical term match(es)"));
    }
    let entity_hits = query_entities
        .iter()
        .filter(|query_entity| {
            fact.entities.iter().any(|fact_entity| {
                fact_entity.kind == query_entity.kind && fact_entity.value == query_entity.value
            })
        })
        .count();
    if entity_hits > 0 {
        score += 0.45 + entity_hits as f32 * 0.15;
        evidence.push(format!("{entity_hits} entity link match(es)"));
    }
    score += fact.confidence.score() * 0.1;
    if evidence.is_empty() {
        return None;
    }
    Some(MemorySearchResult {
        fact,
        score,
        match_reason: "repo memory lexical/entity match".into(),
        evidence,
    })
}

fn terms(query: &str) -> Vec<String> {
    query
        .split(|ch: char| !ch.is_ascii_alphanumeric())
        .filter(|term| term.len() >= 3)
        .map(|term| term.to_ascii_lowercase())
        .collect()
}

fn memory_id(text: &str, source: &str, timestamp: i64) -> String {
    let mut hasher = Sha256::new();
    hasher.update(text.as_bytes());
    hasher.update(source.as_bytes());
    hasher.update(timestamp.to_le_bytes());
    format!("mem:{}", hex_prefix(&hasher.finalize(), 16))
}

fn hex_prefix(bytes: &[u8], len: usize) -> String {
    bytes
        .iter()
        .flat_map(|byte| [byte >> 4, byte & 0x0f])
        .take(len)
        .map(|nibble| char::from_digit(nibble as u32, 16).unwrap_or('0'))
        .collect()
}

fn is_path_like(value: &str) -> bool {
    value.contains('/')
        || value.ends_with(".rs")
        || value.ends_with(".ts")
        || value.ends_with(".tsx")
        || value.ends_with(".js")
        || value.ends_with(".jsx")
        || value.ends_with(".java")
        || value.ends_with(".py")
        || value.ends_with(".go")
        || value.ends_with(".md")
}

fn is_ticket_id(value: &str) -> bool {
    let Some((prefix, number)) = value.split_once('-') else {
        return false;
    };
    prefix.len() >= 2
        && prefix.chars().all(|ch| ch.is_ascii_uppercase())
        && number.len() >= 2
        && number.chars().all(|ch| ch.is_ascii_digit())
}

fn is_identifier(value: &str) -> bool {
    let has_lower = value.chars().any(|ch| ch.is_ascii_lowercase());
    let has_upper = value.chars().any(|ch| ch.is_ascii_uppercase());
    let has_separator = value.contains('_') || value.contains("::");
    has_separator || (has_lower && has_upper)
}

fn storage_err(err: rusqlite::Error) -> OkError {
    OkError::Storage(err.to_string())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn stores_and_searches_entity_linked_facts() {
        let dir = tempfile::tempdir().unwrap();
        let store = RepoMemoryStore::open_repo(dir.path()).unwrap();
        let fact = store
            .remember(
                "RATE-7031 maps PublishRestrictionsMutation to GqlPublishRestrictionsTest",
                "test",
                Confidence::High,
            )
            .unwrap();

        let results = store
            .search("PublishRestrictionsMutation RATE-7031", 5)
            .unwrap();

        assert_eq!(results.len(), 1);
        assert_eq!(results[0].fact.id, fact.id);
        assert!(results[0]
            .evidence
            .iter()
            .any(|evidence| evidence.contains("entity link")));
    }
}