use std::path::Path;
use std::sync::Arc;
use deadpool_sqlite::{Config, Hook, HookError, Pool, Runtime};
use solo_core::{Error, Result, VectorIndex};
use crate::key_material::KeyMaterial;
pub const DEFAULT_POOL_SIZE: usize = 2;
#[derive(Clone)]
pub struct ReaderPool {
pool: Pool,
hnsw: Arc<dyn VectorIndex + Send + Sync>,
}
impl ReaderPool {
pub fn new(
db_path: &Path,
key: Option<KeyMaterial>,
hnsw: Arc<dyn VectorIndex + Send + Sync>,
) -> Result<Self> {
Self::with_size(db_path, key, DEFAULT_POOL_SIZE, hnsw)
}
pub fn with_size(
db_path: &Path,
key: Option<KeyMaterial>,
size: usize,
hnsw: Arc<dyn VectorIndex + Send + Sync>,
) -> Result<Self> {
let cfg = Config::new(db_path);
let mut builder = cfg
.builder(Runtime::Tokio1)
.map_err(|e| Error::storage(format!("deadpool config: {e:?}")))?
.max_size(size);
if let Some(key) = key {
let key_hex = key.as_hex();
builder = builder.post_create(Hook::async_fn(move |conn, _metrics| {
let key_hex = key_hex.clone();
Box::pin(async move {
let pragma = format!("PRAGMA key = \"x'{}'\"", &*key_hex);
conn.interact(move |c| {
c.execute_batch(&pragma)?;
c.execute_batch(
"PRAGMA foreign_keys = ON;
PRAGMA busy_timeout = 5000;",
)?;
Ok::<_, rusqlite::Error>(())
})
.await
.map_err(|e| HookError::message(format!("interact: {e}")))?
.map_err(|e| HookError::message(format!("PRAGMA key: {e}")))?;
Ok(())
})
}));
} else {
builder = builder.post_create(Hook::async_fn(|conn, _metrics| {
Box::pin(async move {
conn.interact(|c| {
c.execute_batch(
"PRAGMA foreign_keys = ON;
PRAGMA busy_timeout = 5000;",
)
})
.await
.map_err(|e| HookError::message(format!("interact: {e}")))?
.map_err(|e| HookError::message(format!("PRAGMA setup: {e}")))?;
Ok(())
})
}));
}
let pool = builder
.build()
.map_err(|e| Error::storage(format!("deadpool build: {e:?}")))?;
Ok(Self { pool, hnsw })
}
pub async fn interact<F, R>(&self, f: F) -> Result<R>
where
F: FnOnce(&mut rusqlite::Connection) -> rusqlite::Result<R> + Send + 'static,
R: Send + 'static,
{
let conn = self
.pool
.get()
.await
.map_err(|e| Error::storage(format!("pool get: {e:?}")))?;
conn.interact(f)
.await
.map_err(|e| Error::storage(format!("interact: {e}")))?
.map_err(|e| Error::storage(format!("rusqlite: {e}")))
}
pub fn hnsw(&self) -> &Arc<dyn VectorIndex + Send + Sync> {
&self.hnsw
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_support::{StubVectorIndex, fixture_embedding, fixture_episode, open_test_db_at};
use crate::writer::WriterActor;
fn rt() -> tokio::runtime::Runtime {
tokio::runtime::Builder::new_multi_thread()
.worker_threads(2)
.enable_all()
.build()
.unwrap()
}
#[test]
fn pool_returns_connections() {
let runtime = rt();
let tmp = tempfile::TempDir::new().unwrap();
let path = tmp.path().join("test.db");
let _ = open_test_db_at(&path);
runtime.block_on(async {
let hnsw: Arc<dyn VectorIndex + Send + Sync> = Arc::new(StubVectorIndex::new(4));
let pool = ReaderPool::new(&path, None, hnsw).unwrap();
let n: u32 = pool
.interact(|conn| {
conn.query_row(
"SELECT MAX(version) FROM schema_migrations",
[],
|row| row.get(0),
)
})
.await
.unwrap();
assert_eq!(n, 2);
});
}
#[test]
fn reader_sees_writes_committed_through_writer_actor() {
let runtime = rt();
let tmp = tempfile::TempDir::new().unwrap();
let path = tmp.path().join("test.db");
let writer_conn = open_test_db_at(&path);
let hnsw: Arc<dyn VectorIndex + Send + Sync> = Arc::new(StubVectorIndex::new(4));
let crate::writer::WriterSpawn { handle, join: _ } =
WriterActor::spawn(writer_conn, hnsw.clone());
runtime.block_on(async {
let pool = ReaderPool::new(&path, None, hnsw).unwrap();
let episode = fixture_episode("reader-visibility test");
let mid = handle
.remember(episode.clone(), fixture_embedding(4))
.await
.unwrap();
assert_eq!(mid, episode.memory_id);
let mid_str = mid.to_string();
let content: String = pool
.interact(move |conn| {
conn.query_row(
"SELECT content FROM episodes WHERE memory_id = ?",
[mid_str],
|row| row.get(0),
)
})
.await
.unwrap();
assert_eq!(content, "reader-visibility test");
});
drop(handle);
}
#[test]
fn many_concurrent_reads_serve_from_pool() {
let runtime = rt();
let tmp = tempfile::TempDir::new().unwrap();
let path = tmp.path().join("test.db");
let _ = open_test_db_at(&path);
runtime.block_on(async {
let hnsw: Arc<dyn VectorIndex + Send + Sync> = Arc::new(StubVectorIndex::new(4));
let pool = ReaderPool::with_size(&path, None, 4, hnsw).unwrap();
let mut tasks = Vec::new();
for _ in 0..32 {
let p = pool.pool.clone();
tasks.push(tokio::spawn(async move {
let conn = p.get().await.unwrap();
conn.interact(|c| {
c.query_row(
"SELECT MAX(version) FROM schema_migrations",
[],
|row| row.get::<_, u32>(0),
)
})
.await
.unwrap()
.unwrap()
}));
}
let mut counts = Vec::new();
for t in tasks {
counts.push(t.await.unwrap());
}
assert_eq!(counts.len(), 32);
assert!(counts.iter().all(|c| *c == 2));
});
}
}