Skip to main content

open_kioku_context_compress/
lib.rs

1use chrono::Utc;
2use open_kioku_core::{
3    CompressedContextPack, Confidence, ContextHandle, ContextHandleId, ContextPack, Evidence,
4    EvidenceId, EvidenceSourceType, FileRange, LineRange, SearchResult,
5};
6use open_kioku_errors::{OkError, Result};
7use open_kioku_memory::extract_entities;
8use rusqlite::{params, Connection, OptionalExtension};
9use serde::{Deserialize, Serialize};
10use sha2::{Digest, Sha256};
11use std::path::{Path, PathBuf};
12use std::sync::Mutex;
13
14pub struct ContextHandleStore {
15    connection: Mutex<Connection>,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct RetrievedContext {
20    pub handle: ContextHandle,
21    pub original: String,
22    pub created_at: chrono::DateTime<Utc>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26struct StoredContext {
27    handle: ContextHandle,
28    original: String,
29    created_at: chrono::DateTime<Utc>,
30}
31
32impl ContextHandleStore {
33    pub fn open(path: impl AsRef<Path>) -> Result<Self> {
34        let path = path.as_ref();
35        if let Some(parent) = path.parent() {
36            std::fs::create_dir_all(parent)
37                .map_err(|err| OkError::Storage(format!("create context dir: {err}")))?;
38        }
39        let connection = Connection::open(path).map_err(storage_err)?;
40        let store = Self {
41            connection: Mutex::new(connection),
42        };
43        store.initialize()?;
44        Ok(store)
45    }
46
47    pub fn open_repo(repo: impl AsRef<Path>) -> Result<Self> {
48        Self::open(default_context_path(repo))
49    }
50
51    pub fn compress_pack(&self, pack: &ContextPack) -> Result<CompressedContextPack> {
52        let mut handles = Vec::new();
53        for result in &pack.primary_files {
54            handles.push(self.store_search_result("primary", result)?);
55        }
56        for result in &pack.supporting_files {
57            handles.push(self.store_search_result("impact", result)?);
58        }
59        for test in &pack.validation_plan.tests {
60            let original = format!(
61                "{}\ncommand: {}\nreason: {}",
62                test.name,
63                test.command.as_deref().unwrap_or("manual validation"),
64                test.reason
65            );
66            handles.push(self.store_original("test", &test.name, None, &original)?);
67        }
68
69        handles.sort_by(|a, b| a.id.cmp(&b.id));
70        handles.dedup_by(|a, b| a.id == b.id);
71
72        let original_tokens = handles
73            .iter()
74            .map(|handle| handle.original_tokens_estimate)
75            .sum::<usize>();
76        let compressed_tokens = handles
77            .iter()
78            .map(|handle| handle.compressed_tokens_estimate)
79            .sum::<usize>();
80        let compression_ratio = if original_tokens == 0 {
81            1.0
82        } else {
83            compressed_tokens as f32 / original_tokens as f32
84        };
85        let summary = format!(
86            "{} handle(s), estimated {} -> {} tokens. Retrieve originals with `retrieve_context`.",
87            handles.len(),
88            original_tokens,
89            compressed_tokens
90        );
91        Ok(CompressedContextPack {
92            task: pack.task.clone(),
93            summary,
94            handles,
95            original_tokens_estimate: original_tokens,
96            compressed_tokens_estimate: compressed_tokens,
97            compression_ratio,
98            evidence: vec![Evidence {
99                id: EvidenceId::new(format!("context-compress:{}", stable_hash(&pack.task, 12))),
100                source: "open-kioku-context-compress".into(),
101                source_type: EvidenceSourceType::Heuristic,
102                file_range: None,
103                symbol_id: None,
104                confidence: Confidence::Medium,
105                message: "context pack compressed into reversible local handles".into(),
106                indexed_at: Utc::now(),
107            }],
108        })
109    }
110
111    pub fn retrieve(&self, handle: &ContextHandleId) -> Result<Option<RetrievedContext>> {
112        let conn = self
113            .connection
114            .lock()
115            .map_err(|_| OkError::Storage("context sqlite mutex poisoned".into()))?;
116        let raw = conn
117            .query_row(
118                "SELECT json FROM context_handles WHERE id = ?1",
119                params![&handle.0],
120                |row| row.get::<_, String>(0),
121            )
122            .optional()
123            .map_err(storage_err)?;
124        let Some(raw) = raw else {
125            return Ok(None);
126        };
127        let stored: StoredContext = serde_json::from_str(&raw)?;
128        Ok(Some(RetrievedContext {
129            handle: stored.handle,
130            original: stored.original,
131            created_at: stored.created_at,
132        }))
133    }
134
135    fn store_search_result(&self, kind: &str, result: &SearchResult) -> Result<ContextHandle> {
136        let title = format!(
137            "{}{}",
138            result.path.display(),
139            line_suffix(&result.line_range)
140        );
141        let file_range = result.line_range.clone().map(|line_range| FileRange {
142            path: result.path.clone(),
143            line_range: Some(line_range),
144        });
145        self.store_original(kind, &title, file_range, &result.snippet)
146    }
147
148    fn store_original(
149        &self,
150        kind: &str,
151        title: &str,
152        file_range: Option<FileRange>,
153        original: &str,
154    ) -> Result<ContextHandle> {
155        let summary = summarize(kind, title, original);
156        let compressed_tokens_estimate = compressed_token_estimate(&summary);
157        let handle = ContextHandle {
158            id: ContextHandleId::new(format!(
159                "ctx:{}",
160                stable_hash(&format!("{kind}:{title}:{original}"), 16)
161            )),
162            kind: kind.into(),
163            summary,
164            file_range,
165            entities: extract_entities(&format!("{title} {original}")),
166            original_tokens_estimate: estimate_tokens(original),
167            compressed_tokens_estimate,
168        };
169        let stored = StoredContext {
170            handle: handle.clone(),
171            original: original.into(),
172            created_at: Utc::now(),
173        };
174        let conn = self
175            .connection
176            .lock()
177            .map_err(|_| OkError::Storage("context sqlite mutex poisoned".into()))?;
178        conn.execute(
179            "INSERT OR REPLACE INTO context_handles(id, kind, created_at, json) VALUES(?1, ?2, ?3, ?4)",
180            params![
181                &handle.id.0,
182                &handle.kind,
183                stored.created_at.to_rfc3339(),
184                serde_json::to_string(&stored)?
185            ],
186        )
187        .map_err(storage_err)?;
188        Ok(handle)
189    }
190
191    fn initialize(&self) -> Result<()> {
192        let conn = self
193            .connection
194            .lock()
195            .map_err(|_| OkError::Storage("context sqlite mutex poisoned".into()))?;
196        conn.execute_batch(
197            "
198            CREATE TABLE IF NOT EXISTS context_handles (
199                id TEXT PRIMARY KEY,
200                kind TEXT NOT NULL,
201                created_at TEXT NOT NULL,
202                json TEXT NOT NULL
203            );
204            CREATE INDEX IF NOT EXISTS idx_context_kind ON context_handles(kind);
205            CREATE INDEX IF NOT EXISTS idx_context_created_at ON context_handles(created_at);
206            ",
207        )
208        .map_err(storage_err)?;
209        Ok(())
210    }
211}
212
213pub fn default_context_path(repo: impl AsRef<Path>) -> PathBuf {
214    repo.as_ref().join(".ok/context.sqlite")
215}
216
217fn summarize(kind: &str, title: &str, original: &str) -> String {
218    let signal = original
219        .lines()
220        .find(|line| !line.trim().is_empty())
221        .unwrap_or_default()
222        .split_whitespace()
223        .take(8)
224        .collect::<Vec<_>>()
225        .join(" ");
226    let title = compact_title(title);
227    if signal.is_empty() {
228        format!("{kind} {title}")
229    } else {
230        format!("{kind} {title}: {signal}")
231    }
232}
233
234fn estimate_tokens(value: &str) -> usize {
235    value.split_whitespace().count().max(value.len() / 4)
236}
237
238fn compressed_token_estimate(summary: &str) -> usize {
239    summary.split_whitespace().count().saturating_add(3).max(4)
240}
241
242fn compact_title(title: &str) -> String {
243    let Some((path, range)) = title.rsplit_once(':') else {
244        return title.into();
245    };
246    let file = path.rsplit('/').next().unwrap_or(path);
247    if range.contains('-') && range.chars().all(|ch| ch.is_ascii_digit() || ch == '-') {
248        format!("{file}:{range}")
249    } else {
250        title.into()
251    }
252}
253
254fn line_suffix(range: &Option<LineRange>) -> String {
255    range
256        .as_ref()
257        .map(|range| format!(":{}-{}", range.start, range.end))
258        .unwrap_or_default()
259}
260
261fn stable_hash(value: &str, len: usize) -> String {
262    let mut hasher = Sha256::new();
263    hasher.update(value.as_bytes());
264    let digest = hasher.finalize();
265    digest
266        .iter()
267        .flat_map(|byte| [byte >> 4, byte & 0x0f])
268        .take(len)
269        .map(|nibble| char::from_digit(nibble as u32, 16).unwrap_or('0'))
270        .collect()
271}
272
273fn storage_err(err: rusqlite::Error) -> OkError {
274    OkError::Storage(err.to_string())
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280    use open_kioku_core::{ChangeBoundary, RiskReport, ScoreComponent, ValidationPlan};
281
282    #[test]
283    fn compresses_and_retrieves_context_handles() {
284        let dir = tempfile::tempdir().unwrap();
285        let store = ContextHandleStore::open_repo(dir.path()).unwrap();
286        let pack = ContextPack {
287            task: "token".into(),
288            intent: "code_change".into(),
289            primary_files: vec![SearchResult {
290                path: "src/auth.rs".into(),
291                line_range: Some(LineRange { start: 1, end: 18 }),
292                snippet: r#"pub fn issue_token(user: &User, grants: &[Grant]) -> Result<String> {
293    let subject = user.subject().ok_or(AuthError::MissingSubject)?;
294    let audience = grants
295        .iter()
296        .filter(|grant| grant.is_active())
297        .map(|grant| grant.audience())
298        .collect::<Vec<_>>();
299    let claims = TokenClaims {
300        subject: subject.to_owned(),
301        audience,
302        issued_at: clock::now(),
303        expires_at: clock::now() + TOKEN_TTL,
304    };
305    signer::sign_claims(&claims).map_err(AuthError::from)
306}"#
307                .into(),
308                symbol: None,
309                score: 1.0,
310                match_reason: "test".into(),
311                evidence: Vec::new(),
312                evidence_refs: Vec::new(),
313                confidence: 1.0,
314                score_breakdown: vec![ScoreComponent::single(
315                    "test_score",
316                    1.0,
317                    Vec::new(),
318                    "test fixture",
319                )],
320            }],
321            primary_symbols: Vec::new(),
322            supporting_files: Vec::new(),
323            dependency_edges: Vec::new(),
324            runtime_signals: Vec::new(),
325            test_candidates: Vec::new(),
326            risk_report: RiskReport {
327                level: "low".into(),
328                score: 0.1,
329                reasons: Vec::new(),
330            },
331            recommended_change_boundary: ChangeBoundary {
332                allowed_files: Vec::new(),
333                caution_files: Vec::new(),
334                forbidden_files: Vec::new(),
335                evidence_refs: Vec::new(),
336                ..Default::default()
337            },
338            validation_plan: ValidationPlan {
339                commands: Vec::new(),
340                tests: Vec::new(),
341                requires_approval: false,
342                evidence: Vec::new(),
343            },
344            evidence: Vec::new(),
345            negative_evidence: Vec::new(),
346            confidence_summary: "test".into(),
347            confidence_breakdown: open_kioku_core::ConfidenceBreakdown::default(),
348        };
349
350        let compressed = store.compress_pack(&pack).unwrap();
351        let retrieved = store.retrieve(&compressed.handles[0].id).unwrap().unwrap();
352
353        assert!(compressed.compression_ratio < 1.0);
354        assert!(retrieved.original.contains("issue_token"));
355    }
356}