use crate::{embedding::cosine_similarity, Document, Error, Result};
use async_trait::async_trait;
use qdrant_client::qdrant::{
CreateCollection, DeletePoints, Distance, GetPoints, PointId, PointStruct, QuantizationConfig,
ScalarQuantization, ScrollPoints, SearchPoints, UpsertPoints, VectorParams, VectorsConfig,
};
use qdrant_client::Qdrant;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use uuid::Uuid;
pub mod hot;
pub mod cold;
pub mod wal;
pub mod context;
pub mod dual_layer;
pub mod sync_worker;
pub mod config;
pub mod memory_types;
pub mod serde_utils;
pub use cold::{ColdMemory, ColdMemoryConfig, ColdMemoryEntry};
pub use context::{retrieve_context, ContextQuery, ContextResult};
pub use hot::{HotMemory, HotMemoryConfig, HotMemoryEntry};
pub use wal::{WalConfig, WalOperation, WriteAheadLog};
pub use dual_layer::{
ContextResult as DualContextResult, DualLayerError, DualLayerResult, DualLayerStorage,
StorageTier,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryEntry {
pub id: Uuid,
pub content: String,
pub embedding: Option<Vec<f32>>,
pub metadata: HashMap<String, String>,
pub importance: f32,
pub access_count: u64,
pub created_at: chrono::DateTime<chrono::Utc>,
pub last_accessed: chrono::DateTime<chrono::Utc>,
pub ttl_secs: Option<u64>,
pub layer: MemoryLayer,
pub tags: Vec<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum MemoryLayer {
Hot,
Cold,
#[default]
Pending,
}
impl MemoryEntry {
pub fn new(content: impl Into<String>) -> Self {
let now = chrono::Utc::now();
Self {
id: Uuid::new_v4(),
content: content.into(),
embedding: None,
metadata: HashMap::new(),
importance: 0.5,
access_count: 0,
created_at: now,
last_accessed: now,
ttl_secs: None,
layer: MemoryLayer::Pending,
tags: Vec::new(),
}
}
pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
self.embedding = Some(embedding);
self
}
pub fn with_importance(mut self, importance: f32) -> Self {
self.importance = importance.clamp(0.0, 1.0);
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
pub fn with_ttl(mut self, ttl_secs: u64) -> Self {
self.ttl_secs = Some(ttl_secs);
self
}
pub fn with_tags(mut self, tags: Vec<String>) -> Self {
self.tags = tags;
self
}
pub fn is_expired(&self) -> bool {
if let Some(ttl) = self.ttl_secs {
let elapsed = chrono::Utc::now()
.signed_duration_since(self.created_at)
.num_seconds() as u64;
elapsed > ttl
} else {
false
}
}
pub fn age_secs(&self) -> i64 {
chrono::Utc::now()
.signed_duration_since(self.created_at)
.num_seconds()
}
pub fn idle_secs(&self) -> i64 {
chrono::Utc::now()
.signed_duration_since(self.last_accessed)
.num_seconds()
}
}
pub use config::DualLayerConfig;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SyncStats {
pub hot_to_cold: usize,
pub expired_removed: usize,
pub wal_replayed: usize,
pub wal_compacted: usize,
pub duration_ms: u64,
pub warnings: Vec<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct RecoveryReport {
pub entries_recovered: usize,
pub entries_lost: usize,
pub operations_replayed: usize,
pub last_sequence: u64,
pub duration_ms: u64,
pub errors: Vec<String>,
pub success: bool,
}
pub struct DualLayerMemory {
hot: Arc<HotMemory>,
cold: Arc<ColdMemory>,
wal: Arc<WriteAheadLog>,
config: DualLayerConfig,
sync_handle: Option<tokio::task::JoinHandle<()>>,
shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
is_shutdown: Arc<std::sync::atomic::AtomicBool>,
}
impl DualLayerMemory {
pub async fn new(config: DualLayerConfig) -> Result<Self> {
let hot_config = hot::HotMemoryConfig {
max_entries: config.hot.max_entries,
ttl: std::time::Duration::from_secs(config.hot.ttl_secs),
eviction_batch_size: config.hot.eviction_batch_size,
};
let hot = Arc::new(HotMemory::new(hot_config));
let cold_config = cold::ColdMemoryConfig {
db_path: config.cold.db_path.clone(),
cache_size_mb: config.cold.cache_size_mb,
flush_interval_secs: config.cold.flush_interval_secs,
enable_compression: config.cold.enable_compression,
parallel_scan_threshold: 1000, use_simd: true, };
let cold = Arc::new(ColdMemory::new(cold_config).await?);
let wal_config = wal::WalConfig {
dir: config.wal.dir.clone(),
segment_size_mb: config.wal.segment_size_mb,
sync_mode: match config.wal.sync_mode {
config::SyncMode::Sync => wal::SyncMode::Immediate,
config::SyncMode::Async => wal::SyncMode::Async,
config::SyncMode::Balanced => {
wal::SyncMode::Batched(std::time::Duration::from_millis(100))
}
config::SyncMode::OsDefault => wal::SyncMode::Async, },
checkpoint_retention: config.wal.max_segments, preallocate_segments: config.wal.preallocate,
};
let wal = Arc::new(WriteAheadLog::new(wal_config).await?);
let is_shutdown = Arc::new(std::sync::atomic::AtomicBool::new(false));
let mut memory = Self {
hot,
cold,
wal,
config: config.clone(),
sync_handle: None,
shutdown_tx: None,
is_shutdown,
};
let recovery = memory.recover().await?;
if recovery.entries_recovered > 0 {
tracing::info!(
"Recovered {} entries from WAL in {}ms",
recovery.entries_recovered,
recovery.duration_ms
);
}
if config.sync.auto_sync_enabled {
memory.start_background_sync();
}
Ok(memory)
}
pub async fn store(&self, mut entry: MemoryEntry) -> Result<Uuid> {
if entry.id == Uuid::nil() {
entry.id = Uuid::new_v4();
}
let embedding = entry.embedding.clone().unwrap_or_default(); let operation = WalOperation::Insert {
id: entry.id,
content: entry.content.clone(),
embedding: embedding.clone(),
};
self.wal.append(operation).await?;
entry.layer = MemoryLayer::Hot;
let metadata_json = {
let mut obj = serde_json::Map::new();
for (k, v) in &entry.metadata {
obj.insert(k.clone(), serde_json::Value::String(v.clone()));
}
serde_json::Value::Object(obj)
};
let hot_entry =
HotMemoryEntry::new(entry.id, entry.content.clone(), embedding, metadata_json);
self.hot.put(hot_entry).await?;
tracing::debug!(id = %entry.id, "Stored entry in hot layer");
Ok(entry.id)
}
pub async fn get(&self, id: &Uuid) -> Result<Option<MemoryEntry>> {
if let Some(hot_entry) = self.hot.get(id).await {
let memory_entry = MemoryEntry {
id: hot_entry.id,
content: hot_entry.content,
embedding: Some(hot_entry.embedding),
metadata: {
let mut meta = HashMap::new();
if let Some(obj) = hot_entry.metadata.as_object() {
for (k, v) in obj {
meta.insert(k.clone(), v.to_string());
}
}
meta
},
importance: 0.5, access_count: hot_entry.access_count,
created_at: chrono::DateTime::from_timestamp(
hot_entry.created_at.elapsed().as_secs() as i64,
0,
)
.unwrap_or_else(chrono::Utc::now),
last_accessed: chrono::DateTime::from_timestamp(
hot_entry.accessed_at.elapsed().as_secs() as i64,
0,
)
.unwrap_or_else(chrono::Utc::now),
ttl_secs: None,
layer: MemoryLayer::Hot,
tags: Vec::new(),
};
return Ok(Some(memory_entry));
}
if let Some(cold_entry) = self.cold.get(id).await? {
let memory_entry = MemoryEntry {
id: cold_entry.id,
content: cold_entry.content,
embedding: Some(cold_entry.embedding),
metadata: {
let mut meta = HashMap::new();
if let Some(obj) = cold_entry.metadata.as_object() {
for (k, v) in obj {
meta.insert(k.clone(), v.to_string());
}
}
meta
},
importance: 0.5, access_count: 0, created_at: chrono::DateTime::from_timestamp(cold_entry.created_at, 0)
.unwrap_or_else(chrono::Utc::now),
last_accessed: chrono::Utc::now(),
ttl_secs: None,
layer: MemoryLayer::Cold,
tags: Vec::new(),
};
return Ok(Some(memory_entry));
}
Ok(None)
}
pub async fn delete(&self, id: &Uuid) -> Result<bool> {
let operation = WalOperation::Delete { id: *id };
self.wal.append(operation).await?;
let hot_deleted = self.hot.delete(id).await?;
let cold_deleted = self.cold.delete(id).await?;
Ok(hot_deleted || cold_deleted)
}
pub async fn retrieve_context(&self, query: &str, limit: usize) -> Result<Vec<ContextResult>> {
let _context_query = ContextQuery {
text: query.to_string(),
embedding: Vec::new(), limit,
min_score: 0.0,
recency_weight: 0.3,
};
Ok(Vec::new())
}
pub async fn sync(&self) -> Result<SyncStats> {
let start = Instant::now();
let mut stats = SyncStats::default();
let threshold = Duration::from_secs(self.config.sync.hot_to_cold_age_secs);
let min_importance = 0.0;
let hot_entries: Vec<MemoryEntry> = Vec::new();
for entry in hot_entries {
if entry.importance >= min_importance {
continue;
}
let idle_duration = Duration::from_secs(entry.idle_secs() as u64);
if idle_duration >= threshold {
let cold_entry = ColdMemoryEntry {
id: entry.id,
content: entry.content.clone(),
embedding: entry.embedding.clone().unwrap_or_default(),
metadata: serde_json::json!(entry.metadata),
created_at: entry.created_at.timestamp(),
};
self.cold.store(&cold_entry).await?;
self.hot.delete(&entry.id).await?;
stats.hot_to_cold += 1;
}
if entry.is_expired() {
self.hot.delete(&entry.id).await?;
stats.expired_removed += 1;
}
}
let expired_cold = 0; stats.expired_removed += expired_cold;
stats.wal_compacted = 0;
stats.duration_ms = start.elapsed().as_millis() as u64;
tracing::info!(
hot_to_cold = stats.hot_to_cold,
expired = stats.expired_removed,
wal_compacted = stats.wal_compacted,
duration_ms = stats.duration_ms,
"Sync completed"
);
Ok(stats)
}
pub async fn recover(&self) -> Result<RecoveryReport> {
let start = Instant::now();
let mut report = RecoveryReport::default();
let operations: Vec<(u64, WalOperation)> = Vec::new();
for (_seq, operation) in operations {
match operation {
WalOperation::Insert {
id,
content,
embedding,
} => {
let hot_entry =
HotMemoryEntry::new(id, content, embedding, serde_json::json!({}));
match self.hot.put(hot_entry).await {
Ok(_) => {
report.entries_recovered += 1;
}
Err(e) => {
report
.errors
.push(format!("Failed to recover entry: {}", e));
report.entries_lost += 1;
}
}
}
WalOperation::Delete { id } => {
let _ = self.hot.delete(&id).await;
let _ = self.cold.delete(&id).await;
}
WalOperation::Update {
id,
content,
embedding,
} => {
let hot_entry =
HotMemoryEntry::new(id, content, embedding, serde_json::json!({}));
match self.hot.put(hot_entry).await {
Ok(_) => {
report.entries_recovered += 1;
}
Err(e) => {
report
.errors
.push(format!("Failed to recover update: {}", e));
}
}
}
WalOperation::Checkpoint {
lsn,
checkpoint_id: _,
} => {
report.last_sequence = lsn; }
WalOperation::BatchInsert { .. }
| WalOperation::BatchDelete { .. }
| WalOperation::TxnBegin { .. }
| WalOperation::TxnCommit { .. }
| WalOperation::TxnRollback { .. } => {
}
}
report.operations_replayed += 1;
}
report.duration_ms = start.elapsed().as_millis() as u64;
report.success = report.errors.is_empty();
Ok(report)
}
pub async fn shutdown(&self) -> Result<()> {
self.is_shutdown
.store(true, std::sync::atomic::Ordering::SeqCst);
let _ = self.sync().await;
self.wal.sync().await?;
tracing::info!("DualLayerMemory shutdown complete");
Ok(())
}
fn start_background_sync(&mut self) {
let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel();
self.shutdown_tx = Some(shutdown_tx);
let hot = self.hot.clone();
let cold = self.cold.clone();
let wal = self.wal.clone();
let config = self.config.clone();
let is_shutdown = self.is_shutdown.clone();
let handle = tokio::spawn(async move {
let mut interval =
tokio::time::interval(Duration::from_secs(config.sync.interval_secs));
loop {
tokio::select! {
_ = interval.tick() => {
if is_shutdown.load(std::sync::atomic::Ordering::SeqCst) {
break;
}
if let Err(e) = Self::background_sync_iteration(
&hot,
&cold,
&wal,
&config,
).await {
tracing::warn!(error = %e, "Background sync iteration failed");
}
}
_ = &mut shutdown_rx => {
tracing::debug!("Background sync received shutdown signal");
break;
}
}
}
tracing::debug!("Background sync worker exited");
});
self.sync_handle = Some(handle);
}
async fn background_sync_iteration(
hot: &Arc<HotMemory>,
cold: &Arc<ColdMemory>,
_wal: &Arc<WriteAheadLog>,
config: &DualLayerConfig,
) -> Result<()> {
let threshold = Duration::from_secs(config.sync.hot_to_cold_age_secs);
let min_importance = 0.0;
let hot_entries: Vec<MemoryEntry> = Vec::new(); let mut migrated = 0;
let mut expired = 0;
for entry in hot_entries {
if entry.is_expired() {
hot.delete(&entry.id).await?;
expired += 1;
continue;
}
if entry.importance >= min_importance {
continue;
}
let idle_duration = Duration::from_secs(entry.idle_secs() as u64);
if idle_duration >= threshold {
let cold_entry = ColdMemoryEntry {
id: entry.id,
content: entry.content.clone(),
embedding: entry.embedding.clone().unwrap_or_default(),
metadata: serde_json::json!(entry.metadata),
created_at: entry.created_at.timestamp(),
};
cold.store(&cold_entry).await?;
hot.delete(&entry.id).await?;
migrated += 1;
}
}
let cold_expired = 0;
if migrated > 0 || expired > 0 || cold_expired > 0 {
tracing::debug!(
migrated = migrated,
hot_expired = expired,
cold_expired = cold_expired,
"Background sync completed"
);
}
Ok(())
}
pub async fn stats(&self) -> Result<DualLayerStats> {
let hot_stats = self.hot.stats().await;
let cold_stats = self.cold.stats().await;
let wal_stats = self.wal.stats().await;
Ok(DualLayerStats {
hot_entry_count: hot_stats.entry_count,
hot_memory_bytes: 0, cold_entry_count: cold_stats.entry_count as usize,
cold_disk_bytes: cold_stats.embeddings_size_bytes + cold_stats.metadata_size_bytes,
wal_entry_count: 0, wal_disk_bytes: wal_stats.total_size_bytes,
total_entries: hot_stats.entry_count + cold_stats.entry_count as usize,
})
}
pub fn hot(&self) -> &Arc<HotMemory> {
&self.hot
}
pub fn cold(&self) -> &Arc<ColdMemory> {
&self.cold
}
pub fn wal(&self) -> &Arc<WriteAheadLog> {
&self.wal
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct DualLayerStats {
pub hot_entry_count: usize,
pub hot_memory_bytes: usize,
pub cold_entry_count: usize,
pub cold_disk_bytes: u64,
pub wal_entry_count: usize,
pub wal_disk_bytes: u64,
pub total_entries: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QdrantSecurityConfig {
pub api_key: Option<String>,
pub tls_enabled: bool,
pub ca_cert_path: Option<String>,
pub client_cert_path: Option<String>,
pub client_key_path: Option<String>,
pub skip_tls_verify: bool,
}
impl Default for QdrantSecurityConfig {
fn default() -> Self {
Self {
api_key: None,
tls_enabled: true,
ca_cert_path: None,
client_cert_path: None,
client_key_path: None,
skip_tls_verify: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QdrantConnectionConfig {
pub max_connections: usize,
pub connect_timeout_secs: u64,
pub request_timeout_secs: u64,
pub health_check_interval_secs: u64,
pub max_idle_secs: u64,
pub security: QdrantSecurityConfig,
}
impl Default for QdrantConnectionConfig {
fn default() -> Self {
Self {
max_connections: 10,
connect_timeout_secs: 30,
request_timeout_secs: 60,
health_check_interval_secs: 300, max_idle_secs: 600, security: QdrantSecurityConfig::default(),
}
}
}
fn qdrant_value_to_json(value: &qdrant_client::qdrant::Value) -> serde_json::Value {
use qdrant_client::qdrant::value::Kind;
match &value.kind {
Some(Kind::NullValue(_)) => serde_json::Value::Null,
Some(Kind::BoolValue(v)) => serde_json::Value::Bool(*v),
Some(Kind::IntegerValue(v)) => serde_json::Value::Number((*v).into()),
Some(Kind::DoubleValue(v)) => {
serde_json::Value::Number(serde_json::Number::from_f64(*v).unwrap_or(0.into()))
}
Some(Kind::StringValue(v)) => serde_json::Value::String(v.clone()),
Some(Kind::ListValue(v)) => {
let items = v.values.iter().map(qdrant_value_to_json).collect();
serde_json::Value::Array(items)
}
Some(Kind::StructValue(v)) => {
let fields = v
.fields
.iter()
.map(|(k, v)| (k.clone(), qdrant_value_to_json(v)))
.collect();
serde_json::Value::Object(fields)
}
None => serde_json::Value::Null,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum AccessLevel {
Read,
ReadWrite,
Admin,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccessControlConfig {
pub read_level: AccessLevel,
pub write_level: AccessLevel,
pub delete_level: AccessLevel,
pub admin_level: AccessLevel,
pub enable_audit_log: bool,
}
impl Default for AccessControlConfig {
fn default() -> Self {
Self {
read_level: AccessLevel::Read,
write_level: AccessLevel::ReadWrite,
delete_level: AccessLevel::ReadWrite,
admin_level: AccessLevel::Admin,
enable_audit_log: true,
}
}
}
#[derive(Debug, Clone)]
pub struct AccessContext {
pub user_id: String,
pub access_level: AccessLevel,
pub operation: String,
pub timestamp: i64,
}
impl AccessContext {
pub fn new(user_id: String, access_level: AccessLevel, operation: String) -> Self {
Self {
user_id,
access_level,
operation,
timestamp: chrono::Utc::now().timestamp(),
}
}
pub fn has_permission(
&self,
required_level: &AccessLevel,
_config: &AccessControlConfig,
) -> bool {
match required_level {
AccessLevel::Read => matches!(
self.access_level,
AccessLevel::Read | AccessLevel::ReadWrite | AccessLevel::Admin
),
AccessLevel::ReadWrite => matches!(
self.access_level,
AccessLevel::ReadWrite | AccessLevel::Admin
),
AccessLevel::Admin => matches!(self.access_level, AccessLevel::Admin),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingCacheConfig {
pub max_size: usize,
pub ttl_secs: u64,
}
impl Default for EmbeddingCacheConfig {
fn default() -> Self {
Self {
max_size: 10000,
ttl_secs: 3600, }
}
}
#[derive(Debug, Clone)]
struct CachedEmbedding {
embedding: Vec<f32>,
created_at: Instant,
}
#[derive(Debug)]
pub struct EmbeddingCache {
cache: HashMap<Uuid, CachedEmbedding>,
access_order: Vec<Uuid>,
config: EmbeddingCacheConfig,
}
impl EmbeddingCache {
pub fn new(config: EmbeddingCacheConfig) -> Self {
Self {
cache: HashMap::new(),
access_order: Vec::new(),
config,
}
}
pub fn put(&mut self, chunk_id: Uuid, embedding: Vec<f32>) {
let entry = CachedEmbedding {
embedding,
created_at: Instant::now(),
};
if self.cache.contains_key(&chunk_id) {
self.access_order.retain(|&id| id != chunk_id);
}
self.cache.insert(chunk_id, entry);
self.access_order.push(chunk_id);
while self.cache.len() > self.config.max_size {
let oldest_id = self.access_order.remove(0);
self.cache.remove(&oldest_id);
}
}
pub fn get(&mut self, chunk_id: &Uuid) -> Option<Vec<f32>> {
if let Some(entry) = self.cache.get(chunk_id) {
if entry.created_at.elapsed().as_secs() <= self.config.ttl_secs {
self.access_order.retain(|&id| id != *chunk_id);
self.access_order.push(*chunk_id);
return Some(entry.embedding.clone());
} else {
self.cache.remove(chunk_id);
self.access_order.retain(|&id| id != *chunk_id);
}
}
None
}
pub fn cleanup_expired(&mut self) {
let mut to_remove = Vec::new();
for (id, entry) in &self.cache {
if entry.created_at.elapsed().as_secs() > self.config.ttl_secs {
to_remove.push(*id);
}
}
for id in to_remove {
self.cache.remove(&id);
self.access_order.retain(|&order_id| order_id != id);
}
}
pub fn remove(&mut self, chunk_id: &Uuid) {
self.cache.remove(chunk_id);
self.access_order.retain(|&id| id != *chunk_id);
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StorageStats {
pub document_count: usize,
pub chunk_count: usize,
pub embedding_count: usize,
pub size_bytes: u64,
}
#[async_trait]
pub trait StorageBackend: Send + Sync {
async fn store_document(&self, doc: &Document, context: &AccessContext) -> Result<()>;
async fn get_document(&self, id: &Uuid, context: &AccessContext) -> Result<Option<Document>>;
async fn delete_document(&self, id: &Uuid, context: &AccessContext) -> Result<()>;
async fn list_documents(&self, context: &AccessContext) -> Result<Vec<Uuid>>;
async fn store_embeddings(
&self,
chunk_id: &Uuid,
embeddings: &[f32],
context: &AccessContext,
) -> Result<()>;
async fn get_embeddings(
&self,
chunk_id: &Uuid,
context: &AccessContext,
) -> Result<Option<Vec<f32>>>;
async fn search_by_vector(
&self,
query_embedding: &[f32],
top_k: usize,
context: &AccessContext,
) -> Result<Vec<(Uuid, f32)>>;
async fn stats(&self, context: &AccessContext) -> Result<StorageStats>;
}
pub struct InMemoryStorage {
documents: Arc<RwLock<HashMap<Uuid, Document>>>,
embeddings: Arc<RwLock<HashMap<Uuid, Vec<f32>>>>,
}
impl Default for InMemoryStorage {
fn default() -> Self {
Self {
documents: Arc::new(RwLock::new(HashMap::new())),
embeddings: Arc::new(RwLock::new(HashMap::new())),
}
}
}
impl InMemoryStorage {
pub fn new() -> Self {
Self::default()
}
}
#[async_trait]
impl StorageBackend for InMemoryStorage {
async fn store_document(&self, doc: &Document, _context: &AccessContext) -> Result<()> {
let mut docs = self.documents.write().await;
docs.insert(doc.id, doc.clone());
Ok(())
}
async fn get_document(&self, id: &Uuid, _context: &AccessContext) -> Result<Option<Document>> {
let docs = self.documents.read().await;
Ok(docs.get(id).cloned())
}
async fn delete_document(&self, id: &Uuid, _context: &AccessContext) -> Result<()> {
let removed = {
let mut docs = self.documents.write().await;
docs.remove(id)
};
if let Some(doc) = removed {
let mut embs = self.embeddings.write().await;
for chunk in &doc.chunks {
embs.remove(&chunk.id);
}
}
Ok(())
}
async fn list_documents(&self, _context: &AccessContext) -> Result<Vec<Uuid>> {
let docs = self.documents.read().await;
Ok(docs.keys().cloned().collect())
}
async fn store_embeddings(
&self,
chunk_id: &Uuid,
embeddings: &[f32],
_context: &AccessContext,
) -> Result<()> {
let mut embs = self.embeddings.write().await;
embs.insert(*chunk_id, embeddings.to_vec());
Ok(())
}
async fn get_embeddings(
&self,
chunk_id: &Uuid,
_context: &AccessContext,
) -> Result<Option<Vec<f32>>> {
let embs = self.embeddings.read().await;
Ok(embs.get(chunk_id).cloned())
}
async fn search_by_vector(
&self,
query_embedding: &[f32],
top_k: usize,
_context: &AccessContext,
) -> Result<Vec<(Uuid, f32)>> {
let embs = self.embeddings.read().await;
let mut results: Vec<(Uuid, f32)> = embs
.iter()
.map(|(id, emb)| (*id, cosine_similarity(query_embedding, emb)))
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(top_k);
Ok(results)
}
async fn stats(&self, _context: &AccessContext) -> Result<StorageStats> {
let docs = self.documents.read().await;
let embs = self.embeddings.read().await;
let chunk_count: usize = docs.values().map(|d| d.chunks.len()).sum();
Ok(StorageStats {
document_count: docs.len(),
chunk_count,
embedding_count: embs.len(),
size_bytes: 0, })
}
}
struct PooledConnection {
client: Qdrant,
last_used: Instant,
#[allow(dead_code)]
created_at: Instant,
}
struct QdrantConnectionPool {
connections: Vec<PooledConnection>,
config: QdrantConnectionConfig,
client_config: qdrant_client::config::QdrantConfig,
}
impl QdrantConnectionPool {
fn new(
client_config: qdrant_client::config::QdrantConfig,
config: QdrantConnectionConfig,
) -> Self {
Self {
connections: Vec::new(),
config,
client_config,
}
}
async fn get_connection(&mut self) -> Result<&mut Qdrant> {
let available_index = self.connections.iter().position(|conn| {
conn.last_used.elapsed() < Duration::from_secs(self.config.max_idle_secs)
});
if let Some(index) = available_index {
self.connections[index].last_used = Instant::now();
return Ok(&mut self.connections[index].client);
}
if self.connections.len() < self.config.max_connections {
let client = Qdrant::new(self.client_config.clone())
.map_err(|e| Error::io(format!("Failed to create Qdrant client: {}", e)))?;
self.connections.push(PooledConnection {
client,
last_used: Instant::now(),
created_at: Instant::now(),
});
let len = self.connections.len();
return Ok(&mut self.connections[len - 1].client);
}
Err(Error::io("Connection pool exhausted".to_string()))
}
#[allow(dead_code)]
fn cleanup_expired(&mut self) {
self.connections.retain(|conn| {
conn.created_at.elapsed() < Duration::from_secs(self.config.max_idle_secs)
});
}
async fn health_check(&mut self) -> Result<()> {
if let Ok(client) = self.get_connection().await {
client
.list_collections()
.await
.map_err(|e| Error::io(format!("Health check failed: {}", e)))?;
}
Ok(())
}
}
pub struct FileStorage {
base_path: PathBuf,
documents: Arc<RwLock<HashMap<Uuid, Document>>>,
}
impl FileStorage {
pub async fn new(base_path: PathBuf) -> Result<Self> {
tokio::fs::create_dir_all(&base_path)
.await
.map_err(|e| Error::io(format!("Failed to create storage directory: {}", e)))?;
tokio::fs::create_dir_all(base_path.join("documents"))
.await
.map_err(|e| Error::io(format!("Failed to create documents directory: {}", e)))?;
tokio::fs::create_dir_all(base_path.join("embeddings"))
.await
.map_err(|e| Error::io(format!("Failed to create embeddings directory: {}", e)))?;
let documents = Arc::new(RwLock::new(HashMap::new()));
let storage = Self {
base_path,
documents,
};
storage.load_documents().await?;
Ok(storage)
}
async fn load_documents(&self) -> Result<()> {
let docs_path = self.base_path.join("documents");
let mut entries = tokio::fs::read_dir(&docs_path)
.await
.map_err(|e| Error::io(format!("Failed to read documents directory: {}", e)))?;
let mut docs = self.documents.write().await;
while let Some(entry) = entries
.next_entry()
.await
.map_err(|e| Error::io(format!("Failed to read directory entry: {}", e)))?
{
let path = entry.path();
if path.extension().is_some_and(|ext| ext == "json") {
let content = tokio::fs::read_to_string(&path)
.await
.map_err(|e| Error::io(format!("Failed to read document file: {}", e)))?;
let doc: Document = serde_json::from_str(&content)
.map_err(|e| Error::parse(format!("Failed to parse document: {}", e)))?;
docs.insert(doc.id, doc);
}
}
Ok(())
}
fn doc_path(&self, id: &Uuid) -> PathBuf {
self.base_path
.join("documents")
.join(format!("{}.json", id))
}
fn embedding_path(&self, id: &Uuid) -> PathBuf {
self.base_path
.join("embeddings")
.join(format!("{}.bin", id))
}
}
#[async_trait]
impl StorageBackend for FileStorage {
async fn store_document(&self, doc: &Document, _context: &AccessContext) -> Result<()> {
let path = self.doc_path(&doc.id);
let content = serde_json::to_string_pretty(doc)
.map_err(|e| Error::parse(format!("Failed to serialize document: {}", e)))?;
tokio::fs::write(&path, content)
.await
.map_err(|e| Error::io(format!("Failed to write document: {}", e)))?;
let mut docs = self.documents.write().await;
docs.insert(doc.id, doc.clone());
Ok(())
}
async fn get_document(&self, id: &Uuid, _context: &AccessContext) -> Result<Option<Document>> {
let docs = self.documents.read().await;
Ok(docs.get(id).cloned())
}
async fn delete_document(&self, id: &Uuid, _context: &AccessContext) -> Result<()> {
let path = self.doc_path(id);
let doc = {
let docs = self.documents.read().await;
docs.get(id).cloned()
};
let doc = match doc {
Some(doc) => Some(doc),
None if path.exists() => {
let content = tokio::fs::read_to_string(&path)
.await
.map_err(|e| Error::io(format!("Failed to read document for deletion: {e}")))?;
let doc: Document = serde_json::from_str(&content).map_err(|e| {
Error::parse(format!("Failed to parse document for deletion: {e}"))
})?;
Some(doc)
}
None => None,
};
if let Some(doc) = doc {
for chunk in &doc.chunks {
let emb_path = self.embedding_path(&chunk.id);
match tokio::fs::remove_file(&emb_path).await {
Ok(()) => {}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
Err(e) => {
return Err(Error::io(format!(
"Failed to delete embedding {:?}: {e}",
emb_path
)));
}
}
}
}
if path.exists() {
tokio::fs::remove_file(&path)
.await
.map_err(|e| Error::io(format!("Failed to delete document: {}", e)))?;
}
let mut docs = self.documents.write().await;
docs.remove(id);
Ok(())
}
async fn list_documents(&self, _context: &AccessContext) -> Result<Vec<Uuid>> {
let docs = self.documents.read().await;
Ok(docs.keys().cloned().collect())
}
async fn store_embeddings(
&self,
chunk_id: &Uuid,
embeddings: &[f32],
_context: &AccessContext,
) -> Result<()> {
let path = self.embedding_path(chunk_id);
let bytes: Vec<u8> = embeddings.iter().flat_map(|f| f.to_le_bytes()).collect();
tokio::fs::write(&path, bytes)
.await
.map_err(|e| Error::io(format!("Failed to write embeddings: {}", e)))?;
Ok(())
}
async fn get_embeddings(
&self,
chunk_id: &Uuid,
_context: &AccessContext,
) -> Result<Option<Vec<f32>>> {
let path = self.embedding_path(chunk_id);
if !path.exists() {
return Ok(None);
}
let bytes = tokio::fs::read(&path)
.await
.map_err(|e| Error::io(format!("Failed to read embeddings: {}", e)))?;
let embeddings: Vec<f32> = bytes
.chunks(4)
.map(|chunk: &[u8]| {
let arr: [u8; 4] = chunk.try_into().unwrap_or([0; 4]);
f32::from_le_bytes(arr)
})
.collect();
Ok(Some(embeddings))
}
async fn search_by_vector(
&self,
query_embedding: &[f32],
top_k: usize,
_context: &AccessContext,
) -> Result<Vec<(Uuid, f32)>> {
let embeddings_dir = self.base_path.join("embeddings");
let mut results: Vec<(Uuid, f32)> = Vec::new();
let mut entries = tokio::fs::read_dir(&embeddings_dir)
.await
.map_err(|e| Error::io(format!("Failed to read embeddings directory: {}", e)))?;
while let Some(entry) = entries
.next_entry()
.await
.map_err(|e| Error::io(format!("Failed to read entry: {}", e)))?
{
let path = entry.path();
if path.extension().is_some_and(|ext| ext == "bin") {
if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
if let Ok(id) = Uuid::parse_str(stem) {
let bytes = tokio::fs::read(&path)
.await
.map_err(|e| Error::io(format!("Failed to read embeddings: {}", e)))?;
let embeddings: Vec<f32> = bytes
.chunks(4)
.map(|chunk: &[u8]| {
let arr: [u8; 4] = chunk.try_into().unwrap_or([0; 4]);
f32::from_le_bytes(arr)
})
.collect();
let score = cosine_similarity(query_embedding, &embeddings);
results.push((id, score));
}
}
}
}
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(top_k);
Ok(results)
}
async fn stats(&self, _context: &AccessContext) -> Result<StorageStats> {
let docs = self.documents.read().await;
let chunk_count: usize = docs.values().map(|d| d.chunks.len()).sum();
let embeddings_dir = self.base_path.join("embeddings");
let mut embedding_count = 0;
if let Ok(mut entries) = tokio::fs::read_dir(&embeddings_dir).await {
while let Ok(Some(_)) = entries.next_entry().await {
embedding_count += 1;
}
}
let mut size_bytes: u64 = 0;
let docs_dir = self.base_path.join("documents");
if let Ok(mut entries) = tokio::fs::read_dir(&docs_dir).await {
while let Ok(Some(entry)) = entries.next_entry().await {
if let Ok(metadata) = entry.metadata().await {
size_bytes += metadata.len();
}
}
}
Ok(StorageStats {
document_count: docs.len(),
chunk_count,
embedding_count,
size_bytes,
})
}
}
pub struct QdrantStorage {
pool: Arc<RwLock<QdrantConnectionPool>>,
collection_name: String,
vector_size: usize,
embedding_cache: Arc<RwLock<EmbeddingCache>>,
access_control: AccessControlConfig,
}
impl QdrantStorage {
pub async fn new(
host: &str,
port: u16,
grpc_port: u16,
collection_name: String,
vector_size: usize,
embedded: bool,
) -> Result<Self> {
Self::new_with_config(
host,
port,
grpc_port,
collection_name,
vector_size,
embedded,
QdrantConnectionConfig::default(),
EmbeddingCacheConfig::default(),
AccessControlConfig::default(),
)
.await
}
#[allow(clippy::too_many_arguments)]
pub async fn new_with_config(
host: &str,
port: u16,
_grpc_port: u16,
collection_name: String,
vector_size: usize,
embedded: bool,
conn_config: QdrantConnectionConfig,
cache_config: EmbeddingCacheConfig,
access_config: AccessControlConfig,
) -> Result<Self> {
let config = if embedded {
qdrant_client::config::QdrantConfig::from_url("http://localhost:6333")
} else {
qdrant_client::config::QdrantConfig::from_url(&format!("http://{}:{}", host, port))
};
let pool = Arc::new(RwLock::new(QdrantConnectionPool::new(
config,
conn_config.clone(),
)));
let embedding_cache = Arc::new(RwLock::new(EmbeddingCache::new(cache_config)));
let storage = Self {
pool: pool.clone(),
collection_name: collection_name.clone(),
vector_size,
embedding_cache,
access_control: access_config,
};
{
let mut pool_guard = pool.write().await;
let client = pool_guard.get_connection().await?;
Self::ensure_collection(client, &collection_name, vector_size).await?;
}
let pool_clone = pool.clone();
tokio::spawn(async move {
let mut interval =
tokio::time::interval(Duration::from_secs(conn_config.health_check_interval_secs));
loop {
interval.tick().await;
let mut pool = pool_clone.write().await;
if let Err(e) = pool.health_check().await {
tracing::warn!("Qdrant health check failed: {}", e);
}
}
});
Ok(storage)
}
async fn ensure_collection(
client: &Qdrant,
collection_name: &str,
vector_size: usize,
) -> Result<()> {
let collections = client
.list_collections()
.await
.map_err(|e| Error::io(format!("Failed to list collections: {}", e)))?;
let collection_exists = collections
.collections
.iter()
.any(|c| c.name == collection_name);
if !collection_exists {
let vector_params = VectorParams {
size: vector_size as u64,
distance: Distance::Cosine as i32,
hnsw_config: None,
quantization_config: Some(QuantizationConfig {
quantization: Some(
qdrant_client::qdrant::quantization_config::Quantization::Scalar(
ScalarQuantization {
r#type: qdrant_client::qdrant::QuantizationType::Int8 as i32,
quantile: None,
always_ram: None,
},
),
),
}),
on_disk: None,
datatype: None,
multivector_config: None,
};
let collection_params = CreateCollection {
collection_name: collection_name.to_string(),
vectors_config: Some(VectorsConfig {
config: Some(qdrant_client::qdrant::vectors_config::Config::Params(
vector_params,
)),
}),
..Default::default()
};
client
.create_collection(collection_params)
.await
.map_err(|e| Error::io(format!("Failed to create collection: {}", e)))?;
}
Ok(())
}
fn point_id_from_uuid(uuid: &Uuid) -> PointId {
PointId::from(uuid.to_string())
}
fn uuid_from_point_id(point_id: &PointId) -> Option<Uuid> {
match &point_id.point_id_options {
Some(qdrant_client::qdrant::point_id::PointIdOptions::Uuid(uuid_str)) => {
Uuid::parse_str(uuid_str).ok()
}
Some(qdrant_client::qdrant::point_id::PointIdOptions::Num(num)) => {
tracing::warn!(
"Cannot convert numeric PointId {} to UUID - using UUID strings is required",
num
);
None
}
None => None,
}
}
fn check_access(&self, context: &AccessContext, required_level: &AccessLevel) -> Result<()> {
if !context.has_permission(required_level, &self.access_control) {
return Err(Error::validation(format!(
"Access denied: user {} requires {:?} level for operation '{}', has {:?}",
context.user_id, required_level, context.operation, context.access_level
)));
}
if self.access_control.enable_audit_log {
tracing::info!(
"Access granted: user={}, operation={}, level={:?}, timestamp={}",
context.user_id,
context.operation,
context.access_level,
context.timestamp
);
}
Ok(())
}
}
#[async_trait]
impl StorageBackend for QdrantStorage {
async fn store_document(&self, doc: &Document, context: &AccessContext) -> Result<()> {
self.check_access(context, &self.access_control.write_level)?;
let mut pool = self.pool.write().await;
let client = pool.get_connection().await?;
let payload: std::collections::HashMap<String, serde_json::Value> =
serde_json::from_str(&serde_json::to_string(doc)?)
.map_err(|e| Error::parse(format!("Failed to serialize document: {}", e)))?;
let point = PointStruct::new(
Self::point_id_from_uuid(&doc.id),
vec![], payload,
);
let points = vec![point];
let upsert_points = UpsertPoints {
collection_name: self.collection_name.clone(),
wait: Some(true),
points,
..Default::default()
};
client
.upsert_points(upsert_points)
.await
.map_err(|e| Error::io(format!("Failed to store document: {}", e)))?;
Ok(())
}
async fn get_document(&self, id: &Uuid, context: &AccessContext) -> Result<Option<Document>> {
self.check_access(context, &self.access_control.read_level)?;
let mut pool = self.pool.write().await;
let client = pool.get_connection().await?;
let point_id = Self::point_id_from_uuid(id);
let get_points = GetPoints {
collection_name: self.collection_name.clone(),
ids: vec![point_id],
with_payload: Some(qdrant_client::qdrant::WithPayloadSelector {
selector_options: Some(
qdrant_client::qdrant::with_payload_selector::SelectorOptions::Enable(true),
),
}),
with_vectors: Some(qdrant_client::qdrant::WithVectorsSelector {
selector_options: Some(
qdrant_client::qdrant::with_vectors_selector::SelectorOptions::Enable(false),
),
}),
..Default::default()
};
let response = client
.get_points(get_points)
.await
.map_err(|e| Error::io(format!("Failed to get document: {}", e)))?;
if let Some(point) = response.result.first() {
let json_payload: std::collections::HashMap<String, serde_json::Value> = point
.payload
.iter()
.map(|(k, v)| (k.clone(), qdrant_value_to_json(v)))
.collect();
let doc: Document = serde_json::from_value(serde_json::Value::Object(
json_payload.into_iter().collect(),
))
.map_err(|e| Error::parse(format!("Failed to deserialize document: {}", e)))?;
Ok(Some(doc))
} else {
Ok(None)
}
}
async fn delete_document(&self, id: &Uuid, context: &AccessContext) -> Result<()> {
self.check_access(context, &self.access_control.delete_level)?;
let mut pool = self.pool.write().await;
let client = pool.get_connection().await?;
let doc = {
let point_id = Self::point_id_from_uuid(id);
let get_points = GetPoints {
collection_name: self.collection_name.clone(),
ids: vec![point_id],
with_payload: Some(qdrant_client::qdrant::WithPayloadSelector {
selector_options: Some(
qdrant_client::qdrant::with_payload_selector::SelectorOptions::Enable(true),
),
}),
with_vectors: Some(qdrant_client::qdrant::WithVectorsSelector {
selector_options: Some(
qdrant_client::qdrant::with_vectors_selector::SelectorOptions::Enable(
false,
),
),
}),
..Default::default()
};
let response = client
.get_points(get_points)
.await
.map_err(|e| Error::io(format!("Failed to get document for deletion: {e}")))?;
if let Some(point) = response.result.first() {
let json_payload: std::collections::HashMap<String, serde_json::Value> = point
.payload
.iter()
.map(|(k, v)| (k.clone(), qdrant_value_to_json(v)))
.collect();
let doc: Document = serde_json::from_value(serde_json::Value::Object(
json_payload.into_iter().collect(),
))
.map_err(|e| Error::parse(format!("Failed to deserialize document: {e}")))?;
Some(doc)
} else {
None
}
};
let mut ids: Vec<PointId> = Vec::new();
if let Some(doc) = doc {
{
let mut cache = self.embedding_cache.write().await;
for chunk in &doc.chunks {
cache.remove(&chunk.id);
}
}
ids.extend(doc.chunks.iter().map(|c| Self::point_id_from_uuid(&c.id)));
}
ids.push(Self::point_id_from_uuid(id));
if ids.is_empty() {
return Ok(());
}
let delete_points = DeletePoints {
collection_name: self.collection_name.clone(),
wait: Some(true),
points: Some(qdrant_client::qdrant::PointsSelector {
points_selector_one_of: Some(
qdrant_client::qdrant::points_selector::PointsSelectorOneOf::Points(
qdrant_client::qdrant::PointsIdsList { ids },
),
),
}),
..Default::default()
};
client
.delete_points(delete_points)
.await
.map_err(|e| Error::io(format!("Failed to delete document: {}", e)))?;
Ok(())
}
async fn list_documents(&self, context: &AccessContext) -> Result<Vec<Uuid>> {
self.check_access(context, &self.access_control.read_level)?;
let mut pool = self.pool.write().await;
let client = pool.get_connection().await?;
let mut all_ids = Vec::new();
let mut offset = None;
loop {
let scroll_points = ScrollPoints {
collection_name: self.collection_name.clone(),
limit: Some(100),
offset,
with_payload: Some(qdrant_client::qdrant::WithPayloadSelector {
selector_options: Some(
qdrant_client::qdrant::with_payload_selector::SelectorOptions::Enable(
false,
),
),
}),
with_vectors: Some(qdrant_client::qdrant::WithVectorsSelector {
selector_options: Some(
qdrant_client::qdrant::with_vectors_selector::SelectorOptions::Enable(
false,
),
),
}),
..Default::default()
};
let response = client
.scroll(scroll_points)
.await
.map_err(|e| Error::io(format!("Failed to scroll points: {}", e)))?;
for point in &response.result {
if let Some(id) = &point.id {
if let Some(uuid) = Self::uuid_from_point_id(id) {
all_ids.push(uuid);
}
}
}
if response.next_page_offset.is_none() {
break;
}
offset = response.next_page_offset;
}
Ok(all_ids)
}
async fn store_embeddings(
&self,
chunk_id: &Uuid,
embeddings: &[f32],
context: &AccessContext,
) -> Result<()> {
self.check_access(context, &self.access_control.write_level)?;
if embeddings.len() != self.vector_size {
return Err(Error::validation(format!(
"Embedding size {} does not match configured vector size {}",
embeddings.len(),
self.vector_size
)));
}
{
let mut cache = self.embedding_cache.write().await;
cache.put(*chunk_id, embeddings.to_vec());
}
let mut pool = self.pool.write().await;
let client = pool.get_connection().await?;
let point_id = Self::point_id_from_uuid(chunk_id);
let mut payload: std::collections::HashMap<String, serde_json::Value> =
std::collections::HashMap::new();
payload.insert(
"chunk_id".to_string(),
serde_json::Value::String(chunk_id.to_string()),
);
let point = PointStruct::new(point_id, embeddings.to_vec(), payload);
let points = vec![point];
let upsert_points = UpsertPoints {
collection_name: self.collection_name.clone(),
wait: Some(true),
points,
..Default::default()
};
client
.upsert_points(upsert_points)
.await
.map_err(|e| Error::io(format!("Failed to store embeddings: {}", e)))?;
Ok(())
}
async fn get_embeddings(
&self,
chunk_id: &Uuid,
context: &AccessContext,
) -> Result<Option<Vec<f32>>> {
self.check_access(context, &self.access_control.read_level)?;
{
let mut cache = self.embedding_cache.write().await;
cache.cleanup_expired(); if let Some(embedding) = cache.get(chunk_id) {
return Ok(Some(embedding));
}
}
let mut pool = self.pool.write().await;
let client = pool.get_connection().await?;
let point_id = Self::point_id_from_uuid(chunk_id);
let get_points = GetPoints {
collection_name: self.collection_name.clone(),
ids: vec![point_id],
with_payload: Some(qdrant_client::qdrant::WithPayloadSelector {
selector_options: Some(
qdrant_client::qdrant::with_payload_selector::SelectorOptions::Enable(false),
),
}),
with_vectors: Some(qdrant_client::qdrant::WithVectorsSelector {
selector_options: Some(
qdrant_client::qdrant::with_vectors_selector::SelectorOptions::Enable(true),
),
}),
..Default::default()
};
let response = client
.get_points(get_points)
.await
.map_err(|e| Error::io(format!("Failed to get embeddings: {}", e)))?;
if let Some(point) = response.result.first() {
if let Some(vectors) = &point.vectors {
use qdrant_client::qdrant::vectors_output::VectorsOptions;
match &vectors.vectors_options {
Some(VectorsOptions::Vector(vector_output)) => {
use qdrant_client::qdrant::vector_output::Vector as OutputVector;
match vector_output.clone().into_vector() {
OutputVector::Dense(dense) => {
let embedding = dense.data;
self.embedding_cache
.write()
.await
.put(*chunk_id, embedding.clone());
Ok(Some(embedding))
}
_ => Ok(None), }
}
Some(VectorsOptions::Vectors(named_vectors)) => {
use qdrant_client::qdrant::vector_output::Vector as OutputVector;
if let Some(vector_output) = named_vectors.vectors.get("") {
match vector_output.clone().into_vector() {
OutputVector::Dense(dense) => {
let embedding = dense.data;
self.embedding_cache
.write()
.await
.put(*chunk_id, embedding.clone());
Ok(Some(embedding))
}
_ => Ok(None),
}
} else if let Some((_, vector_output)) = named_vectors.vectors.iter().next()
{
match vector_output.clone().into_vector() {
OutputVector::Dense(dense) => {
let embedding = dense.data;
self.embedding_cache
.write()
.await
.put(*chunk_id, embedding.clone());
Ok(Some(embedding))
}
_ => Ok(None),
}
} else {
Ok(None)
}
}
None => Ok(None),
}
} else {
Ok(None)
}
} else {
Ok(None)
}
}
async fn search_by_vector(
&self,
query_embedding: &[f32],
top_k: usize,
context: &AccessContext,
) -> Result<Vec<(Uuid, f32)>> {
self.check_access(context, &self.access_control.read_level)?;
if query_embedding.len() != self.vector_size {
return Err(Error::validation(format!(
"Query embedding size {} does not match configured vector size {}",
query_embedding.len(),
self.vector_size
)));
}
let mut pool = self.pool.write().await;
let client = pool.get_connection().await?;
let search_points = SearchPoints {
collection_name: self.collection_name.clone(),
vector: query_embedding.to_vec(),
limit: top_k as u64,
with_payload: Some(qdrant_client::qdrant::WithPayloadSelector {
selector_options: Some(
qdrant_client::qdrant::with_payload_selector::SelectorOptions::Enable(true),
),
}),
..Default::default()
};
let response = client
.search_points(search_points)
.await
.map_err(|e| Error::io(format!("Failed to search vectors: {}", e)))?;
let results = response
.result
.into_iter()
.filter_map(|scored_point| {
scored_point
.id
.as_ref()
.and_then(Self::uuid_from_point_id)
.map(|uuid| (uuid, scored_point.score))
})
.collect();
Ok(results)
}
async fn stats(&self, context: &AccessContext) -> Result<StorageStats> {
self.check_access(context, &self.access_control.admin_level)?;
let mut pool = self.pool.write().await;
let client = pool.get_connection().await?;
let collection_info = client
.collection_info(&self.collection_name)
.await
.map_err(|e| Error::io(format!("Failed to get collection info: {}", e)))?;
let points_count = collection_info
.result
.as_ref()
.map(|info| info.points_count.unwrap_or(0))
.unwrap_or(0);
let document_count = points_count.saturating_sub(points_count);
Ok(StorageStats {
document_count: document_count as usize,
chunk_count: points_count as usize,
embedding_count: points_count as usize,
size_bytes: 0, })
}
}
pub struct Storage {
backend: Box<dyn StorageBackend>,
}
impl Storage {
pub fn in_memory() -> Self {
Self {
backend: Box::new(InMemoryStorage::new()),
}
}
pub async fn file(base_path: PathBuf) -> Result<Self> {
Ok(Self {
backend: Box::new(FileStorage::new(base_path).await?),
})
}
pub async fn new_embedded() -> Result<Self> {
create_embedded_storage(EmbeddedStorageConfig::default()).await
}
pub async fn new_embedded_with_config(config: EmbeddedStorageConfig) -> Result<Self> {
create_embedded_storage(config).await
}
pub async fn qdrant(
host: &str,
port: u16,
grpc_port: u16,
collection_name: String,
vector_size: usize,
embedded: bool,
) -> Result<Self> {
Ok(Self {
backend: Box::new(
QdrantStorage::new(
host,
port,
grpc_port,
collection_name,
vector_size,
embedded,
)
.await?,
),
})
}
#[allow(clippy::too_many_arguments)]
pub async fn qdrant_with_config(
host: &str,
port: u16,
grpc_port: u16,
collection_name: String,
vector_size: usize,
embedded: bool,
conn_config: QdrantConnectionConfig,
cache_config: EmbeddingCacheConfig,
access_config: AccessControlConfig,
) -> Result<Self> {
Ok(Self {
backend: Box::new(
QdrantStorage::new_with_config(
host,
port,
grpc_port,
collection_name,
vector_size,
embedded,
conn_config,
cache_config,
access_config,
)
.await?,
),
})
}
pub async fn store_document(&self, doc: &Document, context: &AccessContext) -> Result<()> {
self.backend.store_document(doc, context).await
}
pub async fn get_document(
&self,
id: &Uuid,
context: &AccessContext,
) -> Result<Option<Document>> {
self.backend.get_document(id, context).await
}
pub async fn delete_document(&self, id: &Uuid, context: &AccessContext) -> Result<()> {
self.backend.delete_document(id, context).await
}
pub async fn list_documents(&self, context: &AccessContext) -> Result<Vec<Uuid>> {
self.backend.list_documents(context).await
}
pub async fn store_embeddings(
&self,
chunk_id: &Uuid,
embeddings: &[f32],
context: &AccessContext,
) -> Result<()> {
self.backend
.store_embeddings(chunk_id, embeddings, context)
.await
}
pub async fn get_embeddings(
&self,
chunk_id: &Uuid,
context: &AccessContext,
) -> Result<Option<Vec<f32>>> {
self.backend.get_embeddings(chunk_id, context).await
}
pub async fn search_by_vector(
&self,
query_embedding: &[f32],
top_k: usize,
context: &AccessContext,
) -> Result<Vec<(Uuid, f32)>> {
self.backend
.search_by_vector(query_embedding, top_k, context)
.await
}
pub async fn stats(&self, context: &AccessContext) -> Result<StorageStats> {
self.backend.stats(context).await
}
}
pub mod benchmarks;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddedStorageConfig {
pub data_path: PathBuf,
pub collection_name: String,
pub vector_size: usize,
pub require_qdrant: bool,
pub qdrant_url: String,
}
impl Default for EmbeddedStorageConfig {
fn default() -> Self {
Self {
data_path: dirs::data_local_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("reasonkit")
.join("storage"),
collection_name: "reasonkit_default".to_string(),
vector_size: 1536, require_qdrant: false,
qdrant_url: "http://localhost:6333".to_string(),
}
}
}
impl EmbeddedStorageConfig {
pub fn file_only(data_path: PathBuf) -> Self {
Self {
data_path,
require_qdrant: false,
..Default::default()
}
}
pub fn with_qdrant(qdrant_url: &str, collection_name: &str, vector_size: usize) -> Self {
Self {
qdrant_url: qdrant_url.to_string(),
collection_name: collection_name.to_string(),
vector_size,
require_qdrant: true,
..Default::default()
}
}
}
pub async fn create_embedded_storage(config: EmbeddedStorageConfig) -> Result<Storage> {
if !config.data_path.exists() {
std::fs::create_dir_all(&config.data_path).map_err(|e| {
Error::io(format!(
"Failed to create storage directory {:?}: {}",
config.data_path, e
))
})?;
tracing::info!(path = ?config.data_path, "Created storage data directory");
}
if config.require_qdrant {
match check_qdrant_health(&config.qdrant_url).await {
Ok(()) => {
tracing::info!(url = %config.qdrant_url, "Connected to Qdrant server");
let (host, port) = parse_qdrant_url(&config.qdrant_url);
return Storage::qdrant(
&host,
port,
port + 1, config.collection_name,
config.vector_size,
true, )
.await;
}
Err(e) => {
tracing::warn!(
error = %e,
url = %config.qdrant_url,
"Qdrant not available, require_qdrant=true will fail"
);
return Err(Error::io(format!(
"Qdrant required but not available at {}: {}",
config.qdrant_url, e
)));
}
}
}
tracing::info!(path = ?config.data_path, "Using file-based storage (Qdrant not required)");
Storage::file(config.data_path).await
}
async fn check_qdrant_health(url: &str) -> Result<()> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(5))
.build()
.map_err(|e| Error::io(format!("Failed to create HTTP client: {}", e)))?;
let normalized_url = url.trim_end_matches('/');
let base_url =
if normalized_url.starts_with("http://") || normalized_url.starts_with("https://") {
normalized_url.to_string()
} else {
format!("http://{}", normalized_url)
};
let health_url = format!("{}/readyz", base_url);
let response = client
.get(&health_url)
.send()
.await
.map_err(|e| Error::io(format!("Qdrant health check failed: {}", e)))?;
if response.status().is_success() {
tracing::debug!(url = %base_url, "Qdrant health check passed");
Ok(())
} else {
Err(Error::io(format!(
"Qdrant health check returned status: {}",
response.status()
)))
}
}
fn parse_qdrant_url(url: &str) -> (String, u16) {
let url = url
.trim_start_matches("http://")
.trim_start_matches("https://");
let parts: Vec<&str> = url.split(':').collect();
let host = parts.first().unwrap_or(&"localhost").to_string();
let port: u16 = parts.get(1).and_then(|p| p.parse().ok()).unwrap_or(6333);
(host, port)
}
pub fn default_storage_path() -> PathBuf {
dirs::data_local_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("reasonkit")
.join("storage")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedded_config_default() {
let config = EmbeddedStorageConfig::default();
assert!(!config.require_qdrant);
assert_eq!(config.vector_size, 1536);
assert_eq!(config.collection_name, "reasonkit_default");
}
#[test]
fn test_embedded_config_file_only() {
let config = EmbeddedStorageConfig::file_only(PathBuf::from("/tmp/test"));
assert!(!config.require_qdrant);
assert_eq!(config.data_path, PathBuf::from("/tmp/test"));
}
#[test]
fn test_embedded_config_with_qdrant() {
let config =
EmbeddedStorageConfig::with_qdrant("http://localhost:6334", "test_collection", 768);
assert!(config.require_qdrant);
assert_eq!(config.qdrant_url, "http://localhost:6334");
assert_eq!(config.collection_name, "test_collection");
assert_eq!(config.vector_size, 768);
}
#[test]
fn test_default_storage_path() {
let path = default_storage_path();
assert!(path.ends_with("reasonkit/storage") || path.ends_with("reasonkit\\storage"));
}
#[test]
fn test_memory_entry_creation() {
let entry = MemoryEntry::new("Test content");
assert!(!entry.content.is_empty());
assert_eq!(entry.importance, 0.5);
assert_eq!(entry.layer, MemoryLayer::Pending);
}
#[test]
fn test_memory_entry_builder() {
let entry = MemoryEntry::new("Test content")
.with_importance(0.9)
.with_metadata("key", "value")
.with_ttl(3600)
.with_tags(vec!["tag1".to_string(), "tag2".to_string()]);
assert_eq!(entry.importance, 0.9);
assert_eq!(entry.metadata.get("key"), Some(&"value".to_string()));
assert_eq!(entry.ttl_secs, Some(3600));
assert_eq!(entry.tags.len(), 2);
}
#[test]
fn test_dual_layer_config_default() {
let config = DualLayerConfig::default();
assert_eq!(config.sync.interval_secs, 60);
assert!(config.sync.auto_sync_enabled);
assert!(config.hot.max_entries > 0);
}
#[test]
fn test_dual_layer_config_low_latency() {
let config = DualLayerConfig::high_performance(PathBuf::from("/tmp"));
assert!(config.hot.max_entries > 0);
assert!(config.sync.interval_secs > 0);
}
#[test]
fn test_dual_layer_config_memory_efficient() {
let config = DualLayerConfig::low_memory(PathBuf::from("/tmp"));
assert!(config.hot.max_entries > 0);
assert_eq!(config.sync.hot_to_cold_age_secs, 300);
}
#[tokio::test]
async fn test_in_memory_storage() {
use crate::{DocumentType, Source, SourceType};
use chrono::Utc;
let storage = Storage::in_memory();
let context = AccessContext::new(
"test_user".to_string(),
AccessLevel::Admin,
"test".to_string(),
);
let source = Source {
source_type: SourceType::Local,
url: None,
path: Some("test.md".to_string()),
arxiv_id: None,
github_repo: None,
retrieved_at: Utc::now(),
version: None,
};
let doc =
Document::new(DocumentType::Note, source).with_content("Test content".to_string());
storage.store_document(&doc, &context).await.unwrap();
let retrieved = storage.get_document(&doc.id, &context).await.unwrap();
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().content.raw, "Test content");
}
#[tokio::test]
async fn test_file_storage_creation() {
let temp_dir = std::env::temp_dir().join("reasonkit_storage_test");
if temp_dir.exists() {
std::fs::remove_dir_all(&temp_dir).ok();
}
let storage = Storage::file(temp_dir.clone()).await.unwrap();
let context = AccessContext::new(
"test_user".to_string(),
AccessLevel::Admin,
"test".to_string(),
);
let stats = storage.stats(&context).await.unwrap();
assert_eq!(stats.document_count, 0);
std::fs::remove_dir_all(&temp_dir).ok();
}
#[tokio::test]
async fn test_embedded_storage_file_fallback() {
let temp_dir = std::env::temp_dir().join("reasonkit_embedded_test");
if temp_dir.exists() {
std::fs::remove_dir_all(&temp_dir).ok();
}
let config = EmbeddedStorageConfig::file_only(temp_dir.clone());
let storage = create_embedded_storage(config).await.unwrap();
let context = AccessContext::new(
"test_user".to_string(),
AccessLevel::Admin,
"test".to_string(),
);
let stats = storage.stats(&context).await.unwrap();
assert_eq!(stats.document_count, 0);
std::fs::remove_dir_all(&temp_dir).ok();
}
#[test]
fn test_parse_qdrant_url() {
assert_eq!(
parse_qdrant_url("http://localhost:6333"),
("localhost".to_string(), 6333)
);
assert_eq!(
parse_qdrant_url("localhost:6333"),
("localhost".to_string(), 6333)
);
assert_eq!(
parse_qdrant_url("localhost"),
("localhost".to_string(), 6333)
);
assert_eq!(
parse_qdrant_url("127.0.0.1:6334"),
("127.0.0.1".to_string(), 6334)
);
assert_eq!(
parse_qdrant_url("https://qdrant.example.com:6333"),
("qdrant.example.com".to_string(), 6333)
);
}
#[tokio::test]
async fn test_embedded_storage_default_config() {
let temp_dir = std::env::temp_dir().join("reasonkit_embedded_default_test");
if temp_dir.exists() {
std::fs::remove_dir_all(&temp_dir).ok();
}
let config = EmbeddedStorageConfig {
data_path: temp_dir.clone(),
..Default::default()
};
let storage = create_embedded_storage(config).await.unwrap();
let context = AccessContext::new(
"test_user".to_string(),
AccessLevel::Admin,
"test".to_string(),
);
let stats = storage.stats(&context).await.unwrap();
assert_eq!(stats.document_count, 0);
std::fs::remove_dir_all(&temp_dir).ok();
}
#[tokio::test]
async fn test_embedded_storage_with_qdrant_required_but_unavailable() {
let temp_dir = std::env::temp_dir().join("reasonkit_embedded_qdrant_test");
if temp_dir.exists() {
std::fs::remove_dir_all(&temp_dir).ok();
}
let config = EmbeddedStorageConfig::with_qdrant(
"http://localhost:99999", "test_collection",
768,
);
let mut config = config;
config.data_path = temp_dir.clone();
match create_embedded_storage(config).await {
Ok(_) => panic!("Expected error when Qdrant is required but unavailable"),
Err(e) => {
let error_msg = e.to_string();
assert!(
error_msg.contains("Qdrant required but not available"),
"Error message should mention Qdrant not available, got: {}",
error_msg
);
}
}
std::fs::remove_dir_all(&temp_dir).ok();
}
}