use crate::{Document, DocumentChunk, Embedding, RragError, RragResult};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokio::fs;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum StorageEntry {
Document(Document),
Chunk(DocumentChunk),
Embedding(Embedding),
Metadata(HashMap<String, serde_json::Value>),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct StorageKey {
pub entry_type: EntryType,
pub id: String,
pub namespace: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum EntryType {
Document,
Chunk,
Embedding,
Metadata,
}
impl StorageKey {
pub fn document(id: impl Into<String>) -> Self {
Self {
entry_type: EntryType::Document,
id: id.into(),
namespace: None,
}
}
pub fn chunk(document_id: impl Into<String>, chunk_index: usize) -> Self {
Self {
entry_type: EntryType::Chunk,
id: format!("{}_{}", document_id.into(), chunk_index),
namespace: None,
}
}
pub fn embedding(id: impl Into<String>) -> Self {
Self {
entry_type: EntryType::Embedding,
id: id.into(),
namespace: None,
}
}
pub fn with_namespace(mut self, namespace: impl Into<String>) -> Self {
self.namespace = Some(namespace.into());
self
}
pub fn to_path(&self) -> PathBuf {
let type_str = match self.entry_type {
EntryType::Document => "documents",
EntryType::Chunk => "chunks",
EntryType::Embedding => "embeddings",
EntryType::Metadata => "metadata",
};
let mut path = PathBuf::from(type_str);
if let Some(namespace) = &self.namespace {
path.push(namespace);
}
path.push(format!("{}.json", self.id));
path
}
}
#[derive(Debug, Clone)]
pub struct StorageQuery {
pub entry_type: Option<EntryType>,
pub namespace: Option<String>,
pub key_prefix: Option<String>,
pub metadata_filters: HashMap<String, serde_json::Value>,
pub limit: Option<usize>,
pub offset: Option<usize>,
}
impl StorageQuery {
pub fn new() -> Self {
Self {
entry_type: None,
namespace: None,
key_prefix: None,
metadata_filters: HashMap::new(),
limit: None,
offset: None,
}
}
pub fn documents() -> Self {
Self::new().with_entry_type(EntryType::Document)
}
pub fn chunks() -> Self {
Self::new().with_entry_type(EntryType::Chunk)
}
pub fn embeddings() -> Self {
Self::new().with_entry_type(EntryType::Embedding)
}
pub fn with_entry_type(mut self, entry_type: EntryType) -> Self {
self.entry_type = Some(entry_type);
self
}
pub fn with_namespace(mut self, namespace: impl Into<String>) -> Self {
self.namespace = Some(namespace.into());
self
}
pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
self.key_prefix = Some(prefix.into());
self
}
pub fn with_limit(mut self, limit: usize) -> Self {
self.limit = Some(limit);
self
}
pub fn with_offset(mut self, offset: usize) -> Self {
self.offset = Some(offset);
self
}
}
impl Default for StorageQuery {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
pub trait Storage: Send + Sync {
fn name(&self) -> &str;
async fn put(&self, key: &StorageKey, entry: &StorageEntry) -> RragResult<()>;
async fn get(&self, key: &StorageKey) -> RragResult<Option<StorageEntry>>;
async fn delete(&self, key: &StorageKey) -> RragResult<bool>;
async fn exists(&self, key: &StorageKey) -> RragResult<bool>;
async fn list_keys(&self, query: &StorageQuery) -> RragResult<Vec<StorageKey>>;
async fn get_many(
&self,
keys: &[StorageKey],
) -> RragResult<Vec<(StorageKey, Option<StorageEntry>)>>;
async fn put_many(&self, entries: &[(StorageKey, StorageEntry)]) -> RragResult<()>;
async fn delete_many(&self, keys: &[StorageKey]) -> RragResult<usize>;
async fn clear(&self) -> RragResult<()> {
Err(RragError::storage(
"clear",
std::io::Error::new(
std::io::ErrorKind::Unsupported,
"Clear operation not supported",
),
))
}
async fn stats(&self) -> RragResult<StorageStats>;
async fn health_check(&self) -> RragResult<bool>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StorageStats {
pub total_entries: usize,
pub entries_by_type: HashMap<String, usize>,
pub size_bytes: u64,
pub available_bytes: Option<u64>,
pub backend_type: String,
pub last_updated: chrono::DateTime<chrono::Utc>,
}
pub struct InMemoryStorage {
data: Arc<tokio::sync::RwLock<HashMap<StorageKey, StorageEntry>>>,
config: MemoryStorageConfig,
}
#[derive(Debug, Clone)]
pub struct MemoryStorageConfig {
pub max_entries: Option<usize>,
pub max_memory_bytes: Option<u64>,
}
impl Default for MemoryStorageConfig {
fn default() -> Self {
Self {
max_entries: Some(100_000),
max_memory_bytes: Some(1_000_000_000), }
}
}
impl InMemoryStorage {
pub fn new() -> Self {
Self {
data: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
config: MemoryStorageConfig::default(),
}
}
pub fn with_config(config: MemoryStorageConfig) -> Self {
Self {
data: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
config,
}
}
async fn check_limits(&self) -> RragResult<()> {
let data = self.data.read().await;
if let Some(max_entries) = self.config.max_entries {
if data.len() >= max_entries {
return Err(RragError::storage(
"memory_limit",
std::io::Error::new(
std::io::ErrorKind::OutOfMemory,
format!("Exceeded maximum entries: {}", max_entries),
),
));
}
}
Ok(())
}
fn matches_query(&self, key: &StorageKey, query: &StorageQuery) -> bool {
if let Some(entry_type) = &query.entry_type {
if key.entry_type != *entry_type {
return false;
}
}
if let Some(namespace) = &query.namespace {
match &key.namespace {
Some(key_ns) if key_ns == namespace => {}
None if namespace.is_empty() => {}
_ => return false,
}
}
if let Some(prefix) = &query.key_prefix {
if !key.id.starts_with(prefix) {
return false;
}
}
true
}
}
impl Default for InMemoryStorage {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Storage for InMemoryStorage {
fn name(&self) -> &str {
"in_memory"
}
async fn put(&self, key: &StorageKey, entry: &StorageEntry) -> RragResult<()> {
self.check_limits().await?;
let mut data = self.data.write().await;
data.insert(key.clone(), entry.clone());
Ok(())
}
async fn get(&self, key: &StorageKey) -> RragResult<Option<StorageEntry>> {
let data = self.data.read().await;
Ok(data.get(key).cloned())
}
async fn delete(&self, key: &StorageKey) -> RragResult<bool> {
let mut data = self.data.write().await;
Ok(data.remove(key).is_some())
}
async fn exists(&self, key: &StorageKey) -> RragResult<bool> {
let data = self.data.read().await;
Ok(data.contains_key(key))
}
async fn list_keys(&self, query: &StorageQuery) -> RragResult<Vec<StorageKey>> {
let data = self.data.read().await;
let mut keys: Vec<StorageKey> = data
.keys()
.filter(|key| self.matches_query(key, query))
.cloned()
.collect();
if let Some(offset) = query.offset {
if offset < keys.len() {
keys = keys.into_iter().skip(offset).collect();
} else {
keys.clear();
}
}
if let Some(limit) = query.limit {
keys.truncate(limit);
}
Ok(keys)
}
async fn get_many(
&self,
keys: &[StorageKey],
) -> RragResult<Vec<(StorageKey, Option<StorageEntry>)>> {
let data = self.data.read().await;
let results = keys
.iter()
.map(|key| (key.clone(), data.get(key).cloned()))
.collect();
Ok(results)
}
async fn put_many(&self, entries: &[(StorageKey, StorageEntry)]) -> RragResult<()> {
self.check_limits().await?;
let mut data = self.data.write().await;
for (key, entry) in entries {
data.insert(key.clone(), entry.clone());
}
Ok(())
}
async fn delete_many(&self, keys: &[StorageKey]) -> RragResult<usize> {
let mut data = self.data.write().await;
let mut deleted = 0;
for key in keys {
if data.remove(key).is_some() {
deleted += 1;
}
}
Ok(deleted)
}
async fn clear(&self) -> RragResult<()> {
let mut data = self.data.write().await;
data.clear();
Ok(())
}
async fn stats(&self) -> RragResult<StorageStats> {
let data = self.data.read().await;
let mut entries_by_type = HashMap::new();
for key in data.keys() {
let type_str = match key.entry_type {
EntryType::Document => "documents",
EntryType::Chunk => "chunks",
EntryType::Embedding => "embeddings",
EntryType::Metadata => "metadata",
};
*entries_by_type.entry(type_str.to_string()).or_insert(0) += 1;
}
let estimated_size = data.len() * 1024;
Ok(StorageStats {
total_entries: data.len(),
entries_by_type,
size_bytes: estimated_size as u64,
available_bytes: self
.config
.max_memory_bytes
.map(|max| max - estimated_size as u64),
backend_type: "in_memory".to_string(),
last_updated: chrono::Utc::now(),
})
}
async fn health_check(&self) -> RragResult<bool> {
let _data = self.data.read().await;
Ok(true)
}
}
pub struct FileStorage {
base_dir: PathBuf,
config: FileStorageConfig,
}
#[derive(Debug, Clone)]
pub struct FileStorageConfig {
pub create_dirs: bool,
pub file_permissions: Option<u32>,
pub compress: bool,
pub sync_writes: bool,
}
impl Default for FileStorageConfig {
fn default() -> Self {
Self {
create_dirs: true,
file_permissions: None,
compress: false,
sync_writes: false,
}
}
}
impl FileStorage {
pub async fn new(base_dir: impl AsRef<Path>) -> RragResult<Self> {
let base_dir = base_dir.as_ref().to_path_buf();
if !base_dir.exists() {
fs::create_dir_all(&base_dir)
.await
.map_err(|e| RragError::storage("create_directory", e))?;
}
Ok(Self {
base_dir,
config: FileStorageConfig::default(),
})
}
pub async fn with_config(
base_dir: impl AsRef<Path>,
config: FileStorageConfig,
) -> RragResult<Self> {
let base_dir = base_dir.as_ref().to_path_buf();
if config.create_dirs && !base_dir.exists() {
fs::create_dir_all(&base_dir)
.await
.map_err(|e| RragError::storage("create_directory", e))?;
}
Ok(Self { base_dir, config })
}
fn get_file_path(&self, key: &StorageKey) -> PathBuf {
self.base_dir.join(key.to_path())
}
async fn ensure_parent_dir(&self, file_path: &Path) -> RragResult<()> {
if let Some(parent) = file_path.parent() {
if !parent.exists() {
fs::create_dir_all(parent)
.await
.map_err(|e| RragError::storage("create_parent_directory", e))?;
}
}
Ok(())
}
}
#[async_trait]
impl Storage for FileStorage {
fn name(&self) -> &str {
"file_system"
}
async fn put(&self, key: &StorageKey, entry: &StorageEntry) -> RragResult<()> {
let file_path = self.get_file_path(key);
self.ensure_parent_dir(&file_path).await?;
let json_data =
serde_json::to_vec_pretty(entry).map_err(|e| RragError::storage("serialize", e))?;
let mut file = fs::File::create(&file_path)
.await
.map_err(|e| RragError::storage("create_file", e))?;
file.write_all(&json_data)
.await
.map_err(|e| RragError::storage("write_file", e))?;
if self.config.sync_writes {
file.sync_all()
.await
.map_err(|e| RragError::storage("sync_file", e))?;
}
Ok(())
}
async fn get(&self, key: &StorageKey) -> RragResult<Option<StorageEntry>> {
let file_path = self.get_file_path(key);
if !file_path.exists() {
return Ok(None);
}
let mut file = fs::File::open(&file_path)
.await
.map_err(|e| RragError::storage("open_file", e))?;
let mut contents = Vec::new();
file.read_to_end(&mut contents)
.await
.map_err(|e| RragError::storage("read_file", e))?;
let entry =
serde_json::from_slice(&contents).map_err(|e| RragError::storage("deserialize", e))?;
Ok(Some(entry))
}
async fn delete(&self, key: &StorageKey) -> RragResult<bool> {
let file_path = self.get_file_path(key);
if !file_path.exists() {
return Ok(false);
}
fs::remove_file(&file_path)
.await
.map_err(|e| RragError::storage("delete_file", e))?;
Ok(true)
}
async fn exists(&self, key: &StorageKey) -> RragResult<bool> {
let file_path = self.get_file_path(key);
Ok(file_path.exists())
}
async fn list_keys(&self, _query: &StorageQuery) -> RragResult<Vec<StorageKey>> {
let keys = Vec::new();
Ok(keys)
}
async fn get_many(
&self,
keys: &[StorageKey],
) -> RragResult<Vec<(StorageKey, Option<StorageEntry>)>> {
let mut results = Vec::with_capacity(keys.len());
for key in keys {
let entry = self.get(key).await?;
results.push((key.clone(), entry));
}
Ok(results)
}
async fn put_many(&self, entries: &[(StorageKey, StorageEntry)]) -> RragResult<()> {
for (key, entry) in entries {
self.put(key, entry).await?;
}
Ok(())
}
async fn delete_many(&self, keys: &[StorageKey]) -> RragResult<usize> {
let mut deleted = 0;
for key in keys {
if self.delete(key).await? {
deleted += 1;
}
}
Ok(deleted)
}
async fn stats(&self) -> RragResult<StorageStats> {
Ok(StorageStats {
total_entries: 0, entries_by_type: HashMap::new(),
size_bytes: 0,
available_bytes: None,
backend_type: "file_system".to_string(),
last_updated: chrono::Utc::now(),
})
}
async fn health_check(&self) -> RragResult<bool> {
Ok(self.base_dir.exists() && self.base_dir.is_dir())
}
}
pub struct StorageService {
storage: Arc<dyn Storage>,
#[allow(dead_code)]
config: StorageServiceConfig,
}
#[derive(Debug, Clone)]
pub struct StorageServiceConfig {
pub enable_batching: bool,
pub batch_size: usize,
pub batch_timeout_ms: u64,
pub enable_caching: bool,
pub cache_ttl_seconds: u64,
}
impl Default for StorageServiceConfig {
fn default() -> Self {
Self {
enable_batching: true,
batch_size: 100,
batch_timeout_ms: 1000,
enable_caching: false,
cache_ttl_seconds: 300,
}
}
}
impl StorageService {
pub fn new(storage: Arc<dyn Storage>) -> Self {
Self {
storage,
config: StorageServiceConfig::default(),
}
}
pub fn with_config(storage: Arc<dyn Storage>, config: StorageServiceConfig) -> Self {
Self { storage, config }
}
pub async fn store_document(&self, document: &Document) -> RragResult<()> {
let key = StorageKey::document(&document.id);
let entry = StorageEntry::Document(document.clone());
self.storage.put(&key, &entry).await
}
pub async fn store_chunk(&self, chunk: &DocumentChunk) -> RragResult<()> {
let key = StorageKey::chunk(&chunk.document_id, chunk.chunk_index);
let entry = StorageEntry::Chunk(chunk.clone());
self.storage.put(&key, &entry).await
}
pub async fn store_embedding(&self, embedding: &Embedding) -> RragResult<()> {
let key = StorageKey::embedding(&embedding.source_id);
let entry = StorageEntry::Embedding(embedding.clone());
self.storage.put(&key, &entry).await
}
pub async fn get_document(&self, document_id: &str) -> RragResult<Option<Document>> {
let key = StorageKey::document(document_id);
match self.storage.get(&key).await? {
Some(StorageEntry::Document(doc)) => Ok(Some(doc)),
_ => Ok(None),
}
}
pub async fn get_stats(&self) -> RragResult<StorageStats> {
self.storage.stats().await
}
pub async fn health_check(&self) -> RragResult<bool> {
self.storage.health_check().await
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[tokio::test]
async fn test_in_memory_storage() {
let storage = InMemoryStorage::new();
let doc = Document::new("Test document");
let key = StorageKey::document(&doc.id);
let entry = StorageEntry::Document(doc.clone());
storage.put(&key, &entry).await.unwrap();
let retrieved = storage.get(&key).await.unwrap();
assert!(retrieved.is_some());
if let Some(StorageEntry::Document(retrieved_doc)) = retrieved {
assert_eq!(retrieved_doc.id, doc.id);
assert_eq!(retrieved_doc.content_str(), doc.content_str());
}
assert!(storage.exists(&key).await.unwrap());
assert!(storage.delete(&key).await.unwrap());
assert!(!storage.exists(&key).await.unwrap());
}
#[tokio::test]
async fn test_file_storage() {
let temp_dir = TempDir::new().unwrap();
let storage = FileStorage::new(temp_dir.path()).await.unwrap();
let doc = Document::new("Test document for file storage");
let key = StorageKey::document(&doc.id);
let entry = StorageEntry::Document(doc.clone());
storage.put(&key, &entry).await.unwrap();
let retrieved = storage.get(&key).await.unwrap();
assert!(retrieved.is_some());
if let Some(StorageEntry::Document(retrieved_doc)) = retrieved {
assert_eq!(retrieved_doc.id, doc.id);
}
let file_path = temp_dir.path().join(key.to_path());
assert!(file_path.exists());
}
#[test]
fn test_storage_key() {
let doc_key = StorageKey::document("doc1");
assert_eq!(doc_key.entry_type, EntryType::Document);
assert_eq!(doc_key.id, "doc1");
let chunk_key = StorageKey::chunk("doc1", 5);
assert_eq!(chunk_key.entry_type, EntryType::Chunk);
assert_eq!(chunk_key.id, "doc1_5");
let ns_key = doc_key.with_namespace("test_namespace");
assert_eq!(ns_key.namespace, Some("test_namespace".to_string()));
}
#[tokio::test]
async fn test_storage_service() {
let storage = Arc::new(InMemoryStorage::new());
let service = StorageService::new(storage);
let doc = Document::new("Test document for service");
service.store_document(&doc).await.unwrap();
let retrieved = service.get_document(&doc.id).await.unwrap();
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().id, doc.id);
let stats = service.get_stats().await.unwrap();
assert_eq!(stats.total_entries, 1);
}
}