Skip to main content

aonyx_memory/
chunks.rs

1//! Searchable chunks store backed by SQLite FTS5.
2//!
3//! Port reference: Aonyx RAG `rag_system/utils/bm25_store.py` + `utils/hybrid_search.py`.
4//!
5//! ## V1 scope (this file)
6//! - Chunk = a piece of text + project + source + timestamp + free-form metadata.
7//! - SQLite **FTS5** virtual table provides BM25-ranked full-text search out
8//!   of the box, with a `unicode61 remove_diacritics 2` tokenizer that survives
9//!   accents.
10//! - `search_bm25(project?, query, k)` returns the top-`k` chunks ordered by
11//!   relevance, with positive `score = -bm25(...)` so larger = better.
12//!
13//! ## V1.1 (deferred)
14//! - Local embeddings via `fastembed-rs` (ONNX, ~30 MB model).
15//! - HNSW index for vector ANN search.
16//! - **RRF** fusion with `k = 60` combining BM25 + vectors.
17//! - Exponential temporal boost on recent chunks.
18//!
19//! The trait signature already accepts a `mode` field so V1.1 can extend it
20//! without breaking callers.
21
22use std::path::Path;
23use std::sync::{Arc, Mutex};
24
25use aonyx_core::{AonyxError, Result};
26use async_trait::async_trait;
27use chrono::{DateTime, Utc};
28use rusqlite::{params, Connection};
29use serde::{Deserialize, Serialize};
30use serde_json::Value as JsonValue;
31use uuid::Uuid;
32
33/// Stable identifier for a [`Chunk`].
34pub type ChunkId = Uuid;
35
36/// A piece of indexable text.
37#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
38pub struct Chunk {
39    /// Stable id (UUID v4 by default).
40    pub id: ChunkId,
41    /// Project slug this chunk belongs to.
42    pub project: String,
43    /// Source identifier (path, url, doc id).
44    pub source: String,
45    /// Raw chunk text.
46    pub content: String,
47    /// Creation timestamp.
48    pub ts: DateTime<Utc>,
49    /// Optional classifier (`"code"`, `"note"`, `"diary"`, `"doc"`).
50    pub kind: Option<String>,
51    /// Free-form JSON metadata (e.g. AST symbol name + line range for code chunks).
52    #[serde(default)]
53    pub metadata: JsonValue,
54}
55
56impl Chunk {
57    /// Build a new chunk with sensible defaults.
58    pub fn new(
59        project: impl Into<String>,
60        source: impl Into<String>,
61        content: impl Into<String>,
62    ) -> Self {
63        Self {
64            id: Uuid::new_v4(),
65            project: project.into(),
66            source: source.into(),
67            content: content.into(),
68            ts: Utc::now(),
69            kind: None,
70            metadata: JsonValue::Null,
71        }
72    }
73
74    /// Attach a classifier.
75    pub fn with_kind(mut self, kind: impl Into<String>) -> Self {
76        self.kind = Some(kind.into());
77        self
78    }
79
80    /// Attach JSON metadata.
81    pub fn with_metadata(mut self, metadata: JsonValue) -> Self {
82        self.metadata = metadata;
83        self
84    }
85}
86
87/// A search hit: a chunk and its score (larger = more relevant).
88#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
89pub struct ScoredChunk {
90    /// The matched chunk.
91    pub chunk: Chunk,
92    /// Relevance score (positive; we flip SQLite's negative BM25).
93    pub score: f32,
94}
95
96/// Async chunks store.
97#[async_trait]
98pub trait ChunksStore: Send + Sync {
99    /// Append a new chunk.
100    async fn append(&self, chunk: Chunk) -> Result<ChunkId>;
101
102    /// BM25 search.
103    ///
104    /// `project = None` searches across every project; `Some(p)` scopes to one.
105    /// `k` caps the number of hits.
106    async fn search_bm25(
107        &self,
108        project: Option<&str>,
109        query: &str,
110        k: usize,
111    ) -> Result<Vec<ScoredChunk>>;
112
113    /// Total chunk count, optionally scoped to a project.
114    async fn count(&self, project: Option<&str>) -> Result<usize>;
115}
116
117/// SQLite-backed [`ChunksStore`] using FTS5 for BM25 ranking.
118#[derive(Clone)]
119pub struct SqliteChunksStore {
120    conn: Arc<Mutex<Connection>>,
121}
122
123impl SqliteChunksStore {
124    /// Open (or create) the chunks database at `path`.
125    pub fn open(path: impl AsRef<Path>) -> Result<Self> {
126        let conn = Connection::open(path.as_ref())
127            .map_err(|e| AonyxError::Memory(format!("open chunks db: {e}")))?;
128        Self::migrate(&conn)?;
129        Ok(Self {
130            conn: Arc::new(Mutex::new(conn)),
131        })
132    }
133
134    /// Open an in-memory database — convenient for tests.
135    pub fn open_in_memory() -> Result<Self> {
136        let conn = Connection::open_in_memory()
137            .map_err(|e| AonyxError::Memory(format!("open in-memory chunks: {e}")))?;
138        Self::migrate(&conn)?;
139        Ok(Self {
140            conn: Arc::new(Mutex::new(conn)),
141        })
142    }
143
144    fn migrate(conn: &Connection) -> Result<()> {
145        conn.execute_batch(MIGRATION_V1)
146            .map_err(|e| AonyxError::Memory(format!("migrate chunks schema: {e}")))?;
147        conn.execute_batch(MIGRATION_V2)
148            .map_err(|e| AonyxError::Memory(format!("migrate chunk_vectors schema: {e}")))?;
149        Ok(())
150    }
151
152    /// Store (or replace) the embedding `vec` for `chunk_id`, tagged with the
153    /// `model_id` that produced it (so a model change can be detected).
154    pub async fn upsert_vector(
155        &self,
156        chunk_id: ChunkId,
157        model_id: &str,
158        vec: &[f32],
159    ) -> Result<()> {
160        let conn = self.conn.clone();
161        let id = chunk_id.to_string();
162        let model_id = model_id.to_string();
163        let dim = vec.len() as i64;
164        let blob = vec_to_blob(vec);
165        tokio::task::spawn_blocking(move || -> Result<()> {
166            let lock = conn.lock().expect("chunks mutex poisoned");
167            lock.execute(
168                "INSERT INTO chunk_vectors (chunk_id, model_id, dim, vec) VALUES (?1, ?2, ?3, ?4)
169                 ON CONFLICT(chunk_id) DO UPDATE SET model_id = ?2, dim = ?3, vec = ?4",
170                params![id, model_id, dim, blob],
171            )
172            .map_err(|e| AonyxError::Memory(format!("upsert_vector: {e}")))?;
173            Ok(())
174        })
175        .await
176        .map_err(|e| AonyxError::Memory(format!("upsert_vector join: {e}")))?
177    }
178
179    /// Brute-force cosine search over stored vectors (optionally scoped to a
180    /// project). Returns the top-`k` chunks by similarity to `query`. Vectors
181    /// whose dimension differs from `query` (a stale embedder) are skipped.
182    pub async fn vector_search(
183        &self,
184        project: Option<&str>,
185        query: &[f32],
186        k: usize,
187    ) -> Result<Vec<ScoredChunk>> {
188        let conn = self.conn.clone();
189        let project = project.map(str::to_string);
190        let query = query.to_vec();
191        tokio::task::spawn_blocking(move || -> Result<Vec<ScoredChunk>> {
192            let lock = conn.lock().expect("chunks mutex poisoned");
193            let mut stmt = lock
194                .prepare(
195                    "SELECT f.uuid, f.project, f.source, f.ts, f.kind, f.metadata_json, f.content, v.vec
196                     FROM chunk_vectors v JOIN chunks_fts f ON f.uuid = v.chunk_id",
197                )
198                .map_err(|e| AonyxError::Memory(format!("prepare vector_search: {e}")))?;
199            let rows = stmt
200                .query_map([], |row| {
201                    let blob: Vec<u8> = row.get(7)?;
202                    Ok((chunk_from_row(row)?, blob_to_vec(&blob)))
203                })
204                .map_err(|e| AonyxError::Memory(format!("query vector_search: {e}")))?;
205            let qn = norm(&query);
206            let mut scored: Vec<ScoredChunk> = Vec::new();
207            for r in rows {
208                let (chunk, vec) = r.map_err(|e| AonyxError::Memory(format!("row decode: {e}")))?;
209                if let Some(p) = &project {
210                    if &chunk.project != p {
211                        continue;
212                    }
213                }
214                if vec.len() != query.len() {
215                    continue; // dim mismatch → stale embedder, skip
216                }
217                scored.push(ScoredChunk {
218                    score: cosine(&query, qn, &vec),
219                    chunk,
220                });
221            }
222            scored.sort_by(|a, b| {
223                b.score
224                    .partial_cmp(&a.score)
225                    .unwrap_or(std::cmp::Ordering::Equal)
226            });
227            scored.truncate(k);
228            Ok(scored)
229        })
230        .await
231        .map_err(|e| AonyxError::Memory(format!("vector_search join: {e}")))?
232    }
233
234    /// Count stored vectors (diagnostics / reindex decisions).
235    pub async fn count_vectors(&self) -> Result<usize> {
236        let conn = self.conn.clone();
237        tokio::task::spawn_blocking(move || -> Result<usize> {
238            let lock = conn.lock().expect("chunks mutex poisoned");
239            let n: i64 = lock
240                .query_row("SELECT COUNT(*) FROM chunk_vectors", [], |r| r.get(0))
241                .map_err(|e| AonyxError::Memory(format!("count_vectors: {e}")))?;
242            Ok(n.max(0) as usize)
243        })
244        .await
245        .map_err(|e| AonyxError::Memory(format!("count_vectors join: {e}")))?
246    }
247}
248
249const MIGRATION_V1: &str = r#"
250CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
251    uuid           UNINDEXED,
252    project        UNINDEXED,
253    source         UNINDEXED,
254    ts             UNINDEXED,
255    kind           UNINDEXED,
256    metadata_json  UNINDEXED,
257    content,
258    tokenize = 'unicode61 remove_diacritics 2'
259);
260"#;
261
262const MIGRATION_V2: &str = r#"
263CREATE TABLE IF NOT EXISTS chunk_vectors (
264    chunk_id TEXT PRIMARY KEY,
265    model_id TEXT NOT NULL,
266    dim      INTEGER NOT NULL,
267    vec      BLOB NOT NULL
268);
269"#;
270
271#[async_trait]
272impl ChunksStore for SqliteChunksStore {
273    async fn append(&self, chunk: Chunk) -> Result<ChunkId> {
274        let conn = self.conn.clone();
275        let id = chunk.id;
276        tokio::task::spawn_blocking(move || -> Result<()> {
277            let lock = conn.lock().expect("chunks mutex poisoned");
278            lock.execute(
279                r#"
280                INSERT INTO chunks_fts (uuid, project, source, ts, kind, metadata_json, content)
281                VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
282                "#,
283                params![
284                    chunk.id.to_string(),
285                    chunk.project,
286                    chunk.source,
287                    chunk.ts.to_rfc3339(),
288                    chunk.kind,
289                    serde_json::to_string(&chunk.metadata).ok(),
290                    chunk.content,
291                ],
292            )
293            .map_err(|e| AonyxError::Memory(format!("chunks append: {e}")))?;
294            Ok(())
295        })
296        .await
297        .map_err(|e| AonyxError::Memory(format!("chunks append join: {e}")))??;
298        Ok(id)
299    }
300
301    async fn search_bm25(
302        &self,
303        project: Option<&str>,
304        query: &str,
305        k: usize,
306    ) -> Result<Vec<ScoredChunk>> {
307        let conn = self.conn.clone();
308        let query = query.to_string();
309        let project = project.map(str::to_string);
310        let limit = k as i64;
311        tokio::task::spawn_blocking(move || -> Result<Vec<ScoredChunk>> {
312            let lock = conn.lock().expect("chunks mutex poisoned");
313            let (sql, with_project) = if project.is_some() {
314                (
315                    "SELECT uuid, project, source, ts, kind, metadata_json, content, bm25(chunks_fts) AS score
316                     FROM chunks_fts
317                     WHERE chunks_fts MATCH ?1 AND project = ?2
318                     ORDER BY score ASC
319                     LIMIT ?3",
320                    true,
321                )
322            } else {
323                (
324                    "SELECT uuid, project, source, ts, kind, metadata_json, content, bm25(chunks_fts) AS score
325                     FROM chunks_fts
326                     WHERE chunks_fts MATCH ?1
327                     ORDER BY score ASC
328                     LIMIT ?2",
329                    false,
330                )
331            };
332            let mut stmt = lock
333                .prepare(sql)
334                .map_err(|e| AonyxError::Memory(format!("prepare search_bm25: {e}")))?;
335            let row_iter = if with_project {
336                stmt.query_map(
337                    params![query, project.as_ref().expect("project guarded above"), limit],
338                    decode_row,
339                )
340            } else {
341                stmt.query_map(params![query, limit], decode_row)
342            }
343            .map_err(|e| AonyxError::Memory(format!("query search_bm25: {e}")))?;
344
345            let mut out = Vec::new();
346            for r in row_iter {
347                out.push(r.map_err(|e| AonyxError::Memory(format!("row decode: {e}")))?);
348            }
349            Ok(out)
350        })
351        .await
352        .map_err(|e| AonyxError::Memory(format!("chunks search join: {e}")))?
353    }
354
355    async fn count(&self, project: Option<&str>) -> Result<usize> {
356        let conn = self.conn.clone();
357        let project = project.map(str::to_string);
358        tokio::task::spawn_blocking(move || -> Result<usize> {
359            let lock = conn.lock().expect("chunks mutex poisoned");
360            let n: i64 = match project {
361                Some(p) => lock
362                    .query_row(
363                        "SELECT COUNT(*) FROM chunks_fts WHERE project = ?1",
364                        params![p],
365                        |r| r.get(0),
366                    )
367                    .map_err(|e| AonyxError::Memory(format!("count: {e}")))?,
368                None => lock
369                    .query_row("SELECT COUNT(*) FROM chunks_fts", [], |r| r.get(0))
370                    .map_err(|e| AonyxError::Memory(format!("count: {e}")))?,
371            };
372            Ok(n.max(0) as usize)
373        })
374        .await
375        .map_err(|e| AonyxError::Memory(format!("chunks count join: {e}")))?
376    }
377}
378
379/// Decode a chunk from a row whose first 7 columns are
380/// `uuid, project, source, ts, kind, metadata_json, content`.
381fn chunk_from_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<Chunk> {
382    let uuid_str: String = row.get(0)?;
383    let project: String = row.get(1)?;
384    let source: String = row.get(2)?;
385    let ts_raw: String = row.get(3)?;
386    let kind: Option<String> = row.get(4)?;
387    let metadata_raw: Option<String> = row.get(5)?;
388    let content: String = row.get(6)?;
389
390    let id = Uuid::parse_str(&uuid_str).map_err(|e| {
391        rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Text, Box::new(e))
392    })?;
393    let ts = DateTime::parse_from_rfc3339(&ts_raw)
394        .map(|d| d.with_timezone(&Utc))
395        .unwrap_or_else(|_| Utc::now());
396    let metadata = metadata_raw
397        .and_then(|s| serde_json::from_str(&s).ok())
398        .unwrap_or(JsonValue::Null);
399
400    Ok(Chunk {
401        id,
402        project,
403        source,
404        content,
405        ts,
406        kind,
407        metadata,
408    })
409}
410
411fn decode_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<ScoredChunk> {
412    let chunk = chunk_from_row(row)?;
413    let raw_score: f64 = row.get(7)?;
414    Ok(ScoredChunk {
415        chunk,
416        // SQLite's bm25() returns negative values; flip the sign so larger = better.
417        score: -(raw_score as f32),
418    })
419}
420
421/// Serialise a vector as little-endian f32 bytes.
422fn vec_to_blob(v: &[f32]) -> Vec<u8> {
423    let mut out = Vec::with_capacity(v.len() * 4);
424    for f in v {
425        out.extend_from_slice(&f.to_le_bytes());
426    }
427    out
428}
429
430/// Deserialise a little-endian f32 byte blob back into a vector.
431fn blob_to_vec(b: &[u8]) -> Vec<f32> {
432    b.chunks_exact(4)
433        .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
434        .collect()
435}
436
437fn norm(v: &[f32]) -> f32 {
438    v.iter().map(|x| x * x).sum::<f32>().sqrt()
439}
440
441/// Cosine similarity; `qn` is the precomputed norm of `q`.
442fn cosine(q: &[f32], qn: f32, d: &[f32]) -> f32 {
443    let dot: f32 = q.iter().zip(d).map(|(a, b)| a * b).sum();
444    let dn = norm(d);
445    if qn == 0.0 || dn == 0.0 {
446        0.0
447    } else {
448        dot / (qn * dn)
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455
456    async fn seeded_store() -> SqliteChunksStore {
457        let store = SqliteChunksStore::open_in_memory().unwrap();
458        store
459            .append(Chunk::new(
460                "demo",
461                "src/lib.rs",
462                "the agent loops over tool calls",
463            ))
464            .await
465            .unwrap();
466        store
467            .append(Chunk::new(
468                "demo",
469                "src/runner.rs",
470                "compaction kicks in at fifty percent",
471            ))
472            .await
473            .unwrap();
474        store
475            .append(Chunk::new("other", "README.md", "another project entirely"))
476            .await
477            .unwrap();
478        store
479    }
480
481    #[tokio::test]
482    async fn append_then_count() {
483        let store = SqliteChunksStore::open_in_memory().unwrap();
484        store
485            .append(Chunk::new("demo", "a.txt", "hello aonyx"))
486            .await
487            .unwrap();
488        assert_eq!(store.count(None).await.unwrap(), 1);
489        assert_eq!(store.count(Some("demo")).await.unwrap(), 1);
490        assert_eq!(store.count(Some("other")).await.unwrap(), 0);
491    }
492
493    #[tokio::test]
494    async fn search_bm25_returns_relevant_chunks() {
495        let store = seeded_store().await;
496        let hits = store.search_bm25(None, "compaction", 10).await.unwrap();
497        assert_eq!(hits.len(), 1);
498        assert!(hits[0].chunk.content.contains("compaction"));
499        assert!(hits[0].score > 0.0);
500    }
501
502    #[tokio::test]
503    async fn search_bm25_can_scope_to_project() {
504        let store = seeded_store().await;
505        let in_demo = store
506            .search_bm25(Some("demo"), "project OR agent", 10)
507            .await
508            .unwrap();
509        let in_other = store
510            .search_bm25(Some("other"), "project OR agent", 10)
511            .await
512            .unwrap();
513        assert!(in_demo.iter().all(|h| h.chunk.project == "demo"));
514        assert!(in_other.iter().all(|h| h.chunk.project == "other"));
515    }
516
517    #[tokio::test]
518    async fn search_bm25_returns_empty_when_no_match() {
519        let store = seeded_store().await;
520        let hits = store
521            .search_bm25(None, "nothing_should_match_this", 10)
522            .await
523            .unwrap();
524        assert!(hits.is_empty());
525    }
526
527    #[tokio::test]
528    async fn search_bm25_honours_limit() {
529        let store = SqliteChunksStore::open_in_memory().unwrap();
530        for i in 0..5 {
531            store
532                .append(Chunk::new("demo", "x", format!("repeat token {i}")))
533                .await
534                .unwrap();
535        }
536        let hits = store.search_bm25(None, "repeat", 2).await.unwrap();
537        assert_eq!(hits.len(), 2);
538    }
539}