use crate::embedder::f32_to_bytes;
use crate::errors::AppError;
use rusqlite::{params, Connection};
#[derive(Debug, Clone)]
pub struct Chunk {
pub memory_id: i64,
pub chunk_idx: i32,
pub chunk_text: String,
pub start_offset: i32,
pub end_offset: i32,
pub token_count: i32,
}
pub fn insert_chunks(conn: &Connection, chunks: &[Chunk]) -> Result<(), AppError> {
for chunk in chunks {
conn.execute(
"INSERT INTO memory_chunks (memory_id, chunk_idx, chunk_text, start_offset, end_offset, token_count)
VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
params![
chunk.memory_id,
chunk.chunk_idx,
chunk.chunk_text,
chunk.start_offset,
chunk.end_offset,
chunk.token_count,
],
)?;
}
Ok(())
}
pub fn insert_chunk_slices(
conn: &Connection,
memory_id: i64,
body: &str,
chunks: &[crate::chunking::Chunk],
) -> Result<(), AppError> {
for (chunk_idx, chunk) in chunks.iter().enumerate() {
conn.execute(
"INSERT INTO memory_chunks (memory_id, chunk_idx, chunk_text, start_offset, end_offset, token_count)
VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
params![
memory_id,
chunk_idx as i32,
crate::chunking::chunk_text(body, chunk),
chunk.start_offset as i32,
chunk.end_offset as i32,
chunk.token_count_approx as i32,
],
)?;
}
Ok(())
}
pub fn upsert_chunk_vec(
conn: &Connection,
_rowid: i64,
memory_id: i64,
chunk_idx: i32,
embedding: &[f32],
) -> Result<(), AppError> {
conn.execute(
"INSERT OR REPLACE INTO chunk_embeddings(chunk_id, memory_id, embedding, source, model, dim)
VALUES (
(SELECT id FROM memory_chunks WHERE memory_id = ?1 AND chunk_idx = ?2),
?1, ?3, 'llm-headless', ?4, ?5
)",
params![
memory_id,
chunk_idx,
f32_to_bytes(embedding),
crate::constants::SQLITE_GRAPHRAG_VERSION,
crate::constants::EMBEDDING_DIM as i64,
],
)?;
Ok(())
}
pub fn delete_chunks(conn: &Connection, memory_id: i64) -> Result<(), AppError> {
conn.execute(
"DELETE FROM memory_chunks WHERE memory_id = ?1",
params![memory_id],
)?;
Ok(())
}
pub fn knn_search_chunks(
conn: &Connection,
embedding: &[f32],
k: usize,
) -> Result<Vec<(i64, i32, f32)>, AppError> {
if embedding.len() != crate::constants::EMBEDDING_DIM {
return Err(AppError::Embedding(format!(
"knn_search_chunks embedding has {} dims, expected {}",
embedding.len(),
crate::constants::EMBEDDING_DIM
)));
}
let mut stmt =
conn.prepare_cached("SELECT chunk_id, memory_id, embedding FROM chunk_embeddings")?;
let mut scored: Vec<(i64, i32, f32)> = stmt
.query_map([], |r| {
let chunk_id: i64 = r.get(0)?;
let memory_id: i64 = r.get(1)?;
let bytes: Vec<u8> = r.get(2)?;
Ok((chunk_id, memory_id, bytes))
})?
.filter_map(|row| {
row.ok().and_then(|(_, memory_id, bytes)| {
let stored = crate::embedder::bytes_to_f32(&bytes);
if stored.len() != embedding.len() {
return None;
}
let score = crate::similarity::cosine_similarity(embedding, &stored);
Some((memory_id, 0, score))
})
})
.collect();
scored.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(k);
Ok(scored)
}
pub fn get_chunks_by_memory(conn: &Connection, memory_id: i64) -> Result<Vec<Chunk>, AppError> {
let mut stmt = conn.prepare_cached(
"SELECT memory_id, chunk_idx, chunk_text, start_offset, end_offset, token_count
FROM memory_chunks WHERE memory_id = ?1 ORDER BY chunk_idx",
)?;
let rows = stmt
.query_map(params![memory_id], |r| {
Ok(Chunk {
memory_id: r.get(0)?,
chunk_idx: r.get(1)?,
chunk_text: r.get(2)?,
start_offset: r.get(3)?,
end_offset: r.get(4)?,
token_count: r.get(5)?,
})
})?
.collect::<Result<Vec<_>, _>>()?;
Ok(rows)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::constants::EMBEDDING_DIM;
use crate::storage::connection::register_vec_extension;
use rusqlite::Connection;
use tempfile::TempDir;
fn setup_db() -> (TempDir, Connection) {
register_vec_extension();
let tmp = TempDir::new().unwrap();
let db_path = tmp.path().join("test.db");
let mut conn = Connection::open(&db_path).unwrap();
crate::migrations::runner().run(&mut conn).unwrap();
(tmp, conn)
}
fn insert_memory(conn: &Connection) -> i64 {
conn.execute(
"INSERT INTO memories (namespace, name, type, description, body, body_hash)
VALUES ('global', 'test-mem', 'user', 'desc', 'body', 'hash1')",
[],
)
.unwrap();
conn.last_insert_rowid()
}
#[test]
fn test_insert_chunks_empty_ok() {
let (_tmp, conn) = setup_db();
let resultado = insert_chunks(&conn, &[]);
assert!(resultado.is_ok());
}
#[test]
fn test_insert_chunks_and_get_by_memory() {
let (_tmp, conn) = setup_db();
let memory_id = insert_memory(&conn);
let chunks = vec![
Chunk {
memory_id,
chunk_idx: 0,
chunk_text: "primeiro chunk".to_string(),
start_offset: 0,
end_offset: 14,
token_count: 3,
},
Chunk {
memory_id,
chunk_idx: 1,
chunk_text: "segundo chunk".to_string(),
start_offset: 15,
end_offset: 28,
token_count: 3,
},
];
insert_chunks(&conn, &chunks).unwrap();
let recuperados = get_chunks_by_memory(&conn, memory_id).unwrap();
assert_eq!(recuperados.len(), 2);
assert_eq!(recuperados[0].chunk_idx, 0);
assert_eq!(recuperados[0].chunk_text, "primeiro chunk");
assert_eq!(recuperados[0].start_offset, 0);
assert_eq!(recuperados[0].end_offset, 14);
assert_eq!(recuperados[0].token_count, 3);
assert_eq!(recuperados[1].chunk_idx, 1);
assert_eq!(recuperados[1].chunk_text, "segundo chunk");
}
#[test]
fn test_get_chunks_missing_memory_returns_empty() {
let (_tmp, conn) = setup_db();
let resultado = get_chunks_by_memory(&conn, 9999).unwrap();
assert!(resultado.is_empty());
}
#[test]
fn test_delete_chunks_removes_all() {
let (_tmp, conn) = setup_db();
let memory_id = insert_memory(&conn);
let chunks = vec![
Chunk {
memory_id,
chunk_idx: 0,
chunk_text: "chunk a".to_string(),
start_offset: 0,
end_offset: 7,
token_count: 2,
},
Chunk {
memory_id,
chunk_idx: 1,
chunk_text: "chunk b".to_string(),
start_offset: 8,
end_offset: 15,
token_count: 2,
},
];
insert_chunks(&conn, &chunks).unwrap();
delete_chunks(&conn, memory_id).unwrap();
let recuperados = get_chunks_by_memory(&conn, memory_id).unwrap();
assert!(recuperados.is_empty());
}
#[test]
fn test_delete_chunks_memory_without_chunks_ok() {
let (_tmp, conn) = setup_db();
let resultado = delete_chunks(&conn, 9999);
assert!(resultado.is_ok());
}
#[test]
fn test_get_chunks_ordered_by_chunk_idx() {
let (_tmp, conn) = setup_db();
let memory_id = insert_memory(&conn);
let chunks = vec![
Chunk {
memory_id,
chunk_idx: 2,
chunk_text: "terceiro".to_string(),
start_offset: 20,
end_offset: 28,
token_count: 1,
},
Chunk {
memory_id,
chunk_idx: 0,
chunk_text: "primeiro".to_string(),
start_offset: 0,
end_offset: 8,
token_count: 1,
},
Chunk {
memory_id,
chunk_idx: 1,
chunk_text: "segundo".to_string(),
start_offset: 9,
end_offset: 16,
token_count: 1,
},
];
insert_chunks(&conn, &chunks).unwrap();
let recuperados = get_chunks_by_memory(&conn, memory_id).unwrap();
assert_eq!(recuperados.len(), 3);
assert_eq!(recuperados[0].chunk_idx, 0);
assert_eq!(recuperados[1].chunk_idx, 1);
assert_eq!(recuperados[2].chunk_idx, 2);
}
#[test]
fn test_upsert_chunk_vec_and_knn_search() {
let (_tmp, conn) = setup_db();
let memory_id = insert_memory(&conn);
let chunk = Chunk {
memory_id,
chunk_idx: 0,
chunk_text: "embedding test".to_string(),
start_offset: 0,
end_offset: 14,
token_count: 2,
};
insert_chunks(&conn, &[chunk]).unwrap();
let mut embedding = vec![0.0f32; EMBEDDING_DIM];
embedding[0] = 1.0;
let chunk_id: i64 = conn
.query_row(
"SELECT id FROM memory_chunks WHERE memory_id = ?1 AND chunk_idx = 0",
params![memory_id],
|r| r.get(0),
)
.unwrap();
upsert_chunk_vec(&conn, chunk_id, memory_id, 0, &embedding).unwrap();
let resultados = knn_search_chunks(&conn, &embedding, 1).unwrap();
assert_eq!(resultados.len(), 1);
assert_eq!(resultados[0].0, memory_id);
assert_eq!(resultados[0].1, 0);
}
#[test]
fn test_knn_search_chunks_without_data_returns_empty() {
let (_tmp, conn) = setup_db();
let embedding = vec![0.0f32; EMBEDDING_DIM];
let resultado = knn_search_chunks(&conn, &embedding, 5).unwrap();
assert!(resultado.is_empty());
}
#[test]
fn test_insert_chunks_invalid_fk_fails() {
let (_tmp, conn) = setup_db();
let chunk = Chunk {
memory_id: 99999,
chunk_idx: 0,
chunk_text: "sem pai".to_string(),
start_offset: 0,
end_offset: 7,
token_count: 1,
};
let resultado = insert_chunks(&conn, &[chunk]);
assert!(resultado.is_err());
}
}