use crate::rag::engine::DocumentChunk;
use crate::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
type VectorStore = Arc<RwLock<Vec<(String, Vec<f32>)>>>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorIndex {
pub id: String,
pub name: String,
pub index_type: IndexType,
pub dimensions: usize,
pub vector_count: usize,
pub metadata: HashMap<String, String>,
pub created_at: chrono::DateTime<chrono::Utc>,
pub updated_at: chrono::DateTime<chrono::Utc>,
}
impl VectorIndex {
pub fn new(id: String, name: String, index_type: IndexType, dimensions: usize) -> Self {
let now = chrono::Utc::now();
Self {
id,
name,
index_type,
dimensions,
vector_count: 0,
metadata: HashMap::new(),
created_at: now,
updated_at: now,
}
}
pub fn add_metadata(&mut self, key: String, value: String) {
self.metadata.insert(key, value);
self.updated_at = chrono::Utc::now();
}
pub fn get_metadata(&self, key: &str) -> Option<&String> {
self.metadata.get(key)
}
pub fn remove_metadata(&mut self, key: &str) -> Option<String> {
let result = self.metadata.remove(key);
if result.is_some() {
self.updated_at = chrono::Utc::now();
}
result
}
pub fn update_vector_count(&mut self, count: usize) {
self.vector_count = count;
self.updated_at = chrono::Utc::now();
}
pub fn estimated_size_bytes(&self) -> u64 {
(self.vector_count * self.dimensions * 4 + 1024) as u64
}
pub fn is_empty(&self) -> bool {
self.vector_count == 0
}
pub fn stats(&self) -> IndexStats {
IndexStats {
id: self.id.clone(),
name: self.name.clone(),
index_type: self.index_type.clone(),
dimensions: self.dimensions,
vector_count: self.vector_count,
estimated_size_bytes: self.estimated_size_bytes(),
metadata_count: self.metadata.len(),
created_at: self.created_at,
updated_at: self.updated_at,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexStats {
pub id: String,
pub name: String,
pub index_type: IndexType,
pub dimensions: usize,
pub vector_count: usize,
pub estimated_size_bytes: u64,
pub metadata_count: usize,
pub created_at: chrono::DateTime<chrono::Utc>,
pub updated_at: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
pub enum IndexType {
#[default]
Flat,
IVF,
HNSW,
PQ,
Custom(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchParams {
pub top_k: usize,
pub threshold: f32,
pub search_method: SearchMethod,
pub include_metadata: bool,
pub document_filter: Option<String>,
pub metadata_filter: Option<HashMap<String, String>>,
}
impl Default for SearchParams {
fn default() -> Self {
Self {
top_k: 10,
threshold: 0.7,
search_method: SearchMethod::Cosine,
include_metadata: true,
document_filter: None,
metadata_filter: None,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub enum SearchMethod {
#[default]
Cosine,
Euclidean,
DotProduct,
Manhattan,
}
#[async_trait::async_trait]
pub trait DocumentStorage: Send + Sync {
async fn store_chunks(&self, chunks: Vec<DocumentChunk>) -> Result<()>;
async fn search_similar(
&self,
query_embedding: &[f32],
top_k: usize,
) -> Result<Vec<DocumentChunk>>;
async fn search_with_params(
&self,
query_embedding: &[f32],
params: SearchParams,
) -> Result<Vec<DocumentChunk>>;
async fn get_chunk(&self, chunk_id: &str) -> Result<Option<DocumentChunk>>;
async fn delete_chunk(&self, chunk_id: &str) -> Result<bool>;
async fn get_chunks_by_document(&self, document_id: &str) -> Result<Vec<DocumentChunk>>;
async fn delete_document(&self, document_id: &str) -> Result<usize>;
async fn get_stats(&self) -> Result<StorageStats>;
async fn list_documents(&self) -> Result<Vec<String>>;
async fn get_total_chunks(&self) -> Result<usize>;
async fn clear(&self) -> Result<()>;
async fn optimize(&self) -> Result<()>;
async fn create_backup(&self, path: &str) -> Result<()>;
async fn restore_backup(&self, path: &str) -> Result<()>;
async fn health_check(&self) -> Result<StorageHealth>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StorageStats {
pub total_documents: usize,
pub total_chunks: usize,
pub index_size_bytes: u64,
pub last_updated: chrono::DateTime<chrono::Utc>,
pub backend_type: String,
pub available_space_bytes: u64,
pub used_space_bytes: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StorageHealth {
pub status: HealthStatus,
pub checked_at: chrono::DateTime<chrono::Utc>,
pub details: HashMap<String, String>,
pub metrics: Option<StorageMetrics>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum HealthStatus {
Healthy,
Warning,
Unhealthy,
Unavailable,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StorageMetrics {
pub average_search_time_ms: f64,
pub average_insert_time_ms: f64,
pub fragmentation_ratio: f32,
pub cache_hit_rate: f32,
pub memory_usage_bytes: u64,
pub disk_usage_bytes: u64,
}
pub struct InMemoryStorage {
chunks: Arc<RwLock<HashMap<String, DocumentChunk>>>,
vectors: VectorStore,
stats: Arc<RwLock<StorageStats>>,
}
impl InMemoryStorage {
pub fn new() -> Self {
Self::new_with_backend_type("memory")
}
pub fn new_with_backend_type(backend_type: &str) -> Self {
let now = chrono::Utc::now();
Self {
chunks: Arc::new(RwLock::new(HashMap::new())),
vectors: Arc::new(RwLock::new(Vec::new())),
stats: Arc::new(RwLock::new(StorageStats {
total_documents: 0,
total_chunks: 0,
index_size_bytes: 0,
last_updated: now,
backend_type: backend_type.to_string(),
available_space_bytes: u64::MAX,
used_space_bytes: 0,
})),
}
}
fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
dot_product / (norm_a * norm_b)
}
}
}
impl Default for InMemoryStorage {
fn default() -> Self {
Self::new()
}
}
#[async_trait::async_trait]
impl DocumentStorage for InMemoryStorage {
async fn store_chunks(&self, chunks: Vec<DocumentChunk>) -> Result<()> {
let mut chunks_map = self.chunks.write().await;
let mut vectors = self.vectors.write().await;
let mut stats = self.stats.write().await;
for chunk in chunks {
chunks_map.insert(chunk.id.clone(), chunk.clone());
vectors.push((chunk.id.clone(), chunk.embedding.clone()));
stats.total_chunks += 1;
}
stats.last_updated = chrono::Utc::now();
stats.index_size_bytes = (stats.total_chunks * 1536 * 4) as u64; stats.used_space_bytes = stats.index_size_bytes;
Ok(())
}
async fn search_similar(
&self,
query_embedding: &[f32],
top_k: usize,
) -> Result<Vec<DocumentChunk>> {
let vectors = self.vectors.read().await;
let chunks = self.chunks.read().await;
let mut similarities: Vec<(String, f32)> = vectors
.iter()
.map(|(chunk_id, embedding)| {
let similarity = self.cosine_similarity(query_embedding, embedding);
(chunk_id.clone(), similarity)
})
.collect();
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut results = Vec::new();
for (chunk_id, _) in similarities.iter().take(top_k) {
if let Some(chunk) = chunks.get(chunk_id) {
results.push(chunk.clone());
}
}
Ok(results)
}
async fn search_with_params(
&self,
query_embedding: &[f32],
params: SearchParams,
) -> Result<Vec<DocumentChunk>> {
let mut results = self.search_similar(query_embedding, params.top_k * 2).await?;
if let Some(document_filter) = ¶ms.document_filter {
results.retain(|chunk| chunk.document_id == *document_filter);
}
if let Some(metadata_filter) = ¶ms.metadata_filter {
results.retain(|chunk| {
metadata_filter.iter().all(|(key, value)| {
chunk.get_metadata(key).map(|v| v == value).unwrap_or(false)
})
});
}
results.retain(|chunk| {
let similarity = self.cosine_similarity(query_embedding, &chunk.embedding);
similarity >= params.threshold
});
results.sort_by(|a, b| {
let sim_a = self.cosine_similarity(query_embedding, &a.embedding);
let sim_b = self.cosine_similarity(query_embedding, &b.embedding);
sim_b.partial_cmp(&sim_a).unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(params.top_k);
Ok(results)
}
async fn get_chunk(&self, chunk_id: &str) -> Result<Option<DocumentChunk>> {
let chunks = self.chunks.read().await;
Ok(chunks.get(chunk_id).cloned())
}
async fn delete_chunk(&self, chunk_id: &str) -> Result<bool> {
let mut chunks = self.chunks.write().await;
let mut vectors = self.vectors.write().await;
let mut stats = self.stats.write().await;
let chunk_removed = chunks.remove(chunk_id).is_some();
let _vector_removed = vectors.retain(|(id, _)| id != chunk_id);
if chunk_removed {
stats.total_chunks = stats.total_chunks.saturating_sub(1);
stats.last_updated = chrono::Utc::now();
stats.index_size_bytes = (stats.total_chunks * 1536 * 4) as u64;
stats.used_space_bytes = stats.index_size_bytes;
}
Ok(chunk_removed)
}
async fn get_chunks_by_document(&self, document_id: &str) -> Result<Vec<DocumentChunk>> {
let chunks = self.chunks.read().await;
let results = chunks
.values()
.filter(|chunk| chunk.document_id == document_id)
.cloned()
.collect();
Ok(results)
}
async fn delete_document(&self, document_id: &str) -> Result<usize> {
let mut chunks = self.chunks.write().await;
let mut vectors = self.vectors.write().await;
let mut stats = self.stats.write().await;
let initial_count = chunks.len();
chunks.retain(|_, chunk| chunk.document_id != document_id);
vectors.retain(|(id, _)| {
chunks.contains_key(id)
});
let removed_count = initial_count - chunks.len();
if removed_count > 0 {
stats.total_chunks = stats.total_chunks.saturating_sub(removed_count);
stats.last_updated = chrono::Utc::now();
stats.index_size_bytes = (stats.total_chunks * 1536 * 4) as u64;
stats.used_space_bytes = stats.index_size_bytes;
}
Ok(removed_count)
}
async fn get_stats(&self) -> Result<StorageStats> {
let stats = self.stats.read().await;
Ok(stats.clone())
}
async fn list_documents(&self) -> Result<Vec<String>> {
let chunks = self.chunks.read().await;
let documents: std::collections::HashSet<String> =
chunks.values().map(|chunk| chunk.document_id.clone()).collect();
Ok(documents.into_iter().collect())
}
async fn get_total_chunks(&self) -> Result<usize> {
let stats = self.stats.read().await;
Ok(stats.total_chunks)
}
async fn clear(&self) -> Result<()> {
let mut chunks = self.chunks.write().await;
let mut vectors = self.vectors.write().await;
let mut stats = self.stats.write().await;
chunks.clear();
vectors.clear();
stats.total_documents = 0;
stats.total_chunks = 0;
stats.index_size_bytes = 0;
stats.used_space_bytes = 0;
stats.last_updated = chrono::Utc::now();
Ok(())
}
async fn optimize(&self) -> Result<()> {
Ok(())
}
async fn create_backup(&self, path: &str) -> Result<()> {
let chunks = self.chunks.read().await;
let vectors = self.vectors.read().await;
let backup_data = serde_json::json!({
"version": 1,
"created_at": chrono::Utc::now().to_rfc3339(),
"chunks": chunks.values().collect::<Vec<_>>(),
"vectors": vectors.iter().collect::<Vec<_>>(),
});
let json_bytes = serde_json::to_vec_pretty(&backup_data)?;
std::fs::write(path, json_bytes)?;
Ok(())
}
async fn restore_backup(&self, path: &str) -> Result<()> {
let json_bytes = std::fs::read(path)?;
let backup_data: serde_json::Value = serde_json::from_slice(&json_bytes)?;
self.clear().await?;
let mut chunks_map = self.chunks.write().await;
let mut vectors = self.vectors.write().await;
let mut stats = self.stats.write().await;
if let Some(chunks_arr) = backup_data.get("chunks").and_then(|v| v.as_array()) {
for chunk_val in chunks_arr {
if let Ok(chunk) = serde_json::from_value::<DocumentChunk>(chunk_val.clone()) {
chunks_map.insert(chunk.id.clone(), chunk);
}
}
}
if let Some(vectors_arr) = backup_data.get("vectors").and_then(|v| v.as_array()) {
for vector_val in vectors_arr {
if let Ok(vector) = serde_json::from_value::<(String, Vec<f32>)>(vector_val.clone())
{
vectors.push(vector);
}
}
}
let doc_ids: std::collections::HashSet<String> =
chunks_map.values().map(|c| c.document_id.clone()).collect();
stats.total_documents = doc_ids.len();
stats.total_chunks = chunks_map.len();
stats.index_size_bytes = (stats.total_chunks * 1536 * 4) as u64;
stats.used_space_bytes = stats.index_size_bytes;
stats.last_updated = chrono::Utc::now();
Ok(())
}
async fn health_check(&self) -> Result<StorageHealth> {
let chunks = self.chunks.read().await;
let vectors = self.vectors.read().await;
let mut details = HashMap::new();
details.insert("chunk_count".to_string(), chunks.len().to_string());
details.insert("vector_count".to_string(), vectors.len().to_string());
details.insert("memory_usage".to_string(), "unknown".to_string());
let status = if chunks.len() == vectors.len() {
HealthStatus::Healthy
} else {
details.insert("error".to_string(), "Chunk/vector count mismatch".to_string());
HealthStatus::Unhealthy
};
Ok(StorageHealth {
status,
checked_at: chrono::Utc::now(),
details,
metrics: None,
})
}
}
pub struct StorageFactory;
impl StorageFactory {
pub fn create_memory() -> Box<dyn DocumentStorage> {
Box::new(InMemoryStorage::new())
}
pub fn create_file(path: &str) -> Result<Box<dyn DocumentStorage>> {
if path.trim().is_empty() {
return Err(crate::Error::generic("File storage path cannot be empty"));
}
std::fs::create_dir_all(path)?;
Ok(Box::new(InMemoryStorage::new_with_backend_type("file")))
}
pub fn create_database(connection_string: &str) -> Result<Box<dyn DocumentStorage>> {
if connection_string.trim().is_empty() {
return Err(crate::Error::generic("Database connection string cannot be empty"));
}
Ok(Box::new(InMemoryStorage::new_with_backend_type("database")))
}
pub fn create_vector_db(config: HashMap<String, String>) -> Result<Box<dyn DocumentStorage>> {
if config.is_empty() {
return Err(crate::Error::generic("Vector database configuration cannot be empty"));
}
Ok(Box::new(InMemoryStorage::new_with_backend_type("vector-db")))
}
}
#[cfg(test)]
mod tests {
use super::StorageFactory;
use std::collections::HashMap;
#[test]
fn test_module_compiles() {
}
#[tokio::test]
async fn test_create_file_storage_fallback_backend_type() {
let dir =
std::env::temp_dir().join(format!("mockforge-data-storage-{}", std::process::id()));
let _ = std::fs::remove_dir_all(&dir);
let storage = StorageFactory::create_file(dir.to_str().expect("path")).expect("create");
let stats = storage.get_stats().await.expect("stats");
assert_eq!(stats.backend_type, "file");
let _ = std::fs::remove_dir_all(&dir);
}
#[tokio::test]
async fn test_create_database_storage_fallback_backend_type() {
let storage =
StorageFactory::create_database("postgres://user:pass@localhost/db").expect("create");
let stats = storage.get_stats().await.expect("stats");
assert_eq!(stats.backend_type, "database");
}
#[tokio::test]
async fn test_create_vector_storage_fallback_backend_type() {
let mut cfg = HashMap::new();
cfg.insert("provider".to_string(), "qdrant".to_string());
let storage = StorageFactory::create_vector_db(cfg).expect("create");
let stats = storage.get_stats().await.expect("stats");
assert_eq!(stats.backend_type, "vector-db");
}
}