use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TursoConfig {
pub database_url: String,
pub auth_token: String,
pub enable_vectors: bool,
pub remote_only: bool,
}
impl Default for TursoConfig {
fn default() -> Self {
Self {
database_url: "file:local.db".to_string(),
auth_token: String::new(),
enable_vectors: false,
remote_only: false,
}
}
}
impl TursoConfig {
#[must_use]
pub fn new(database_url: String, auth_token: String) -> Self {
Self {
database_url,
auth_token,
enable_vectors: false,
remote_only: false,
}
}
#[must_use]
pub fn local_only() -> Self {
Self {
database_url: "file:local.db".to_string(),
auth_token: String::new(),
enable_vectors: false,
remote_only: false,
}
}
#[must_use]
pub fn remote_only(database_url: String, auth_token: String) -> Self {
Self {
database_url,
auth_token,
enable_vectors: false,
remote_only: true,
}
}
#[must_use]
pub fn hybrid(database_url: String, auth_token: String) -> Self {
Self {
database_url,
auth_token,
enable_vectors: false,
remote_only: false,
}
}
#[must_use]
pub fn with_vectors(mut self, enable: bool) -> Self {
self.enable_vectors = enable;
self
}
#[must_use]
pub fn is_local_only(&self) -> bool {
self.database_url.starts_with("file:") || self.auth_token.is_empty()
}
#[must_use]
pub fn is_remote(&self) -> bool {
!self.is_local_only()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct MigrationStats {
pub nodes_migrated: usize,
pub edges_migrated: usize,
pub embeddings_migrated: usize,
pub migration_time_ms: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StorageMode {
None,
LocalOnly,
RemoteOnly,
Hybrid,
}
#[derive(Debug, Error)]
pub enum StorageError {
#[error("Connection failed: {0}")]
ConnectionFailed(String),
#[error("Migration failed: {0}")]
MigrationFailed(String),
#[error("Vector extension not available")]
VectorExtensionNotAvailable,
#[error("Local storage error: {0}")]
LocalStorageError(String),
#[error("Remote query failed: {0}")]
RemoteQueryFailed(String),
}
pub struct HybridStorage {
pub local: Option<crate::storage::Storage>,
pub config: TursoConfig,
pub vectors_initialized: bool,
}
impl HybridStorage {
pub fn new(config: TursoConfig) -> Result<Self, StorageError> {
let local = if !config.remote_only {
let storage = crate::storage::Storage::open("local.db")
.map_err(|e| StorageError::LocalStorageError(format!("{:?}", e)))?;
Some(storage)
} else {
None
};
Ok(Self {
local,
config,
vectors_initialized: false,
})
}
pub fn init_vectors(&mut self) -> Result<(), StorageError> {
if !self.config.enable_vectors {
return Ok(());
}
if let Some(storage) = &self.local {
self.init_local_vectors(storage)?;
}
self.vectors_initialized = true;
tracing::info!("Vector extension initialized successfully");
Ok(())
}
fn init_local_vectors(&self, storage: &crate::storage::Storage) -> Result<(), StorageError> {
let conn = storage.conn();
conn.execute(
"CREATE TABLE IF NOT EXISTS node_metadata (
node_id TEXT PRIMARY KEY,
symbol_name TEXT NOT NULL,
file_path TEXT NOT NULL,
node_type TEXT NOT NULL,
created_at INTEGER DEFAULT (strftime('%s', 'now'))
)",
[],
)
.map_err(|e| {
StorageError::LocalStorageError(format!(
"Failed to create node_metadata table: {:?}",
e
))
})?;
conn.execute(
"CREATE TABLE IF NOT EXISTS node_embeddings (
node_id TEXT PRIMARY KEY,
embedding BLOB NOT NULL,
dimension INTEGER NOT NULL,
FOREIGN KEY (node_id) REFERENCES node_metadata(node_id) ON DELETE CASCADE
)",
[],
)
.map_err(|e| {
StorageError::LocalStorageError(format!(
"Failed to create node_embeddings table: {:?}",
e
))
})?;
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_node_embeddings_dimension
ON node_embeddings(dimension)",
[],
)
.map_err(|e| StorageError::LocalStorageError(format!("Failed to create index: {:?}", e)))?;
Ok(())
}
pub fn store_embedding(
&self,
node_id: &str,
symbol_name: &str,
file_path: &str,
node_type: &str,
embedding: &[f32],
) -> Result<(), StorageError> {
if !self.vectors_initialized {
return Err(StorageError::VectorExtensionNotAvailable);
}
let storage = self
.local
.as_ref()
.ok_or(StorageError::VectorExtensionNotAvailable)?;
self.store_local_embedding(
storage,
node_id,
symbol_name,
file_path,
node_type,
embedding,
)
}
fn store_local_embedding(
&self,
storage: &crate::storage::Storage,
node_id: &str,
symbol_name: &str,
file_path: &str,
node_type: &str,
embedding: &[f32],
) -> Result<(), StorageError> {
use rusqlite::params;
if embedding.len() != 768 {
return Err(StorageError::LocalStorageError(format!(
"Invalid embedding dimension: {}, expected 768",
embedding.len()
)));
}
let conn = storage.conn();
conn.execute(
"INSERT OR REPLACE INTO node_metadata (node_id, symbol_name, file_path, node_type)
VALUES (?1, ?2, ?3, ?4)",
params![node_id, symbol_name, file_path, node_type],
)
.map_err(|e| {
StorageError::LocalStorageError(format!("Failed to insert metadata: {:?}", e))
})?;
let embedding_bytes: Vec<u8> = embedding.iter().flat_map(|v| v.to_le_bytes()).collect();
conn.execute(
"INSERT OR REPLACE INTO node_embeddings (node_id, embedding, dimension)
VALUES (?1, ?2, ?3)",
params![node_id, embedding_bytes, embedding.len() as i32],
)
.map_err(|e| {
StorageError::LocalStorageError(format!("Failed to insert embedding: {:?}", e))
})?;
Ok(())
}
pub fn search_similar(
&self,
query_embedding: &[f32],
k: usize,
) -> Result<Vec<(String, f32)>, StorageError> {
if !self.vectors_initialized {
return Err(StorageError::VectorExtensionNotAvailable);
}
let storage = self
.local
.as_ref()
.ok_or(StorageError::VectorExtensionNotAvailable)?;
self.search_local_similar(storage, query_embedding, k)
}
fn search_local_similar(
&self,
storage: &crate::storage::Storage,
query_embedding: &[f32],
k: usize,
) -> Result<Vec<(String, f32)>, StorageError> {
use rusqlite::Row;
if query_embedding.len() != 768 {
return Err(StorageError::LocalStorageError(format!(
"Invalid query embedding dimension: {}, expected 768",
query_embedding.len()
)));
}
let conn = storage.conn();
let mut stmt = conn
.prepare(
"SELECT n.node_id, e.embedding
FROM node_embeddings e
JOIN node_metadata n ON e.node_id = n.node_id
WHERE e.dimension = 768",
)
.map_err(|e| {
StorageError::LocalStorageError(format!("Failed to prepare query: {:?}", e))
})?;
let rows = stmt
.query_map([], |row: &Row<'_>| {
let node_id: String = row.get(0)?;
let embedding_bytes: Vec<u8> = row.get(1)?;
Ok((node_id, embedding_bytes))
})
.map_err(|e| {
StorageError::LocalStorageError(format!("Failed to execute query: {:?}", e))
})?;
let mut results: Vec<(String, f32)> = Vec::new();
for row in rows {
let (node_id, embedding_bytes) = row.map_err(|e| {
StorageError::LocalStorageError(format!("Failed to read row: {:?}", e))
})?;
let stored_embedding: Vec<f32> = embedding_bytes
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
let similarity = cosine_similarity(query_embedding, &stored_embedding);
results.push((node_id, similarity));
}
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(k);
Ok(results)
}
pub fn batch_store_embeddings(
&self,
embeddings: &[(&str, &str, &str, &str, &[f32])],
) -> Result<usize, StorageError> {
if !self.vectors_initialized {
return Err(StorageError::VectorExtensionNotAvailable);
}
let storage = self
.local
.as_ref()
.ok_or(StorageError::VectorExtensionNotAvailable)?;
let mut stored = 0;
for (node_id, symbol_name, file_path, node_type, embedding) in embeddings {
if self
.store_local_embedding(
storage,
node_id,
symbol_name,
file_path,
node_type,
embedding,
)
.is_ok()
{
stored += 1;
}
}
Ok(stored)
}
#[must_use]
pub fn local(&self) -> Option<&crate::storage::Storage> {
self.local.as_ref()
}
pub fn local_mut(&mut self) -> Option<&mut crate::storage::Storage> {
self.local.as_mut()
}
pub fn migrate_to_remote(&self) -> Result<MigrationStats, StorageError> {
Ok(MigrationStats::default())
}
#[must_use]
pub fn has_local(&self) -> bool {
self.local.is_some()
}
#[must_use]
pub fn has_remote(&self) -> bool {
self.config.is_remote()
}
#[must_use]
pub fn mode(&self) -> StorageMode {
match (self.local.is_some(), self.config.is_remote()) {
(true, false) => StorageMode::LocalOnly,
(false, true) => StorageMode::RemoteOnly,
(true, true) => StorageMode::Hybrid,
(false, false) => StorageMode::None,
}
}
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot_product / (norm_a * norm_b)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_turso_config_default() {
let config = TursoConfig::default();
assert_eq!(config.database_url, "file:local.db");
assert!(config.auth_token.is_empty());
assert!(!config.enable_vectors);
assert!(!config.remote_only);
}
#[test]
fn test_turso_config_local_only() {
let config = TursoConfig::local_only();
assert!(config.is_local_only());
assert!(!config.is_remote());
}
#[test]
fn test_turso_config_remote_only() {
let config = TursoConfig::remote_only(
"libsql://token@db.turso.io".to_string(),
"auth_token".to_string(),
);
assert!(config.is_remote());
assert!(!config.is_local_only());
assert!(config.remote_only);
}
#[test]
fn test_turso_config_hybrid() {
let config = TursoConfig::hybrid(
"libsql://token@db.turso.io".to_string(),
"auth_token".to_string(),
);
assert!(config.is_remote());
assert!(!config.is_local_only());
assert!(!config.remote_only);
}
#[test]
fn test_turso_config_with_vectors() {
let config = TursoConfig::local_only().with_vectors(true);
assert!(config.enable_vectors);
}
#[test]
fn test_migration_stats_default() {
let stats = MigrationStats::default();
assert_eq!(stats.nodes_migrated, 0);
assert_eq!(stats.edges_migrated, 0);
assert_eq!(stats.embeddings_migrated, 0);
assert_eq!(stats.migration_time_ms, 0);
}
#[test]
fn test_hybrid_storage_local_only() {
let config = TursoConfig::local_only();
let storage = HybridStorage::new(config);
assert!(storage.is_ok());
let storage = storage.unwrap();
assert!(storage.has_local());
assert!(!storage.has_remote());
assert_eq!(storage.mode(), StorageMode::LocalOnly);
}
#[test]
fn test_hybrid_storage_remote_only_fails_without_url() {
let config = TursoConfig::remote_only("".to_string(), "".to_string());
let result = HybridStorage::new(config);
assert!(result.is_ok());
let storage = result.unwrap();
assert!(!storage.has_local());
assert!(!storage.has_remote());
assert_eq!(storage.mode(), StorageMode::None);
}
#[test]
fn test_storage_mode_display() {
assert_eq!(format!("{:?}", StorageMode::LocalOnly), "LocalOnly");
assert_eq!(format!("{:?}", StorageMode::RemoteOnly), "RemoteOnly");
assert_eq!(format!("{:?}", StorageMode::Hybrid), "Hybrid");
assert_eq!(format!("{:?}", StorageMode::None), "None");
}
#[test]
fn test_turso_config_is_local_only() {
let config = TursoConfig::local_only();
assert!(config.is_local_only());
let config = TursoConfig {
database_url: "file:test.db".to_string(),
..Default::default()
};
assert!(config.is_local_only());
}
#[test]
fn test_turso_config_is_remote() {
let config = TursoConfig {
database_url: "libsql://token@db.turso.io".to_string(),
auth_token: "some_token".to_string(),
..Default::default()
};
assert!(config.is_remote());
assert!(!config.is_local_only());
}
#[test]
fn test_storage_error_messages() {
let err = StorageError::ConnectionFailed("test".to_string());
assert_eq!(format!("{}", err), "Connection failed: test");
let err = StorageError::MigrationFailed("test".to_string());
assert_eq!(format!("{}", err), "Migration failed: test");
let err = StorageError::VectorExtensionNotAvailable;
assert_eq!(format!("{}", err), "Vector extension not available");
}
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![1.0, 2.0, 3.0];
let sim = cosine_similarity(&a, &b);
assert!((sim - 1.0).abs() < 0.001);
let c = vec![1.0, 0.0, 0.0];
let d = vec![0.0, 1.0, 0.0];
let sim = cosine_similarity(&c, &d);
assert!((sim - 0.0).abs() < 0.001);
let e: Vec<f32> = vec![];
let sim = cosine_similarity(&a, &e);
assert_eq!(sim, 0.0);
}
#[test]
fn test_cosine_similarity_parallel() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![2.0, 4.0, 6.0, 8.0];
let sim = cosine_similarity(&a, &b);
assert!((sim - 1.0).abs() < 0.001);
}
}