#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::cast_sign_loss)]
use crate::core::{Buffer, BufferMetadata, Chunk, ChunkMetadata, Context};
use crate::error::{Result, StorageError};
use crate::storage::schema::{
CHECK_SCHEMA_SQL, CURRENT_SCHEMA_VERSION, GET_VERSION_SQL, SCHEMA_SQL, SET_VERSION_SQL,
};
use crate::storage::traits::{Storage, StorageStats};
use rusqlite::{Connection, OptionalExtension, params};
use std::path::{Path, PathBuf};
pub struct SqliteStorage {
conn: Connection,
path: Option<PathBuf>,
}
impl SqliteStorage {
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
let path = path.as_ref().to_path_buf();
if let Some(parent) = path.parent()
&& !parent.exists()
{
std::fs::create_dir_all(parent).map_err(|e| StorageError::Database(e.to_string()))?;
}
let conn = Connection::open(&path).map_err(StorageError::from)?;
conn.execute("PRAGMA foreign_keys = ON;", [])
.map_err(StorageError::from)?;
let _: String = conn
.query_row("PRAGMA journal_mode = WAL;", [], |row| row.get(0))
.map_err(StorageError::from)?;
Ok(Self {
conn,
path: Some(path),
})
}
pub fn in_memory() -> Result<Self> {
let conn = Connection::open_in_memory().map_err(StorageError::from)?;
conn.execute("PRAGMA foreign_keys = ON;", [])
.map_err(StorageError::from)?;
Ok(Self { conn, path: None })
}
#[must_use]
pub fn path(&self) -> Option<&Path> {
self.path.as_deref()
}
fn get_schema_version(&self) -> Result<Option<u32>> {
let version: Option<String> = self
.conn
.query_row(GET_VERSION_SQL, [], |row| row.get(0))
.optional()
.map_err(StorageError::from)?;
Ok(version.and_then(|v| v.parse().ok()))
}
fn set_schema_version(&self, version: u32) -> Result<()> {
self.conn
.execute(SET_VERSION_SQL, params![version.to_string()])
.map_err(StorageError::from)?;
Ok(())
}
#[allow(clippy::cast_possible_wrap)]
fn now() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0)
}
}
impl Storage for SqliteStorage {
fn init(&mut self) -> Result<()> {
let is_init: i64 = self
.conn
.query_row(CHECK_SCHEMA_SQL, [], |row| row.get(0))
.map_err(StorageError::from)?;
if is_init == 0 {
self.conn
.execute_batch(SCHEMA_SQL)
.map_err(StorageError::from)?;
self.set_schema_version(CURRENT_SCHEMA_VERSION)?;
} else if let Some(current) = self.get_schema_version()?
&& current < CURRENT_SCHEMA_VERSION
{
let migrations = crate::storage::schema::get_migrations_from(current);
for migration in migrations {
self.conn
.execute_batch(migration.sql)
.map_err(|e| StorageError::Migration(e.to_string()))?;
}
self.set_schema_version(CURRENT_SCHEMA_VERSION)?;
}
Ok(())
}
fn is_initialized(&self) -> Result<bool> {
let count: i64 = self
.conn
.query_row(CHECK_SCHEMA_SQL, [], |row| row.get(0))
.map_err(StorageError::from)?;
Ok(count > 0)
}
fn reset(&mut self) -> Result<()> {
self.conn
.execute_batch(
r"
DELETE FROM chunk_embeddings;
DELETE FROM chunks;
DELETE FROM buffers;
DELETE FROM context;
DELETE FROM metadata;
",
)
.map_err(StorageError::from)?;
Ok(())
}
fn save_context(&mut self, context: &Context) -> Result<()> {
let data = serde_json::to_string(context).map_err(StorageError::from)?;
let now = Self::now();
self.conn
.execute(
r"
INSERT OR REPLACE INTO context (id, data, created_at, updated_at)
VALUES (1, ?, COALESCE((SELECT created_at FROM context WHERE id = 1), ?), ?)
",
params![data, now, now],
)
.map_err(StorageError::from)?;
Ok(())
}
fn load_context(&self) -> Result<Option<Context>> {
let data: Option<String> = self
.conn
.query_row("SELECT data FROM context WHERE id = 1", [], |row| {
row.get(0)
})
.optional()
.map_err(StorageError::from)?;
match data {
Some(json) => {
let context = serde_json::from_str(&json).map_err(StorageError::from)?;
Ok(Some(context))
}
None => Ok(None),
}
}
fn delete_context(&mut self) -> Result<()> {
self.conn
.execute("DELETE FROM context WHERE id = 1", [])
.map_err(StorageError::from)?;
Ok(())
}
#[allow(clippy::cast_possible_wrap)]
fn add_buffer(&mut self, buffer: &Buffer) -> Result<i64> {
let now = Self::now();
self.conn
.execute(
r"
INSERT INTO buffers (
name, source_path, content, content_type, content_hash,
size, line_count, chunk_count, created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
",
params![
buffer.name,
buffer
.source
.as_ref()
.map(|p| p.to_string_lossy().to_string()),
buffer.content,
buffer.metadata.content_type,
buffer.metadata.content_hash,
buffer.metadata.size as i64,
buffer.metadata.line_count.map(|c| c as i64),
buffer.metadata.chunk_count.map(|c| c as i64),
now,
now,
],
)
.map_err(StorageError::from)?;
Ok(self.conn.last_insert_rowid())
}
fn get_buffer(&self, id: i64) -> Result<Option<Buffer>> {
let result = self
.conn
.query_row(
r"
SELECT id, name, source_path, content, content_type, content_hash,
size, line_count, chunk_count, created_at, updated_at
FROM buffers WHERE id = ?
",
params![id],
|row| {
Ok(Buffer {
id: Some(row.get::<_, i64>(0)?),
name: row.get(1)?,
source: row.get::<_, Option<String>>(2)?.map(PathBuf::from),
content: row.get(3)?,
metadata: BufferMetadata {
content_type: row.get(4)?,
content_hash: row.get(5)?,
size: row.get::<_, i64>(6)? as usize,
line_count: row.get::<_, Option<i64>>(7)?.map(|c| c as usize),
chunk_count: row.get::<_, Option<i64>>(8)?.map(|c| c as usize),
created_at: row.get(9)?,
updated_at: row.get(10)?,
},
})
},
)
.optional()
.map_err(StorageError::from)?;
Ok(result)
}
fn get_buffer_by_name(&self, name: &str) -> Result<Option<Buffer>> {
let id: Option<i64> = self
.conn
.query_row(
"SELECT id FROM buffers WHERE name = ?",
params![name],
|row| row.get(0),
)
.optional()
.map_err(StorageError::from)?;
id.map_or(Ok(None), |id| self.get_buffer(id))
}
fn list_buffers(&self) -> Result<Vec<Buffer>> {
let mut stmt = self
.conn
.prepare(
r"
SELECT id, name, source_path, content, content_type, content_hash,
size, line_count, chunk_count, created_at, updated_at
FROM buffers ORDER BY id
",
)
.map_err(StorageError::from)?;
let buffers = stmt
.query_map([], |row| {
Ok(Buffer {
id: Some(row.get::<_, i64>(0)?),
name: row.get(1)?,
source: row.get::<_, Option<String>>(2)?.map(PathBuf::from),
content: row.get(3)?,
metadata: BufferMetadata {
content_type: row.get(4)?,
content_hash: row.get(5)?,
size: row.get::<_, i64>(6)? as usize,
line_count: row.get::<_, Option<i64>>(7)?.map(|c| c as usize),
chunk_count: row.get::<_, Option<i64>>(8)?.map(|c| c as usize),
created_at: row.get(9)?,
updated_at: row.get(10)?,
},
})
})
.map_err(StorageError::from)?
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(StorageError::from)?;
Ok(buffers)
}
#[allow(clippy::cast_possible_wrap)]
fn update_buffer(&mut self, buffer: &Buffer) -> Result<()> {
let id = buffer.id.ok_or_else(|| StorageError::BufferNotFound {
identifier: "no ID".to_string(),
})?;
let now = Self::now();
self.conn
.execute(
r"
UPDATE buffers SET
name = ?, source_path = ?, content = ?, content_type = ?,
content_hash = ?, size = ?, line_count = ?, chunk_count = ?,
updated_at = ?
WHERE id = ?
",
params![
buffer.name,
buffer
.source
.as_ref()
.map(|p| p.to_string_lossy().to_string()),
buffer.content,
buffer.metadata.content_type,
buffer.metadata.content_hash,
buffer.metadata.size as i64,
buffer.metadata.line_count.map(|c| c as i64),
buffer.metadata.chunk_count.map(|c| c as i64),
now,
id,
],
)
.map_err(StorageError::from)?;
Ok(())
}
fn delete_buffer(&mut self, id: i64) -> Result<()> {
self.conn
.execute("DELETE FROM buffers WHERE id = ?", params![id])
.map_err(StorageError::from)?;
Ok(())
}
fn buffer_count(&self) -> Result<usize> {
let count: i64 = self
.conn
.query_row("SELECT COUNT(*) FROM buffers", [], |row| row.get(0))
.map_err(StorageError::from)?;
Ok(count as usize)
}
#[allow(clippy::cast_possible_wrap)]
fn add_chunks(&mut self, buffer_id: i64, chunks: &[Chunk]) -> Result<()> {
let tx = self.conn.transaction().map_err(StorageError::from)?;
let now = Self::now();
{
let mut stmt = tx
.prepare(
r"
INSERT INTO chunks (
buffer_id, content, byte_start, byte_end, chunk_index,
strategy, token_count, line_start, line_end, has_overlap,
content_hash, custom_metadata, created_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
",
)
.map_err(StorageError::from)?;
for chunk in chunks {
let custom_meta = chunk.metadata.custom.clone();
let (line_start, line_end) = chunk
.metadata
.line_range
.as_ref()
.map_or((None, None), |r| (Some(r.start as i64), Some(r.end as i64)));
stmt.execute(params![
buffer_id,
chunk.content,
chunk.byte_range.start as i64,
chunk.byte_range.end as i64,
chunk.index as i64,
chunk.metadata.strategy,
chunk.metadata.token_count.map(|c| c as i64),
line_start,
line_end,
i64::from(chunk.metadata.has_overlap),
chunk.metadata.content_hash,
custom_meta,
now,
])
.map_err(StorageError::from)?;
}
}
tx.commit().map_err(StorageError::from)?;
self.conn
.execute(
"UPDATE buffers SET chunk_count = ? WHERE id = ?",
params![chunks.len() as i64, buffer_id],
)
.map_err(StorageError::from)?;
Ok(())
}
fn get_chunks(&self, buffer_id: i64) -> Result<Vec<Chunk>> {
let mut stmt = self
.conn
.prepare(
r"
SELECT id, buffer_id, content, byte_start, byte_end, chunk_index,
strategy, token_count, line_start, line_end, has_overlap,
content_hash, custom_metadata, created_at
FROM chunks WHERE buffer_id = ? ORDER BY chunk_index
",
)
.map_err(StorageError::from)?;
let chunks = stmt
.query_map(params![buffer_id], |row| {
let line_start: Option<i64> = row.get(8)?;
let line_end: Option<i64> = row.get(9)?;
let line_range = match (line_start, line_end) {
(Some(s), Some(e)) => Some((s as usize)..(e as usize)),
_ => None,
};
Ok(Chunk {
id: Some(row.get::<_, i64>(0)?),
buffer_id: row.get(1)?,
content: row.get(2)?,
byte_range: (row.get::<_, i64>(3)? as usize)..(row.get::<_, i64>(4)? as usize),
index: row.get::<_, i64>(5)? as usize,
metadata: ChunkMetadata {
strategy: row.get(6)?,
token_count: row.get::<_, Option<i64>>(7)?.map(|c| c as usize),
line_range,
has_overlap: row.get::<_, i64>(10)? != 0,
content_hash: row.get(11)?,
custom: row.get(12)?,
created_at: row.get(13)?,
},
})
})
.map_err(StorageError::from)?
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(StorageError::from)?;
Ok(chunks)
}
fn get_chunk(&self, id: i64) -> Result<Option<Chunk>> {
let result = self
.conn
.query_row(
r"
SELECT id, buffer_id, content, byte_start, byte_end, chunk_index,
strategy, token_count, line_start, line_end, has_overlap,
content_hash, custom_metadata, created_at
FROM chunks WHERE id = ?
",
params![id],
|row| {
let line_start: Option<i64> = row.get(8)?;
let line_end: Option<i64> = row.get(9)?;
let line_range = match (line_start, line_end) {
(Some(s), Some(e)) => Some((s as usize)..(e as usize)),
_ => None,
};
Ok(Chunk {
id: Some(row.get::<_, i64>(0)?),
buffer_id: row.get(1)?,
content: row.get(2)?,
byte_range: (row.get::<_, i64>(3)? as usize)
..(row.get::<_, i64>(4)? as usize),
index: row.get::<_, i64>(5)? as usize,
metadata: ChunkMetadata {
strategy: row.get(6)?,
token_count: row.get::<_, Option<i64>>(7)?.map(|c| c as usize),
line_range,
has_overlap: row.get::<_, i64>(10)? != 0,
content_hash: row.get(11)?,
custom: row.get(12)?,
created_at: row.get(13)?,
},
})
},
)
.optional()
.map_err(StorageError::from)?;
Ok(result)
}
fn delete_chunks(&mut self, buffer_id: i64) -> Result<()> {
self.conn
.execute("DELETE FROM chunks WHERE buffer_id = ?", params![buffer_id])
.map_err(StorageError::from)?;
self.conn
.execute(
"UPDATE buffers SET chunk_count = 0 WHERE id = ?",
params![buffer_id],
)
.map_err(StorageError::from)?;
Ok(())
}
fn chunk_count(&self, buffer_id: i64) -> Result<usize> {
let count: i64 = self
.conn
.query_row(
"SELECT COUNT(*) FROM chunks WHERE buffer_id = ?",
params![buffer_id],
|row| row.get(0),
)
.map_err(StorageError::from)?;
Ok(count as usize)
}
fn export_buffers(&self) -> Result<String> {
let buffers = self.list_buffers()?;
let mut output = String::new();
for (i, buffer) in buffers.iter().enumerate() {
if i > 0 {
output.push_str("\n\n");
}
output.push_str(&buffer.content);
}
Ok(output)
}
fn stats(&self) -> Result<StorageStats> {
let buffer_count = self.buffer_count()?;
let chunk_count: i64 = self
.conn
.query_row("SELECT COUNT(*) FROM chunks", [], |row| row.get(0))
.map_err(StorageError::from)?;
let total_size: i64 = self
.conn
.query_row("SELECT COALESCE(SUM(size), 0) FROM buffers", [], |row| {
row.get(0)
})
.map_err(StorageError::from)?;
let has_context = self.load_context()?.is_some();
let schema_version = self.get_schema_version()?.unwrap_or(0);
let db_size = self
.path
.as_ref()
.and_then(|p| std::fs::metadata(p).ok().map(|m| m.len()));
Ok(StorageStats {
buffer_count,
chunk_count: chunk_count as usize,
total_content_size: total_size as usize,
has_context,
schema_version,
db_size,
})
}
}
impl SqliteStorage {
#[allow(clippy::cast_possible_wrap)]
pub fn store_embedding(
&mut self,
chunk_id: i64,
embedding: &[f32],
model_name: Option<&str>,
) -> Result<()> {
let now = Self::now();
let bytes: Vec<u8> = embedding.iter().flat_map(|f| f.to_le_bytes()).collect();
self.conn
.execute(
r"
INSERT OR REPLACE INTO chunk_embeddings (chunk_id, embedding, dimensions, model_name, created_at)
VALUES (?, ?, ?, ?, ?)
",
params![chunk_id, bytes, embedding.len() as i64, model_name, now],
)
.map_err(StorageError::from)?;
Ok(())
}
pub fn get_embedding(&self, chunk_id: i64) -> Result<Option<Vec<f32>>> {
let result: Option<Vec<u8>> = self
.conn
.query_row(
"SELECT embedding FROM chunk_embeddings WHERE chunk_id = ?",
params![chunk_id],
|row| row.get(0),
)
.optional()
.map_err(StorageError::from)?;
Ok(result.map(|bytes| {
bytes
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect()
}))
}
pub fn get_embedding_models(&self, buffer_id: i64) -> Result<Vec<String>> {
let mut stmt = self
.conn
.prepare(
r"
SELECT DISTINCT ce.model_name
FROM chunk_embeddings ce
JOIN chunks c ON ce.chunk_id = c.id
WHERE c.buffer_id = ? AND ce.model_name IS NOT NULL
",
)
.map_err(StorageError::from)?;
let models = stmt
.query_map(params![buffer_id], |row| row.get::<_, String>(0))
.map_err(StorageError::from)?
.filter_map(std::result::Result::ok)
.collect();
Ok(models)
}
pub fn get_embedding_model_counts(&self, buffer_id: i64) -> Result<Vec<(Option<String>, i64)>> {
let mut stmt = self
.conn
.prepare(
r"
SELECT ce.model_name, COUNT(*) as count
FROM chunk_embeddings ce
JOIN chunks c ON ce.chunk_id = c.id
WHERE c.buffer_id = ?
GROUP BY ce.model_name
",
)
.map_err(StorageError::from)?;
let counts = stmt
.query_map(params![buffer_id], |row| {
Ok((row.get::<_, Option<String>>(0)?, row.get::<_, i64>(1)?))
})
.map_err(StorageError::from)?
.filter_map(std::result::Result::ok)
.collect();
Ok(counts)
}
#[allow(clippy::cast_possible_wrap)]
pub fn store_embeddings_batch(
&mut self,
embeddings: &[(i64, Vec<f32>)],
model_name: Option<&str>,
) -> Result<()> {
let tx = self.conn.transaction().map_err(StorageError::from)?;
let now = Self::now();
{
let mut stmt = tx
.prepare(
r"
INSERT OR REPLACE INTO chunk_embeddings (chunk_id, embedding, dimensions, model_name, created_at)
VALUES (?, ?, ?, ?, ?)
",
)
.map_err(StorageError::from)?;
for (chunk_id, embedding) in embeddings {
let bytes: Vec<u8> = embedding.iter().flat_map(|f| f.to_le_bytes()).collect();
stmt.execute(params![
chunk_id,
bytes,
embedding.len() as i64,
model_name,
now
])
.map_err(StorageError::from)?;
}
}
tx.commit().map_err(StorageError::from)?;
Ok(())
}
pub fn delete_embedding(&mut self, chunk_id: i64) -> Result<()> {
self.conn
.execute(
"DELETE FROM chunk_embeddings WHERE chunk_id = ?",
params![chunk_id],
)
.map_err(StorageError::from)?;
Ok(())
}
#[allow(clippy::cast_possible_wrap)]
pub fn search_fts(&self, query: &str, limit: usize) -> Result<Vec<(i64, f64)>> {
let fts_query = query
.split_whitespace()
.map(|term| format!("\"{}\"", term.replace('"', "\"\"")))
.collect::<Vec<_>>()
.join(" OR ");
let mut stmt = self
.conn
.prepare(
r"
SELECT rowid, -bm25(chunks_fts) as score
FROM chunks_fts
WHERE chunks_fts MATCH ?
ORDER BY score DESC
LIMIT ?
",
)
.map_err(StorageError::from)?;
let results = stmt
.query_map(params![fts_query, limit as i64], |row| {
Ok((row.get::<_, i64>(0)?, row.get::<_, f64>(1)?))
})
.map_err(StorageError::from)?
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(StorageError::from)?;
Ok(results)
}
pub fn get_all_embeddings(&self) -> Result<Vec<(i64, Vec<f32>)>> {
let mut stmt = self
.conn
.prepare("SELECT chunk_id, embedding FROM chunk_embeddings")
.map_err(StorageError::from)?;
let results = stmt
.query_map([], |row| {
let chunk_id: i64 = row.get(0)?;
let bytes: Vec<u8> = row.get(1)?;
let embedding: Vec<f32> = bytes
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
Ok((chunk_id, embedding))
})
.map_err(StorageError::from)?
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(StorageError::from)?;
Ok(results)
}
pub fn embedding_count(&self) -> Result<usize> {
let count: i64 = self
.conn
.query_row("SELECT COUNT(*) FROM chunk_embeddings", [], |row| {
row.get(0)
})
.map_err(StorageError::from)?;
Ok(count as usize)
}
pub fn has_embedding(&self, chunk_id: i64) -> Result<bool> {
let count: i64 = self
.conn
.query_row(
"SELECT COUNT(*) FROM chunk_embeddings WHERE chunk_id = ?",
params![chunk_id],
|row| row.get(0),
)
.map_err(StorageError::from)?;
Ok(count > 0)
}
pub fn get_chunks_needing_embedding(
&self,
buffer_id: i64,
current_model: Option<&str>,
) -> Result<Vec<i64>> {
let mut results = Vec::new();
let mut stmt = self
.conn
.prepare(
r"
SELECT c.id FROM chunks c
LEFT JOIN chunk_embeddings e ON c.id = e.chunk_id
WHERE c.buffer_id = ? AND e.chunk_id IS NULL
",
)
.map_err(StorageError::from)?;
let rows = stmt
.query_map(params![buffer_id], |row| row.get(0))
.map_err(StorageError::from)?;
for row in rows {
results.push(row.map_err(StorageError::from)?);
}
if let Some(model) = current_model {
let mut stmt = self
.conn
.prepare(
r"
SELECT c.id FROM chunks c
INNER JOIN chunk_embeddings e ON c.id = e.chunk_id
WHERE c.buffer_id = ? AND (e.model_name IS NULL OR e.model_name != ?)
",
)
.map_err(StorageError::from)?;
let rows = stmt
.query_map(params![buffer_id, model], |row| row.get(0))
.map_err(StorageError::from)?;
for row in rows {
results.push(row.map_err(StorageError::from)?);
}
}
results.sort_unstable();
results.dedup();
Ok(results)
}
pub fn get_chunks_without_embedding(&self, buffer_id: i64) -> Result<Vec<i64>> {
self.get_chunks_needing_embedding(buffer_id, None)
}
pub fn delete_embeddings_by_model(
&mut self,
buffer_id: i64,
model_name: Option<&str>,
) -> Result<usize> {
let deleted = match model_name {
Some(name) => self
.conn
.execute(
r"
DELETE FROM chunk_embeddings
WHERE chunk_id IN (
SELECT id FROM chunks WHERE buffer_id = ?
) AND model_name = ?
",
params![buffer_id, name],
)
.map_err(StorageError::from)?,
None => self
.conn
.execute(
r"
DELETE FROM chunk_embeddings
WHERE chunk_id IN (
SELECT id FROM chunks WHERE buffer_id = ?
) AND model_name IS NULL
",
params![buffer_id],
)
.map_err(StorageError::from)?,
};
Ok(deleted)
}
pub fn get_embedding_stats(&self, buffer_id: i64) -> Result<EmbeddingStats> {
let total_chunks: i64 = self
.conn
.query_row(
"SELECT COUNT(*) FROM chunks WHERE buffer_id = ?",
params![buffer_id],
|row| row.get(0),
)
.map_err(StorageError::from)?;
let embedded_chunks: i64 = self
.conn
.query_row(
r"
SELECT COUNT(*) FROM chunk_embeddings e
INNER JOIN chunks c ON e.chunk_id = c.id
WHERE c.buffer_id = ?
",
params![buffer_id],
|row| row.get(0),
)
.map_err(StorageError::from)?;
let model_counts = self.get_embedding_model_counts(buffer_id)?;
Ok(EmbeddingStats {
total_chunks: total_chunks as usize,
embedded_chunks: embedded_chunks as usize,
model_counts,
})
}
}
#[derive(Debug, Clone)]
pub struct EmbeddingStats {
pub total_chunks: usize,
pub embedded_chunks: usize,
pub model_counts: Vec<(Option<String>, i64)>,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::ContextValue;
fn setup() -> SqliteStorage {
let mut storage = SqliteStorage::in_memory().unwrap();
storage.init().unwrap();
storage
}
#[test]
fn test_init() {
let mut storage = SqliteStorage::in_memory().unwrap();
assert!(storage.init().is_ok());
assert!(storage.is_initialized().unwrap());
}
#[test]
fn test_init_idempotent() {
let mut storage = SqliteStorage::in_memory().unwrap();
assert!(storage.init().is_ok());
assert!(storage.init().is_ok()); }
#[test]
fn test_context_crud() {
let mut storage = setup();
assert!(storage.load_context().unwrap().is_none());
let mut ctx = Context::new();
ctx.set_variable("key".to_string(), ContextValue::String("value".to_string()));
storage.save_context(&ctx).unwrap();
let loaded = storage.load_context().unwrap().unwrap();
assert_eq!(
loaded.get_variable("key"),
Some(&ContextValue::String("value".to_string()))
);
storage.delete_context().unwrap();
assert!(storage.load_context().unwrap().is_none());
}
#[test]
fn test_buffer_crud() {
let mut storage = setup();
let buffer = Buffer::from_named("test".to_string(), "Hello, world!".to_string());
let id = storage.add_buffer(&buffer).unwrap();
assert!(id > 0);
let loaded = storage.get_buffer(id).unwrap().unwrap();
assert_eq!(loaded.name, Some("test".to_string()));
assert_eq!(loaded.content, "Hello, world!");
let by_name = storage.get_buffer_by_name("test").unwrap().unwrap();
assert_eq!(by_name.id, Some(id));
let buffers = storage.list_buffers().unwrap();
assert_eq!(buffers.len(), 1);
let mut updated = loaded;
updated.content = "Updated content".to_string();
storage.update_buffer(&updated).unwrap();
let reloaded = storage.get_buffer(id).unwrap().unwrap();
assert_eq!(reloaded.content, "Updated content");
storage.delete_buffer(id).unwrap();
assert!(storage.get_buffer(id).unwrap().is_none());
}
#[test]
fn test_chunk_crud() {
let mut storage = setup();
let buffer = Buffer::from_content("Hello, world!".to_string());
let buffer_id = storage.add_buffer(&buffer).unwrap();
let chunks = vec![
Chunk::new(buffer_id, "Hello, ".to_string(), 0..7, 0),
Chunk::new(buffer_id, "world!".to_string(), 7..13, 1),
];
storage.add_chunks(buffer_id, &chunks).unwrap();
let loaded = storage.get_chunks(buffer_id).unwrap();
assert_eq!(loaded.len(), 2);
assert_eq!(loaded[0].content, "Hello, ");
assert_eq!(loaded[1].content, "world!");
assert_eq!(storage.chunk_count(buffer_id).unwrap(), 2);
let chunk_id = loaded[0].id.unwrap();
let single = storage.get_chunk(chunk_id).unwrap().unwrap();
assert_eq!(single.content, "Hello, ");
storage.delete_chunks(buffer_id).unwrap();
assert_eq!(storage.chunk_count(buffer_id).unwrap(), 0);
}
#[test]
fn test_cascade_delete() {
let mut storage = setup();
let buffer = Buffer::from_content("Hello, world!".to_string());
let buffer_id = storage.add_buffer(&buffer).unwrap();
let chunks = vec![Chunk::new(buffer_id, "Hello".to_string(), 0..5, 0)];
storage.add_chunks(buffer_id, &chunks).unwrap();
assert_eq!(storage.chunk_count(buffer_id).unwrap(), 1);
storage.delete_buffer(buffer_id).unwrap();
let count: i64 = storage
.conn
.query_row("SELECT COUNT(*) FROM chunks", [], |row| row.get(0))
.unwrap();
assert_eq!(count, 0);
}
#[test]
fn test_reset() {
let mut storage = setup();
let ctx = Context::new();
storage.save_context(&ctx).unwrap();
let buffer = Buffer::from_content("test".to_string());
storage.add_buffer(&buffer).unwrap();
storage.reset().unwrap();
assert!(storage.load_context().unwrap().is_none());
assert_eq!(storage.buffer_count().unwrap(), 0);
}
#[test]
fn test_stats() {
let mut storage = setup();
let stats = storage.stats().unwrap();
assert_eq!(stats.buffer_count, 0);
assert_eq!(stats.chunk_count, 0);
assert!(!stats.has_context);
let ctx = Context::new();
storage.save_context(&ctx).unwrap();
let buffer = Buffer::from_content("Hello, world!".to_string());
let buffer_id = storage.add_buffer(&buffer).unwrap();
let chunks = vec![Chunk::new(buffer_id, "Hello".to_string(), 0..5, 0)];
storage.add_chunks(buffer_id, &chunks).unwrap();
let stats = storage.stats().unwrap();
assert_eq!(stats.buffer_count, 1);
assert_eq!(stats.chunk_count, 1);
assert!(stats.has_context);
assert_eq!(stats.total_content_size, 13);
}
#[test]
fn test_export_buffers() {
let mut storage = setup();
storage
.add_buffer(&Buffer::from_content("First".to_string()))
.unwrap();
storage
.add_buffer(&Buffer::from_content("Second".to_string()))
.unwrap();
let exported = storage.export_buffers().unwrap();
assert_eq!(exported, "First\n\nSecond");
}
}