1use chrono::Utc;
2use open_kioku_core::{Confidence, EntityLink, MemoryFact, MemoryFactId, MemorySearchResult};
3use open_kioku_errors::{OkError, Result};
4use rusqlite::{params, Connection, OptionalExtension};
5use sha2::{Digest, Sha256};
6use std::path::{Path, PathBuf};
7use std::sync::Mutex;
8
9pub struct RepoMemoryStore {
10 connection: Mutex<Connection>,
11}
12
13impl RepoMemoryStore {
14 pub fn open(path: impl AsRef<Path>) -> Result<Self> {
15 let path = path.as_ref();
16 if let Some(parent) = path.parent() {
17 std::fs::create_dir_all(parent)
18 .map_err(|err| OkError::Storage(format!("create memory dir: {err}")))?;
19 }
20 let connection = Connection::open(path).map_err(storage_err)?;
21 let store = Self {
22 connection: Mutex::new(connection),
23 };
24 store.initialize()?;
25 Ok(store)
26 }
27
28 pub fn open_repo(repo: impl AsRef<Path>) -> Result<Self> {
29 Self::open(default_memory_path(repo))
30 }
31
32 pub fn remember(&self, text: &str, source: &str, confidence: Confidence) -> Result<MemoryFact> {
33 let text = text.trim();
34 if text.is_empty() {
35 return Err(OkError::Config("memory fact text cannot be empty".into()));
36 }
37 let created_at = Utc::now();
38 let fact = MemoryFact {
39 id: MemoryFactId::new(memory_id(text, source, created_at.timestamp_micros())),
40 text: text.into(),
41 source: source.into(),
42 confidence,
43 entities: extract_entities(text),
44 created_at,
45 };
46 let conn = self
47 .connection
48 .lock()
49 .map_err(|_| OkError::Storage("memory sqlite mutex poisoned".into()))?;
50 conn.execute(
51 "INSERT INTO memory_facts(id, created_at, source, text, json) VALUES(?1, ?2, ?3, ?4, ?5)",
52 params![
53 &fact.id.0,
54 fact.created_at.to_rfc3339(),
55 &fact.source,
56 &fact.text,
57 serde_json::to_string(&fact)?
58 ],
59 )
60 .map_err(storage_err)?;
61 Ok(fact)
62 }
63
64 pub fn get(&self, id: &MemoryFactId) -> Result<Option<MemoryFact>> {
65 let conn = self
66 .connection
67 .lock()
68 .map_err(|_| OkError::Storage("memory sqlite mutex poisoned".into()))?;
69 let raw = conn
70 .query_row(
71 "SELECT json FROM memory_facts WHERE id = ?1",
72 params![&id.0],
73 |row| row.get::<_, String>(0),
74 )
75 .optional()
76 .map_err(storage_err)?;
77 raw.map(|json| serde_json::from_str(&json).map_err(Into::into))
78 .transpose()
79 }
80
81 pub fn search(&self, query: &str, limit: usize) -> Result<Vec<MemorySearchResult>> {
82 let query = query.trim();
83 if query.is_empty() {
84 return Ok(Vec::new());
85 }
86 let facts = self.recent(500)?;
87 let query_terms = terms(query);
88 let query_entities = extract_entities(query);
89 let mut scored = facts
90 .into_iter()
91 .filter_map(|fact| score_fact(fact, &query_terms, &query_entities))
92 .collect::<Vec<_>>();
93 scored.sort_by(|a, b| {
94 b.score
95 .partial_cmp(&a.score)
96 .unwrap_or(std::cmp::Ordering::Equal)
97 .then_with(|| b.fact.created_at.cmp(&a.fact.created_at))
98 });
99 scored.truncate(limit.min(100));
100 Ok(scored)
101 }
102
103 pub fn recent(&self, limit: usize) -> Result<Vec<MemoryFact>> {
104 let conn = self
105 .connection
106 .lock()
107 .map_err(|_| OkError::Storage("memory sqlite mutex poisoned".into()))?;
108 let mut stmt = conn
109 .prepare("SELECT json FROM memory_facts ORDER BY created_at DESC LIMIT ?1")
110 .map_err(storage_err)?;
111 let rows = stmt
112 .query_map(params![limit as i64], |row| row.get::<_, String>(0))
113 .map_err(storage_err)?;
114 let mut facts = Vec::new();
115 for row in rows {
116 facts.push(serde_json::from_str(&row.map_err(storage_err)?)?);
117 }
118 Ok(facts)
119 }
120
121 fn initialize(&self) -> Result<()> {
122 let conn = self
123 .connection
124 .lock()
125 .map_err(|_| OkError::Storage("memory sqlite mutex poisoned".into()))?;
126 conn.execute_batch(
127 "
128 CREATE TABLE IF NOT EXISTS memory_facts (
129 id TEXT PRIMARY KEY,
130 created_at TEXT NOT NULL,
131 source TEXT NOT NULL,
132 text TEXT NOT NULL,
133 json TEXT NOT NULL
134 );
135 CREATE INDEX IF NOT EXISTS idx_memory_created_at ON memory_facts(created_at);
136 CREATE INDEX IF NOT EXISTS idx_memory_source ON memory_facts(source);
137 ",
138 )
139 .map_err(storage_err)?;
140 Ok(())
141 }
142}
143
144pub fn default_memory_path(repo: impl AsRef<Path>) -> PathBuf {
145 repo.as_ref().join(".ok/memory.sqlite")
146}
147
148pub fn extract_entities(text: &str) -> Vec<EntityLink> {
149 let mut entities = Vec::new();
150 for token in text.split_whitespace() {
151 let cleaned = token.trim_matches(|ch: char| {
152 !(ch.is_ascii_alphanumeric()
153 || ch == '_'
154 || ch == '-'
155 || ch == '/'
156 || ch == '.'
157 || ch == ':')
158 });
159 if cleaned.len() < 3 {
160 continue;
161 }
162 let kind = if is_path_like(cleaned) {
163 "file"
164 } else if is_ticket_id(cleaned) {
165 "ticket"
166 } else if cleaned.starts_with("cargo")
167 || cleaned.starts_with("./")
168 || cleaned.starts_with("npm")
169 || cleaned.starts_with("pytest")
170 {
171 "command"
172 } else if is_identifier(cleaned) {
173 "symbol"
174 } else {
175 continue;
176 };
177 if !entities
178 .iter()
179 .any(|entity: &EntityLink| entity.kind == kind && entity.value == cleaned)
180 {
181 entities.push(EntityLink {
182 kind: kind.into(),
183 value: cleaned.into(),
184 file_range: None,
185 confidence: Confidence::Medium,
186 });
187 }
188 }
189 entities
190}
191
192fn score_fact(
193 fact: MemoryFact,
194 query_terms: &[String],
195 query_entities: &[EntityLink],
196) -> Option<MemorySearchResult> {
197 let lower = fact.text.to_ascii_lowercase();
198 let mut score = 0.0;
199 let mut evidence = Vec::new();
200 let term_hits = query_terms
201 .iter()
202 .filter(|term| lower.contains(term.as_str()))
203 .count();
204 if term_hits > 0 {
205 score += 0.25 + term_hits as f32 * 0.08;
206 evidence.push(format!("{term_hits} lexical term match(es)"));
207 }
208 let entity_hits = query_entities
209 .iter()
210 .filter(|query_entity| {
211 fact.entities.iter().any(|fact_entity| {
212 fact_entity.kind == query_entity.kind && fact_entity.value == query_entity.value
213 })
214 })
215 .count();
216 if entity_hits > 0 {
217 score += 0.45 + entity_hits as f32 * 0.15;
218 evidence.push(format!("{entity_hits} entity link match(es)"));
219 }
220 score += fact.confidence.score() * 0.1;
221 if evidence.is_empty() {
222 return None;
223 }
224 Some(MemorySearchResult {
225 fact,
226 score,
227 match_reason: "repo memory lexical/entity match".into(),
228 evidence,
229 })
230}
231
232fn terms(query: &str) -> Vec<String> {
233 query
234 .split(|ch: char| !ch.is_ascii_alphanumeric())
235 .filter(|term| term.len() >= 3)
236 .map(|term| term.to_ascii_lowercase())
237 .collect()
238}
239
240fn memory_id(text: &str, source: &str, timestamp: i64) -> String {
241 let mut hasher = Sha256::new();
242 hasher.update(text.as_bytes());
243 hasher.update(source.as_bytes());
244 hasher.update(timestamp.to_le_bytes());
245 format!("mem:{}", hex_prefix(&hasher.finalize(), 16))
246}
247
248fn hex_prefix(bytes: &[u8], len: usize) -> String {
249 bytes
250 .iter()
251 .flat_map(|byte| [byte >> 4, byte & 0x0f])
252 .take(len)
253 .map(|nibble| char::from_digit(nibble as u32, 16).unwrap_or('0'))
254 .collect()
255}
256
257fn is_path_like(value: &str) -> bool {
258 value.contains('/')
259 || value.ends_with(".rs")
260 || value.ends_with(".ts")
261 || value.ends_with(".tsx")
262 || value.ends_with(".js")
263 || value.ends_with(".jsx")
264 || value.ends_with(".java")
265 || value.ends_with(".py")
266 || value.ends_with(".go")
267 || value.ends_with(".md")
268}
269
270fn is_ticket_id(value: &str) -> bool {
271 let Some((prefix, number)) = value.split_once('-') else {
272 return false;
273 };
274 prefix.len() >= 2
275 && prefix.chars().all(|ch| ch.is_ascii_uppercase())
276 && number.len() >= 2
277 && number.chars().all(|ch| ch.is_ascii_digit())
278}
279
280fn is_identifier(value: &str) -> bool {
281 let has_lower = value.chars().any(|ch| ch.is_ascii_lowercase());
282 let has_upper = value.chars().any(|ch| ch.is_ascii_uppercase());
283 let has_separator = value.contains('_') || value.contains("::");
284 has_separator || (has_lower && has_upper)
285}
286
287fn storage_err(err: rusqlite::Error) -> OkError {
288 OkError::Storage(err.to_string())
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294
295 #[test]
296 fn stores_and_searches_entity_linked_facts() {
297 let dir = tempfile::tempdir().unwrap();
298 let store = RepoMemoryStore::open_repo(dir.path()).unwrap();
299 let fact = store
300 .remember(
301 "RATE-7031 maps PublishRestrictionsMutation to GqlPublishRestrictionsTest",
302 "test",
303 Confidence::High,
304 )
305 .unwrap();
306
307 let results = store
308 .search("PublishRestrictionsMutation RATE-7031", 5)
309 .unwrap();
310
311 assert_eq!(results.len(), 1);
312 assert_eq!(results[0].fact.id, fact.id);
313 assert!(results[0]
314 .evidence
315 .iter()
316 .any(|evidence| evidence.contains("entity link")));
317 }
318}