use std::collections::HashMap;
use std::path::Path;
use std::sync::{Arc, Mutex};
use rusqlite::Connection;
use crate::{
cache::{AdjacencyCache, CacheStats},
config::SqliteConfig,
errors::SqliteGraphError,
hnsw::HnswIndex,
introspection::{EdgeCount, GraphIntrospection, IntrospectError},
mvcc::SnapshotManager,
query_cache::QueryCache,
schema::ensure_schema,
};
use super::{
metrics::{GraphMetrics, StatementTracker},
pool::PoolManager,
};
pub struct SqliteGraph {
pub pool: PoolManager,
pub(crate) outgoing_cache: AdjacencyCache,
pub(crate) incoming_cache: AdjacencyCache,
pub(crate) query_cache: QueryCache,
pub(crate) metrics: Arc<GraphMetrics>,
pub(crate) statement_tracker: Arc<StatementTracker>,
pub(crate) snapshot_manager: SnapshotManager,
pub hnsw_indexes: Mutex<HashMap<String, HnswIndex>>,
}
pub fn is_in_memory_connection(conn: &Connection) -> bool {
match conn.pragma_query_value(None, "database_list", |row| {
let name: String = row.get(1)?;
Ok(name)
}) {
Ok(name) => name.is_empty() || name == ":memory:",
Err(_) => true, }
}
impl SqliteGraph {
pub fn open_with_config<P: AsRef<Path>>(
path: P,
cfg: &SqliteConfig,
) -> Result<Self, SqliteGraphError> {
let pool_size = cfg.pool_size.unwrap_or(5) as u32;
let pool = PoolManager::with_max_size(path, pool_size)
.map_err(|e| SqliteGraphError::connection(e.to_string()))?;
{
let conn = pool
.get()
.map_err(|e| SqliteGraphError::connection(e.to_string()))?;
if cfg.without_migrations {
crate::schema::ensure_schema_without_migrations(&conn)?;
} else {
ensure_schema(&conn)?;
}
}
pool.configure_pool(|conn| {
let cache_size = cfg.cache_size.unwrap_or(128);
conn.set_prepared_statement_cache_capacity(cache_size);
let result = conn.pragma_update(None, "journal_mode", "WAL");
if result.is_err() {
let _ = conn.pragma_update(None, "journal_mode", "DELETE");
}
let _ = conn.pragma_update(None, "synchronous", "NORMAL"); let _ = conn.pragma_update(None, "cache_size", "-64000"); let _ = conn.pragma_update(None, "temp_store", "MEMORY"); let _ = conn.pragma_update(None, "mmap_size", "268435456");
for (key, value) in &cfg.pragma_settings {
let _ = conn.pragma_update(None, key, value.as_str());
}
Ok(())
})?;
let hnsw_indexes = {
let conn = pool
.get()
.map_err(|e| SqliteGraphError::connection(e.to_string()))?;
Self::load_hnsw_indexes(&conn).unwrap_or_default()
};
Ok(Self {
pool,
outgoing_cache: AdjacencyCache::new(),
incoming_cache: AdjacencyCache::new(),
query_cache: QueryCache::new(),
metrics: Arc::new(GraphMetrics::default()),
statement_tracker: Arc::new(StatementTracker::default()),
snapshot_manager: SnapshotManager::new(),
hnsw_indexes: Mutex::new(hnsw_indexes),
})
}
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self, SqliteGraphError> {
Self::open_with_config(path, &SqliteConfig::default())
}
pub fn open_without_migrations<P: AsRef<Path>>(path: P) -> Result<Self, SqliteGraphError> {
let cfg = SqliteConfig::new().with_migrations_disabled(true);
Self::open_with_config(path, &cfg)
}
pub fn open_in_memory_with_config(cfg: &SqliteConfig) -> Result<Self, SqliteGraphError> {
let mut pool =
PoolManager::in_memory().map_err(|e| SqliteGraphError::connection(e.to_string()))?;
let cache_size = cfg.cache_size.unwrap_or(128);
pool.configure_direct(|conn| {
if cfg.without_migrations {
crate::schema::ensure_schema_without_migrations(conn)
.map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
} else {
ensure_schema(conn)
.map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
}
conn.set_prepared_statement_cache_capacity(cache_size);
for (key, value) in &cfg.pragma_settings {
let _ = conn.pragma_update(None, key, value.as_str());
}
Ok(())
})
.map_err(|e| SqliteGraphError::connection(e.to_string()))?;
let hnsw_indexes = pool
.direct_connection()
.map(|conn| Self::load_hnsw_indexes(conn).unwrap_or_default())
.unwrap_or_default();
Ok(Self {
pool,
outgoing_cache: AdjacencyCache::new(),
incoming_cache: AdjacencyCache::new(),
query_cache: QueryCache::new(),
metrics: Arc::new(GraphMetrics::default()),
statement_tracker: Arc::new(StatementTracker::default()),
snapshot_manager: SnapshotManager::new(),
hnsw_indexes: Mutex::new(hnsw_indexes),
})
}
pub fn open_in_memory() -> Result<Self, SqliteGraphError> {
Self::open_in_memory_with_config(&SqliteConfig::default())
}
pub fn open_in_memory_without_migrations() -> Result<Self, SqliteGraphError> {
let cfg = SqliteConfig::new().with_migrations_disabled(true);
Self::open_in_memory_with_config(&cfg)
}
fn load_hnsw_indexes(
conn: &Connection,
) -> Result<HashMap<String, HnswIndex>, SqliteGraphError> {
let mut indexes = HashMap::new();
let index_names = HnswIndex::list_indexes(conn).map_err(|e| {
SqliteGraphError::invalid_input(format!("Failed to load HNSW indexes: {}", e))
})?;
for name in index_names {
match HnswIndex::load_with_vectors(conn, &name) {
Ok(hnsw) => {
indexes.insert(name, hnsw);
}
Err(e) => {
eprintln!("Warning: Failed to load HNSW index '{}': {}", name, e);
}
}
}
Ok(indexes)
}
pub fn introspect(&self) -> Result<GraphIntrospection, SqliteGraphError> {
let backend_type = "sqlite".to_string();
let node_count = self
.all_entity_ids()
.map_err(|e| IntrospectError::NodeCountError(e.to_string()))?
.len();
let edge_count = self.count_edges()?;
let outgoing_stats = self.outgoing_cache.stats();
let incoming_stats = self.incoming_cache.stats();
let cache_stats = CacheStats {
hits: outgoing_stats.hits + incoming_stats.hits,
misses: outgoing_stats.misses + incoming_stats.misses,
entries: outgoing_stats.entries + incoming_stats.entries,
};
let is_in_memory = self.pool.is_in_memory();
let file_size = if is_in_memory {
None
} else {
self.get_database_path()
.and_then(crate::introspection::get_file_size)
};
let wal_size = if is_in_memory {
None
} else {
self.get_database_path()
.and_then(crate::introspection::get_wal_size)
};
let memory_usage = None;
Ok(GraphIntrospection {
backend_type,
node_count,
edge_count,
cache_stats,
memory_usage,
file_size,
wal_size,
is_in_memory,
})
}
pub fn cache_stats(&self) -> CacheStats {
let outgoing_stats = self.outgoing_cache.stats();
let incoming_stats = self.incoming_cache.stats();
CacheStats {
hits: outgoing_stats.hits + incoming_stats.hits,
misses: outgoing_stats.misses + incoming_stats.misses,
entries: outgoing_stats.entries + incoming_stats.entries,
}
}
fn count_edges(&self) -> Result<EdgeCount, SqliteGraphError> {
let conn = self.connection();
let estimate: i64 = conn
.query_row("SELECT COUNT(*) FROM graph_edges", [], |row| row.get(0))
.map_err(|e| SqliteGraphError::query(e.to_string()))?;
if estimate < 10_000 {
return Ok(EdgeCount::Exact(estimate as usize));
}
let sample_size = 1000.min(estimate as usize);
let sample_count: i64 = conn
.query_row(
&format!(
"SELECT COUNT(*) FROM (
SELECT 1 FROM graph_edges
ORDER BY RANDOM()
LIMIT {}
)",
sample_size
),
[],
|row| row.get(0),
)
.map_err(|e| SqliteGraphError::query(e.to_string()))?;
let _ratio = sample_count as f64 / sample_size as f64;
let margin = estimate as f64 * 0.02;
Ok(EdgeCount::Estimate {
count: estimate as usize,
min: ((estimate as f64 - margin).floor() as usize),
max: ((estimate as f64 + margin).ceil() as usize),
sample_size,
})
}
fn get_database_path(&self) -> Option<String> {
if self.pool.is_in_memory() {
None
} else {
self.pool.get().ok().and_then(|conn| {
conn.pragma_query_value(None, "database_list", |row| {
let name: String = row.get(1)?;
Ok(name)
})
.ok()
.filter(|name| !name.is_empty() && name != ":memory:")
})
}
}
}