Skip to main content

open_kioku_memory/
lib.rs

1use chrono::Utc;
2use open_kioku_core::{Confidence, EntityLink, MemoryFact, MemoryFactId, MemorySearchResult};
3use open_kioku_errors::{OkError, Result};
4use rusqlite::{params, Connection, OptionalExtension};
5use sha2::{Digest, Sha256};
6use std::path::{Path, PathBuf};
7use std::sync::Mutex;
8
9pub struct RepoMemoryStore {
10    connection: Mutex<Connection>,
11}
12
13impl RepoMemoryStore {
14    pub fn open(path: impl AsRef<Path>) -> Result<Self> {
15        let path = path.as_ref();
16        if let Some(parent) = path.parent() {
17            std::fs::create_dir_all(parent)
18                .map_err(|err| OkError::Storage(format!("create memory dir: {err}")))?;
19        }
20        let connection = Connection::open(path).map_err(storage_err)?;
21        let store = Self {
22            connection: Mutex::new(connection),
23        };
24        store.initialize()?;
25        Ok(store)
26    }
27
28    pub fn open_repo(repo: impl AsRef<Path>) -> Result<Self> {
29        Self::open(default_memory_path(repo))
30    }
31
32    pub fn remember(&self, text: &str, source: &str, confidence: Confidence) -> Result<MemoryFact> {
33        let text = text.trim();
34        if text.is_empty() {
35            return Err(OkError::Config("memory fact text cannot be empty".into()));
36        }
37        let created_at = Utc::now();
38        let fact = MemoryFact {
39            id: MemoryFactId::new(memory_id(text, source, created_at.timestamp_micros())),
40            text: text.into(),
41            source: source.into(),
42            confidence,
43            entities: extract_entities(text),
44            created_at,
45        };
46        let conn = self
47            .connection
48            .lock()
49            .map_err(|_| OkError::Storage("memory sqlite mutex poisoned".into()))?;
50        conn.execute(
51            "INSERT INTO memory_facts(id, created_at, source, text, json) VALUES(?1, ?2, ?3, ?4, ?5)",
52            params![
53                &fact.id.0,
54                fact.created_at.to_rfc3339(),
55                &fact.source,
56                &fact.text,
57                serde_json::to_string(&fact)?
58            ],
59        )
60        .map_err(storage_err)?;
61        Ok(fact)
62    }
63
64    pub fn get(&self, id: &MemoryFactId) -> Result<Option<MemoryFact>> {
65        let conn = self
66            .connection
67            .lock()
68            .map_err(|_| OkError::Storage("memory sqlite mutex poisoned".into()))?;
69        let raw = conn
70            .query_row(
71                "SELECT json FROM memory_facts WHERE id = ?1",
72                params![&id.0],
73                |row| row.get::<_, String>(0),
74            )
75            .optional()
76            .map_err(storage_err)?;
77        raw.map(|json| serde_json::from_str(&json).map_err(Into::into))
78            .transpose()
79    }
80
81    pub fn search(&self, query: &str, limit: usize) -> Result<Vec<MemorySearchResult>> {
82        let query = query.trim();
83        if query.is_empty() {
84            return Ok(Vec::new());
85        }
86        let facts = self.recent(500)?;
87        let query_terms = terms(query);
88        let query_entities = extract_entities(query);
89        let mut scored = facts
90            .into_iter()
91            .filter_map(|fact| score_fact(fact, &query_terms, &query_entities))
92            .collect::<Vec<_>>();
93        scored.sort_by(|a, b| {
94            b.score
95                .partial_cmp(&a.score)
96                .unwrap_or(std::cmp::Ordering::Equal)
97                .then_with(|| b.fact.created_at.cmp(&a.fact.created_at))
98        });
99        scored.truncate(limit.min(100));
100        Ok(scored)
101    }
102
103    pub fn recent(&self, limit: usize) -> Result<Vec<MemoryFact>> {
104        let conn = self
105            .connection
106            .lock()
107            .map_err(|_| OkError::Storage("memory sqlite mutex poisoned".into()))?;
108        let mut stmt = conn
109            .prepare("SELECT json FROM memory_facts ORDER BY created_at DESC LIMIT ?1")
110            .map_err(storage_err)?;
111        let rows = stmt
112            .query_map(params![limit as i64], |row| row.get::<_, String>(0))
113            .map_err(storage_err)?;
114        let mut facts = Vec::new();
115        for row in rows {
116            facts.push(serde_json::from_str(&row.map_err(storage_err)?)?);
117        }
118        Ok(facts)
119    }
120
121    fn initialize(&self) -> Result<()> {
122        let conn = self
123            .connection
124            .lock()
125            .map_err(|_| OkError::Storage("memory sqlite mutex poisoned".into()))?;
126        conn.execute_batch(
127            "
128            CREATE TABLE IF NOT EXISTS memory_facts (
129                id TEXT PRIMARY KEY,
130                created_at TEXT NOT NULL,
131                source TEXT NOT NULL,
132                text TEXT NOT NULL,
133                json TEXT NOT NULL
134            );
135            CREATE INDEX IF NOT EXISTS idx_memory_created_at ON memory_facts(created_at);
136            CREATE INDEX IF NOT EXISTS idx_memory_source ON memory_facts(source);
137            ",
138        )
139        .map_err(storage_err)?;
140        Ok(())
141    }
142}
143
144pub fn default_memory_path(repo: impl AsRef<Path>) -> PathBuf {
145    repo.as_ref().join(".ok/memory.sqlite")
146}
147
148pub fn extract_entities(text: &str) -> Vec<EntityLink> {
149    let mut entities = Vec::new();
150    for token in text.split_whitespace() {
151        let cleaned = token.trim_matches(|ch: char| {
152            !(ch.is_ascii_alphanumeric()
153                || ch == '_'
154                || ch == '-'
155                || ch == '/'
156                || ch == '.'
157                || ch == ':')
158        });
159        if cleaned.len() < 3 {
160            continue;
161        }
162        let kind = if is_path_like(cleaned) {
163            "file"
164        } else if is_ticket_id(cleaned) {
165            "ticket"
166        } else if cleaned.starts_with("cargo")
167            || cleaned.starts_with("./")
168            || cleaned.starts_with("npm")
169            || cleaned.starts_with("pytest")
170        {
171            "command"
172        } else if is_identifier(cleaned) {
173            "symbol"
174        } else {
175            continue;
176        };
177        if !entities
178            .iter()
179            .any(|entity: &EntityLink| entity.kind == kind && entity.value == cleaned)
180        {
181            entities.push(EntityLink {
182                kind: kind.into(),
183                value: cleaned.into(),
184                file_range: None,
185                confidence: Confidence::Medium,
186            });
187        }
188    }
189    entities
190}
191
192fn score_fact(
193    fact: MemoryFact,
194    query_terms: &[String],
195    query_entities: &[EntityLink],
196) -> Option<MemorySearchResult> {
197    let lower = fact.text.to_ascii_lowercase();
198    let mut score = 0.0;
199    let mut evidence = Vec::new();
200    let term_hits = query_terms
201        .iter()
202        .filter(|term| lower.contains(term.as_str()))
203        .count();
204    if term_hits > 0 {
205        score += 0.25 + term_hits as f32 * 0.08;
206        evidence.push(format!("{term_hits} lexical term match(es)"));
207    }
208    let entity_hits = query_entities
209        .iter()
210        .filter(|query_entity| {
211            fact.entities.iter().any(|fact_entity| {
212                fact_entity.kind == query_entity.kind && fact_entity.value == query_entity.value
213            })
214        })
215        .count();
216    if entity_hits > 0 {
217        score += 0.45 + entity_hits as f32 * 0.15;
218        evidence.push(format!("{entity_hits} entity link match(es)"));
219    }
220    score += fact.confidence.score() * 0.1;
221    if evidence.is_empty() {
222        return None;
223    }
224    Some(MemorySearchResult {
225        fact,
226        score,
227        match_reason: "repo memory lexical/entity match".into(),
228        evidence,
229    })
230}
231
232fn terms(query: &str) -> Vec<String> {
233    query
234        .split(|ch: char| !ch.is_ascii_alphanumeric())
235        .filter(|term| term.len() >= 3)
236        .map(|term| term.to_ascii_lowercase())
237        .collect()
238}
239
240fn memory_id(text: &str, source: &str, timestamp: i64) -> String {
241    let mut hasher = Sha256::new();
242    hasher.update(text.as_bytes());
243    hasher.update(source.as_bytes());
244    hasher.update(timestamp.to_le_bytes());
245    format!("mem:{}", hex_prefix(&hasher.finalize(), 16))
246}
247
248fn hex_prefix(bytes: &[u8], len: usize) -> String {
249    bytes
250        .iter()
251        .flat_map(|byte| [byte >> 4, byte & 0x0f])
252        .take(len)
253        .map(|nibble| char::from_digit(nibble as u32, 16).unwrap_or('0'))
254        .collect()
255}
256
257fn is_path_like(value: &str) -> bool {
258    value.contains('/')
259        || value.ends_with(".rs")
260        || value.ends_with(".ts")
261        || value.ends_with(".tsx")
262        || value.ends_with(".js")
263        || value.ends_with(".jsx")
264        || value.ends_with(".java")
265        || value.ends_with(".py")
266        || value.ends_with(".go")
267        || value.ends_with(".md")
268}
269
270fn is_ticket_id(value: &str) -> bool {
271    let Some((prefix, number)) = value.split_once('-') else {
272        return false;
273    };
274    prefix.len() >= 2
275        && prefix.chars().all(|ch| ch.is_ascii_uppercase())
276        && number.len() >= 2
277        && number.chars().all(|ch| ch.is_ascii_digit())
278}
279
280fn is_identifier(value: &str) -> bool {
281    let has_lower = value.chars().any(|ch| ch.is_ascii_lowercase());
282    let has_upper = value.chars().any(|ch| ch.is_ascii_uppercase());
283    let has_separator = value.contains('_') || value.contains("::");
284    has_separator || (has_lower && has_upper)
285}
286
287fn storage_err(err: rusqlite::Error) -> OkError {
288    OkError::Storage(err.to_string())
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294
295    #[test]
296    fn stores_and_searches_entity_linked_facts() {
297        let dir = tempfile::tempdir().unwrap();
298        let store = RepoMemoryStore::open_repo(dir.path()).unwrap();
299        let fact = store
300            .remember(
301                "RATE-7031 maps PublishRestrictionsMutation to GqlPublishRestrictionsTest",
302                "test",
303                Confidence::High,
304            )
305            .unwrap();
306
307        let results = store
308            .search("PublishRestrictionsMutation RATE-7031", 5)
309            .unwrap();
310
311        assert_eq!(results.len(), 1);
312        assert_eq!(results[0].fact.id, fact.id);
313        assert!(results[0]
314            .evidence
315            .iter()
316            .any(|evidence| evidence.contains("entity link")));
317    }
318}