1use std::path::Path;
23use std::sync::{Arc, Mutex};
24
25use aonyx_core::{AonyxError, Result};
26use async_trait::async_trait;
27use chrono::{DateTime, Utc};
28use rusqlite::{params, Connection};
29use serde::{Deserialize, Serialize};
30use serde_json::Value as JsonValue;
31use uuid::Uuid;
32
33pub type ChunkId = Uuid;
35
36#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
38pub struct Chunk {
39 pub id: ChunkId,
41 pub project: String,
43 pub source: String,
45 pub content: String,
47 pub ts: DateTime<Utc>,
49 pub kind: Option<String>,
51 #[serde(default)]
53 pub metadata: JsonValue,
54}
55
56impl Chunk {
57 pub fn new(
59 project: impl Into<String>,
60 source: impl Into<String>,
61 content: impl Into<String>,
62 ) -> Self {
63 Self {
64 id: Uuid::new_v4(),
65 project: project.into(),
66 source: source.into(),
67 content: content.into(),
68 ts: Utc::now(),
69 kind: None,
70 metadata: JsonValue::Null,
71 }
72 }
73
74 pub fn with_kind(mut self, kind: impl Into<String>) -> Self {
76 self.kind = Some(kind.into());
77 self
78 }
79
80 pub fn with_metadata(mut self, metadata: JsonValue) -> Self {
82 self.metadata = metadata;
83 self
84 }
85}
86
87#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
89pub struct ScoredChunk {
90 pub chunk: Chunk,
92 pub score: f32,
94}
95
96#[async_trait]
98pub trait ChunksStore: Send + Sync {
99 async fn append(&self, chunk: Chunk) -> Result<ChunkId>;
101
102 async fn search_bm25(
107 &self,
108 project: Option<&str>,
109 query: &str,
110 k: usize,
111 ) -> Result<Vec<ScoredChunk>>;
112
113 async fn count(&self, project: Option<&str>) -> Result<usize>;
115}
116
117#[derive(Clone)]
119pub struct SqliteChunksStore {
120 conn: Arc<Mutex<Connection>>,
121}
122
123impl SqliteChunksStore {
124 pub fn open(path: impl AsRef<Path>) -> Result<Self> {
126 let conn = Connection::open(path.as_ref())
127 .map_err(|e| AonyxError::Memory(format!("open chunks db: {e}")))?;
128 Self::migrate(&conn)?;
129 Ok(Self {
130 conn: Arc::new(Mutex::new(conn)),
131 })
132 }
133
134 pub fn open_in_memory() -> Result<Self> {
136 let conn = Connection::open_in_memory()
137 .map_err(|e| AonyxError::Memory(format!("open in-memory chunks: {e}")))?;
138 Self::migrate(&conn)?;
139 Ok(Self {
140 conn: Arc::new(Mutex::new(conn)),
141 })
142 }
143
144 fn migrate(conn: &Connection) -> Result<()> {
145 conn.execute_batch(MIGRATION_V1)
146 .map_err(|e| AonyxError::Memory(format!("migrate chunks schema: {e}")))?;
147 conn.execute_batch(MIGRATION_V2)
148 .map_err(|e| AonyxError::Memory(format!("migrate chunk_vectors schema: {e}")))?;
149 Ok(())
150 }
151
152 pub async fn upsert_vector(
155 &self,
156 chunk_id: ChunkId,
157 model_id: &str,
158 vec: &[f32],
159 ) -> Result<()> {
160 let conn = self.conn.clone();
161 let id = chunk_id.to_string();
162 let model_id = model_id.to_string();
163 let dim = vec.len() as i64;
164 let blob = vec_to_blob(vec);
165 tokio::task::spawn_blocking(move || -> Result<()> {
166 let lock = conn.lock().expect("chunks mutex poisoned");
167 lock.execute(
168 "INSERT INTO chunk_vectors (chunk_id, model_id, dim, vec) VALUES (?1, ?2, ?3, ?4)
169 ON CONFLICT(chunk_id) DO UPDATE SET model_id = ?2, dim = ?3, vec = ?4",
170 params![id, model_id, dim, blob],
171 )
172 .map_err(|e| AonyxError::Memory(format!("upsert_vector: {e}")))?;
173 Ok(())
174 })
175 .await
176 .map_err(|e| AonyxError::Memory(format!("upsert_vector join: {e}")))?
177 }
178
179 pub async fn vector_search(
183 &self,
184 project: Option<&str>,
185 query: &[f32],
186 k: usize,
187 ) -> Result<Vec<ScoredChunk>> {
188 let conn = self.conn.clone();
189 let project = project.map(str::to_string);
190 let query = query.to_vec();
191 tokio::task::spawn_blocking(move || -> Result<Vec<ScoredChunk>> {
192 let lock = conn.lock().expect("chunks mutex poisoned");
193 let mut stmt = lock
194 .prepare(
195 "SELECT f.uuid, f.project, f.source, f.ts, f.kind, f.metadata_json, f.content, v.vec
196 FROM chunk_vectors v JOIN chunks_fts f ON f.uuid = v.chunk_id",
197 )
198 .map_err(|e| AonyxError::Memory(format!("prepare vector_search: {e}")))?;
199 let rows = stmt
200 .query_map([], |row| {
201 let blob: Vec<u8> = row.get(7)?;
202 Ok((chunk_from_row(row)?, blob_to_vec(&blob)))
203 })
204 .map_err(|e| AonyxError::Memory(format!("query vector_search: {e}")))?;
205 let qn = norm(&query);
206 let mut scored: Vec<ScoredChunk> = Vec::new();
207 for r in rows {
208 let (chunk, vec) = r.map_err(|e| AonyxError::Memory(format!("row decode: {e}")))?;
209 if let Some(p) = &project {
210 if &chunk.project != p {
211 continue;
212 }
213 }
214 if vec.len() != query.len() {
215 continue; }
217 scored.push(ScoredChunk {
218 score: cosine(&query, qn, &vec),
219 chunk,
220 });
221 }
222 scored.sort_by(|a, b| {
223 b.score
224 .partial_cmp(&a.score)
225 .unwrap_or(std::cmp::Ordering::Equal)
226 });
227 scored.truncate(k);
228 Ok(scored)
229 })
230 .await
231 .map_err(|e| AonyxError::Memory(format!("vector_search join: {e}")))?
232 }
233
234 pub async fn count_vectors(&self) -> Result<usize> {
236 let conn = self.conn.clone();
237 tokio::task::spawn_blocking(move || -> Result<usize> {
238 let lock = conn.lock().expect("chunks mutex poisoned");
239 let n: i64 = lock
240 .query_row("SELECT COUNT(*) FROM chunk_vectors", [], |r| r.get(0))
241 .map_err(|e| AonyxError::Memory(format!("count_vectors: {e}")))?;
242 Ok(n.max(0) as usize)
243 })
244 .await
245 .map_err(|e| AonyxError::Memory(format!("count_vectors join: {e}")))?
246 }
247}
248
249const MIGRATION_V1: &str = r#"
250CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
251 uuid UNINDEXED,
252 project UNINDEXED,
253 source UNINDEXED,
254 ts UNINDEXED,
255 kind UNINDEXED,
256 metadata_json UNINDEXED,
257 content,
258 tokenize = 'unicode61 remove_diacritics 2'
259);
260"#;
261
262const MIGRATION_V2: &str = r#"
263CREATE TABLE IF NOT EXISTS chunk_vectors (
264 chunk_id TEXT PRIMARY KEY,
265 model_id TEXT NOT NULL,
266 dim INTEGER NOT NULL,
267 vec BLOB NOT NULL
268);
269"#;
270
271#[async_trait]
272impl ChunksStore for SqliteChunksStore {
273 async fn append(&self, chunk: Chunk) -> Result<ChunkId> {
274 let conn = self.conn.clone();
275 let id = chunk.id;
276 tokio::task::spawn_blocking(move || -> Result<()> {
277 let lock = conn.lock().expect("chunks mutex poisoned");
278 lock.execute(
279 r#"
280 INSERT INTO chunks_fts (uuid, project, source, ts, kind, metadata_json, content)
281 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
282 "#,
283 params![
284 chunk.id.to_string(),
285 chunk.project,
286 chunk.source,
287 chunk.ts.to_rfc3339(),
288 chunk.kind,
289 serde_json::to_string(&chunk.metadata).ok(),
290 chunk.content,
291 ],
292 )
293 .map_err(|e| AonyxError::Memory(format!("chunks append: {e}")))?;
294 Ok(())
295 })
296 .await
297 .map_err(|e| AonyxError::Memory(format!("chunks append join: {e}")))??;
298 Ok(id)
299 }
300
301 async fn search_bm25(
302 &self,
303 project: Option<&str>,
304 query: &str,
305 k: usize,
306 ) -> Result<Vec<ScoredChunk>> {
307 let conn = self.conn.clone();
308 let query = query.to_string();
309 let project = project.map(str::to_string);
310 let limit = k as i64;
311 tokio::task::spawn_blocking(move || -> Result<Vec<ScoredChunk>> {
312 let lock = conn.lock().expect("chunks mutex poisoned");
313 let (sql, with_project) = if project.is_some() {
314 (
315 "SELECT uuid, project, source, ts, kind, metadata_json, content, bm25(chunks_fts) AS score
316 FROM chunks_fts
317 WHERE chunks_fts MATCH ?1 AND project = ?2
318 ORDER BY score ASC
319 LIMIT ?3",
320 true,
321 )
322 } else {
323 (
324 "SELECT uuid, project, source, ts, kind, metadata_json, content, bm25(chunks_fts) AS score
325 FROM chunks_fts
326 WHERE chunks_fts MATCH ?1
327 ORDER BY score ASC
328 LIMIT ?2",
329 false,
330 )
331 };
332 let mut stmt = lock
333 .prepare(sql)
334 .map_err(|e| AonyxError::Memory(format!("prepare search_bm25: {e}")))?;
335 let row_iter = if with_project {
336 stmt.query_map(
337 params![query, project.as_ref().expect("project guarded above"), limit],
338 decode_row,
339 )
340 } else {
341 stmt.query_map(params![query, limit], decode_row)
342 }
343 .map_err(|e| AonyxError::Memory(format!("query search_bm25: {e}")))?;
344
345 let mut out = Vec::new();
346 for r in row_iter {
347 out.push(r.map_err(|e| AonyxError::Memory(format!("row decode: {e}")))?);
348 }
349 Ok(out)
350 })
351 .await
352 .map_err(|e| AonyxError::Memory(format!("chunks search join: {e}")))?
353 }
354
355 async fn count(&self, project: Option<&str>) -> Result<usize> {
356 let conn = self.conn.clone();
357 let project = project.map(str::to_string);
358 tokio::task::spawn_blocking(move || -> Result<usize> {
359 let lock = conn.lock().expect("chunks mutex poisoned");
360 let n: i64 = match project {
361 Some(p) => lock
362 .query_row(
363 "SELECT COUNT(*) FROM chunks_fts WHERE project = ?1",
364 params![p],
365 |r| r.get(0),
366 )
367 .map_err(|e| AonyxError::Memory(format!("count: {e}")))?,
368 None => lock
369 .query_row("SELECT COUNT(*) FROM chunks_fts", [], |r| r.get(0))
370 .map_err(|e| AonyxError::Memory(format!("count: {e}")))?,
371 };
372 Ok(n.max(0) as usize)
373 })
374 .await
375 .map_err(|e| AonyxError::Memory(format!("chunks count join: {e}")))?
376 }
377}
378
379fn chunk_from_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<Chunk> {
382 let uuid_str: String = row.get(0)?;
383 let project: String = row.get(1)?;
384 let source: String = row.get(2)?;
385 let ts_raw: String = row.get(3)?;
386 let kind: Option<String> = row.get(4)?;
387 let metadata_raw: Option<String> = row.get(5)?;
388 let content: String = row.get(6)?;
389
390 let id = Uuid::parse_str(&uuid_str).map_err(|e| {
391 rusqlite::Error::FromSqlConversionFailure(0, rusqlite::types::Type::Text, Box::new(e))
392 })?;
393 let ts = DateTime::parse_from_rfc3339(&ts_raw)
394 .map(|d| d.with_timezone(&Utc))
395 .unwrap_or_else(|_| Utc::now());
396 let metadata = metadata_raw
397 .and_then(|s| serde_json::from_str(&s).ok())
398 .unwrap_or(JsonValue::Null);
399
400 Ok(Chunk {
401 id,
402 project,
403 source,
404 content,
405 ts,
406 kind,
407 metadata,
408 })
409}
410
411fn decode_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<ScoredChunk> {
412 let chunk = chunk_from_row(row)?;
413 let raw_score: f64 = row.get(7)?;
414 Ok(ScoredChunk {
415 chunk,
416 score: -(raw_score as f32),
418 })
419}
420
421fn vec_to_blob(v: &[f32]) -> Vec<u8> {
423 let mut out = Vec::with_capacity(v.len() * 4);
424 for f in v {
425 out.extend_from_slice(&f.to_le_bytes());
426 }
427 out
428}
429
430fn blob_to_vec(b: &[u8]) -> Vec<f32> {
432 b.chunks_exact(4)
433 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
434 .collect()
435}
436
437fn norm(v: &[f32]) -> f32 {
438 v.iter().map(|x| x * x).sum::<f32>().sqrt()
439}
440
441fn cosine(q: &[f32], qn: f32, d: &[f32]) -> f32 {
443 let dot: f32 = q.iter().zip(d).map(|(a, b)| a * b).sum();
444 let dn = norm(d);
445 if qn == 0.0 || dn == 0.0 {
446 0.0
447 } else {
448 dot / (qn * dn)
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use super::*;
455
456 async fn seeded_store() -> SqliteChunksStore {
457 let store = SqliteChunksStore::open_in_memory().unwrap();
458 store
459 .append(Chunk::new(
460 "demo",
461 "src/lib.rs",
462 "the agent loops over tool calls",
463 ))
464 .await
465 .unwrap();
466 store
467 .append(Chunk::new(
468 "demo",
469 "src/runner.rs",
470 "compaction kicks in at fifty percent",
471 ))
472 .await
473 .unwrap();
474 store
475 .append(Chunk::new("other", "README.md", "another project entirely"))
476 .await
477 .unwrap();
478 store
479 }
480
481 #[tokio::test]
482 async fn append_then_count() {
483 let store = SqliteChunksStore::open_in_memory().unwrap();
484 store
485 .append(Chunk::new("demo", "a.txt", "hello aonyx"))
486 .await
487 .unwrap();
488 assert_eq!(store.count(None).await.unwrap(), 1);
489 assert_eq!(store.count(Some("demo")).await.unwrap(), 1);
490 assert_eq!(store.count(Some("other")).await.unwrap(), 0);
491 }
492
493 #[tokio::test]
494 async fn search_bm25_returns_relevant_chunks() {
495 let store = seeded_store().await;
496 let hits = store.search_bm25(None, "compaction", 10).await.unwrap();
497 assert_eq!(hits.len(), 1);
498 assert!(hits[0].chunk.content.contains("compaction"));
499 assert!(hits[0].score > 0.0);
500 }
501
502 #[tokio::test]
503 async fn search_bm25_can_scope_to_project() {
504 let store = seeded_store().await;
505 let in_demo = store
506 .search_bm25(Some("demo"), "project OR agent", 10)
507 .await
508 .unwrap();
509 let in_other = store
510 .search_bm25(Some("other"), "project OR agent", 10)
511 .await
512 .unwrap();
513 assert!(in_demo.iter().all(|h| h.chunk.project == "demo"));
514 assert!(in_other.iter().all(|h| h.chunk.project == "other"));
515 }
516
517 #[tokio::test]
518 async fn search_bm25_returns_empty_when_no_match() {
519 let store = seeded_store().await;
520 let hits = store
521 .search_bm25(None, "nothing_should_match_this", 10)
522 .await
523 .unwrap();
524 assert!(hits.is_empty());
525 }
526
527 #[tokio::test]
528 async fn search_bm25_honours_limit() {
529 let store = SqliteChunksStore::open_in_memory().unwrap();
530 for i in 0..5 {
531 store
532 .append(Chunk::new("demo", "x", format!("repeat token {i}")))
533 .await
534 .unwrap();
535 }
536 let hits = store.search_bm25(None, "repeat", 2).await.unwrap();
537 assert_eq!(hits.len(), 2);
538 }
539}