1use chrono::Utc;
2use open_kioku_core::{
3 CompressedContextPack, Confidence, ContextHandle, ContextHandleId, ContextPack, Evidence,
4 EvidenceId, EvidenceSourceType, FileRange, LineRange, SearchResult,
5};
6use open_kioku_errors::{OkError, Result};
7use open_kioku_memory::extract_entities;
8use rusqlite::{params, Connection, OptionalExtension};
9use serde::{Deserialize, Serialize};
10use sha2::{Digest, Sha256};
11use std::path::{Path, PathBuf};
12use std::sync::Mutex;
13
14pub struct ContextHandleStore {
15 connection: Mutex<Connection>,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct RetrievedContext {
20 pub handle: ContextHandle,
21 pub original: String,
22 pub created_at: chrono::DateTime<Utc>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26struct StoredContext {
27 handle: ContextHandle,
28 original: String,
29 created_at: chrono::DateTime<Utc>,
30}
31
32impl ContextHandleStore {
33 pub fn open(path: impl AsRef<Path>) -> Result<Self> {
34 let path = path.as_ref();
35 if let Some(parent) = path.parent() {
36 std::fs::create_dir_all(parent)
37 .map_err(|err| OkError::Storage(format!("create context dir: {err}")))?;
38 }
39 let connection = Connection::open(path).map_err(storage_err)?;
40 let store = Self {
41 connection: Mutex::new(connection),
42 };
43 store.initialize()?;
44 Ok(store)
45 }
46
47 pub fn open_repo(repo: impl AsRef<Path>) -> Result<Self> {
48 Self::open(default_context_path(repo))
49 }
50
51 pub fn compress_pack(&self, pack: &ContextPack) -> Result<CompressedContextPack> {
52 let mut handles = Vec::new();
53 for result in &pack.primary_files {
54 handles.push(self.store_search_result("primary", result)?);
55 }
56 for result in &pack.supporting_files {
57 handles.push(self.store_search_result("impact", result)?);
58 }
59 for test in &pack.validation_plan.tests {
60 let original = format!(
61 "{}\ncommand: {}\nreason: {}",
62 test.name,
63 test.command.as_deref().unwrap_or("manual validation"),
64 test.reason
65 );
66 handles.push(self.store_original("test", &test.name, None, &original)?);
67 }
68
69 handles.sort_by(|a, b| a.id.cmp(&b.id));
70 handles.dedup_by(|a, b| a.id == b.id);
71
72 let original_tokens = handles
73 .iter()
74 .map(|handle| handle.original_tokens_estimate)
75 .sum::<usize>();
76 let compressed_tokens = handles
77 .iter()
78 .map(|handle| handle.compressed_tokens_estimate)
79 .sum::<usize>();
80 let compression_ratio = if original_tokens == 0 {
81 1.0
82 } else {
83 compressed_tokens as f32 / original_tokens as f32
84 };
85 let summary = format!(
86 "{} handle(s), estimated {} -> {} tokens. Retrieve originals with `retrieve_context`.",
87 handles.len(),
88 original_tokens,
89 compressed_tokens
90 );
91 Ok(CompressedContextPack {
92 task: pack.task.clone(),
93 summary,
94 handles,
95 original_tokens_estimate: original_tokens,
96 compressed_tokens_estimate: compressed_tokens,
97 compression_ratio,
98 evidence: vec![Evidence {
99 id: EvidenceId::new(format!("context-compress:{}", stable_hash(&pack.task, 12))),
100 source: "open-kioku-context-compress".into(),
101 source_type: EvidenceSourceType::Heuristic,
102 file_range: None,
103 symbol_id: None,
104 confidence: Confidence::Medium,
105 message: "context pack compressed into reversible local handles".into(),
106 indexed_at: Utc::now(),
107 }],
108 })
109 }
110
111 pub fn retrieve(&self, handle: &ContextHandleId) -> Result<Option<RetrievedContext>> {
112 let conn = self
113 .connection
114 .lock()
115 .map_err(|_| OkError::Storage("context sqlite mutex poisoned".into()))?;
116 let raw = conn
117 .query_row(
118 "SELECT json FROM context_handles WHERE id = ?1",
119 params![&handle.0],
120 |row| row.get::<_, String>(0),
121 )
122 .optional()
123 .map_err(storage_err)?;
124 let Some(raw) = raw else {
125 return Ok(None);
126 };
127 let stored: StoredContext = serde_json::from_str(&raw)?;
128 Ok(Some(RetrievedContext {
129 handle: stored.handle,
130 original: stored.original,
131 created_at: stored.created_at,
132 }))
133 }
134
135 fn store_search_result(&self, kind: &str, result: &SearchResult) -> Result<ContextHandle> {
136 let title = format!(
137 "{}{}",
138 result.path.display(),
139 line_suffix(&result.line_range)
140 );
141 let file_range = result.line_range.clone().map(|line_range| FileRange {
142 path: result.path.clone(),
143 line_range: Some(line_range),
144 });
145 self.store_original(kind, &title, file_range, &result.snippet)
146 }
147
148 fn store_original(
149 &self,
150 kind: &str,
151 title: &str,
152 file_range: Option<FileRange>,
153 original: &str,
154 ) -> Result<ContextHandle> {
155 let summary = summarize(kind, title, original);
156 let compressed_tokens_estimate = compressed_token_estimate(&summary);
157 let handle = ContextHandle {
158 id: ContextHandleId::new(format!(
159 "ctx:{}",
160 stable_hash(&format!("{kind}:{title}:{original}"), 16)
161 )),
162 kind: kind.into(),
163 summary,
164 file_range,
165 entities: extract_entities(&format!("{title} {original}")),
166 original_tokens_estimate: estimate_tokens(original),
167 compressed_tokens_estimate,
168 };
169 let stored = StoredContext {
170 handle: handle.clone(),
171 original: original.into(),
172 created_at: Utc::now(),
173 };
174 let conn = self
175 .connection
176 .lock()
177 .map_err(|_| OkError::Storage("context sqlite mutex poisoned".into()))?;
178 conn.execute(
179 "INSERT OR REPLACE INTO context_handles(id, kind, created_at, json) VALUES(?1, ?2, ?3, ?4)",
180 params![
181 &handle.id.0,
182 &handle.kind,
183 stored.created_at.to_rfc3339(),
184 serde_json::to_string(&stored)?
185 ],
186 )
187 .map_err(storage_err)?;
188 Ok(handle)
189 }
190
191 fn initialize(&self) -> Result<()> {
192 let conn = self
193 .connection
194 .lock()
195 .map_err(|_| OkError::Storage("context sqlite mutex poisoned".into()))?;
196 conn.execute_batch(
197 "
198 CREATE TABLE IF NOT EXISTS context_handles (
199 id TEXT PRIMARY KEY,
200 kind TEXT NOT NULL,
201 created_at TEXT NOT NULL,
202 json TEXT NOT NULL
203 );
204 CREATE INDEX IF NOT EXISTS idx_context_kind ON context_handles(kind);
205 CREATE INDEX IF NOT EXISTS idx_context_created_at ON context_handles(created_at);
206 ",
207 )
208 .map_err(storage_err)?;
209 Ok(())
210 }
211}
212
213pub fn default_context_path(repo: impl AsRef<Path>) -> PathBuf {
214 repo.as_ref().join(".ok/context.sqlite")
215}
216
217fn summarize(kind: &str, title: &str, original: &str) -> String {
218 let signal = original
219 .lines()
220 .find(|line| !line.trim().is_empty())
221 .unwrap_or_default()
222 .split_whitespace()
223 .take(8)
224 .collect::<Vec<_>>()
225 .join(" ");
226 let title = compact_title(title);
227 if signal.is_empty() {
228 format!("{kind} {title}")
229 } else {
230 format!("{kind} {title}: {signal}")
231 }
232}
233
234fn estimate_tokens(value: &str) -> usize {
235 value.split_whitespace().count().max(value.len() / 4)
236}
237
238fn compressed_token_estimate(summary: &str) -> usize {
239 summary.split_whitespace().count().saturating_add(3).max(4)
240}
241
242fn compact_title(title: &str) -> String {
243 let Some((path, range)) = title.rsplit_once(':') else {
244 return title.into();
245 };
246 let file = path.rsplit('/').next().unwrap_or(path);
247 if range.contains('-') && range.chars().all(|ch| ch.is_ascii_digit() || ch == '-') {
248 format!("{file}:{range}")
249 } else {
250 title.into()
251 }
252}
253
254fn line_suffix(range: &Option<LineRange>) -> String {
255 range
256 .as_ref()
257 .map(|range| format!(":{}-{}", range.start, range.end))
258 .unwrap_or_default()
259}
260
261fn stable_hash(value: &str, len: usize) -> String {
262 let mut hasher = Sha256::new();
263 hasher.update(value.as_bytes());
264 let digest = hasher.finalize();
265 digest
266 .iter()
267 .flat_map(|byte| [byte >> 4, byte & 0x0f])
268 .take(len)
269 .map(|nibble| char::from_digit(nibble as u32, 16).unwrap_or('0'))
270 .collect()
271}
272
273fn storage_err(err: rusqlite::Error) -> OkError {
274 OkError::Storage(err.to_string())
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280 use open_kioku_core::{ChangeBoundary, RiskReport, ScoreComponent, ValidationPlan};
281
282 #[test]
283 fn compresses_and_retrieves_context_handles() {
284 let dir = tempfile::tempdir().unwrap();
285 let store = ContextHandleStore::open_repo(dir.path()).unwrap();
286 let pack = ContextPack {
287 task: "token".into(),
288 intent: "code_change".into(),
289 primary_files: vec![SearchResult {
290 path: "src/auth.rs".into(),
291 line_range: Some(LineRange { start: 1, end: 18 }),
292 snippet: r#"pub fn issue_token(user: &User, grants: &[Grant]) -> Result<String> {
293 let subject = user.subject().ok_or(AuthError::MissingSubject)?;
294 let audience = grants
295 .iter()
296 .filter(|grant| grant.is_active())
297 .map(|grant| grant.audience())
298 .collect::<Vec<_>>();
299 let claims = TokenClaims {
300 subject: subject.to_owned(),
301 audience,
302 issued_at: clock::now(),
303 expires_at: clock::now() + TOKEN_TTL,
304 };
305 signer::sign_claims(&claims).map_err(AuthError::from)
306}"#
307 .into(),
308 symbol: None,
309 score: 1.0,
310 match_reason: "test".into(),
311 evidence: Vec::new(),
312 evidence_refs: Vec::new(),
313 confidence: 1.0,
314 score_breakdown: vec![ScoreComponent::single(
315 "test_score",
316 1.0,
317 Vec::new(),
318 "test fixture",
319 )],
320 }],
321 primary_symbols: Vec::new(),
322 supporting_files: Vec::new(),
323 dependency_edges: Vec::new(),
324 runtime_signals: Vec::new(),
325 test_candidates: Vec::new(),
326 risk_report: RiskReport {
327 level: "low".into(),
328 score: 0.1,
329 reasons: Vec::new(),
330 },
331 recommended_change_boundary: ChangeBoundary {
332 allowed_files: Vec::new(),
333 caution_files: Vec::new(),
334 forbidden_files: Vec::new(),
335 evidence_refs: Vec::new(),
336 ..Default::default()
337 },
338 validation_plan: ValidationPlan {
339 commands: Vec::new(),
340 tests: Vec::new(),
341 requires_approval: false,
342 evidence: Vec::new(),
343 },
344 evidence: Vec::new(),
345 negative_evidence: Vec::new(),
346 confidence_summary: "test".into(),
347 confidence_breakdown: open_kioku_core::ConfidenceBreakdown::default(),
348 };
349
350 let compressed = store.compress_pack(&pack).unwrap();
351 let retrieved = store.retrieve(&compressed.handles[0].id).unwrap().unwrap();
352
353 assert!(compressed.compression_ratio < 1.0);
354 assert!(retrieved.original.contains("issue_token"));
355 }
356}