use crate::error::{MemError, MemResult};
use crate::storage::{
AccessContext, AccessControlConfig, AccessLevel, EmbeddingCacheConfig, QdrantConnectionConfig,
QdrantSecurityConfig, StorageBackend, StorageStats,
};
use crate::{Document, Result};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::future::Future;
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use uuid::Uuid;
pub use super::config::DualLayerConfig;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HotMemoryConfig {
pub max_capacity_bytes: u64,
pub max_entries: usize,
pub ttl_secs: u64,
pub eviction_policy: EvictionPolicy,
pub backend: HotBackendType,
pub compression_enabled: bool,
pub compression_level: u8,
}
impl Default for HotMemoryConfig {
fn default() -> Self {
Self {
max_capacity_bytes: 1024 * 1024 * 1024, max_entries: 100_000,
ttl_secs: 3600, eviction_policy: EvictionPolicy::Lru,
backend: HotBackendType::InMemory,
compression_enabled: false,
compression_level: 3,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum EvictionPolicy {
#[default]
Lru,
Lfu,
Fifo,
TimeOnly,
Adaptive,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub enum HotBackendType {
#[default]
InMemory,
Mmap {
path: PathBuf,
},
RocksDb {
path: PathBuf,
},
Redis {
url: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ColdMemoryConfig {
pub backend: ColdBackendType,
pub vector_size: usize,
pub collection_name: String,
pub quantization_enabled: bool,
pub quantization_type: QuantizationType,
pub connection: QdrantConnectionConfig,
pub embedding_cache: EmbeddingCacheConfig,
pub access_control: AccessControlConfig,
}
impl Default for ColdMemoryConfig {
fn default() -> Self {
Self {
backend: ColdBackendType::File {
base_path: crate::storage::default_storage_path(),
},
vector_size: 1536,
collection_name: "reasonkit_cold".to_string(),
quantization_enabled: true,
quantization_type: QuantizationType::Int8,
connection: QdrantConnectionConfig::default(),
embedding_cache: EmbeddingCacheConfig::default(),
access_control: AccessControlConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ColdBackendType {
File {
base_path: PathBuf,
},
QdrantLocal {
url: String,
},
QdrantCloud {
url: String,
api_key: String,
},
S3 {
endpoint: String,
bucket: String,
access_key: String,
secret_key: String,
},
}
impl Default for ColdBackendType {
fn default() -> Self {
Self::File {
base_path: crate::storage::default_storage_path(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum QuantizationType {
None,
#[default]
Int8,
Binary,
ProductQuantization {
segments: usize,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WalConfig {
pub enabled: bool,
pub path: PathBuf,
pub max_file_size: u64,
pub sync_mode: WalSyncMode,
pub sync_interval_ms: u64,
pub retention_secs: u64,
pub compression_enabled: bool,
pub checkpoint_threshold: usize,
}
impl Default for WalConfig {
fn default() -> Self {
Self {
enabled: true,
path: crate::storage::default_storage_path().join("wal"),
max_file_size: 64 * 1024 * 1024, sync_mode: WalSyncMode::Interval,
sync_interval_ms: 1000,
retention_secs: 86400, compression_enabled: true,
checkpoint_threshold: 10000,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum WalSyncMode {
Immediate,
#[default]
Interval,
OsManaged,
None,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SyncConfig {
pub tiering_policy: TieringPolicy,
pub background_sync_interval_secs: u64,
pub sync_batch_size: usize,
pub auto_promote: bool,
pub promotion_threshold: usize,
pub demotion_threshold_secs: u64,
pub parallel_sync: bool,
pub max_concurrent_syncs: usize,
}
impl Default for SyncConfig {
fn default() -> Self {
Self {
tiering_policy: TieringPolicy::default(),
background_sync_interval_secs: 60,
sync_batch_size: 100,
auto_promote: true,
promotion_threshold: 3,
demotion_threshold_secs: 3600,
parallel_sync: true,
max_concurrent_syncs: 4,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TieringPolicy {
pub default_tier: StorageTier,
pub access_based: bool,
pub age_based: bool,
pub size_based: bool,
pub size_threshold_bytes: u64,
pub age_threshold_secs: u64,
}
impl Default for TieringPolicy {
fn default() -> Self {
Self {
default_tier: StorageTier::Hot,
access_based: true,
age_based: true,
size_based: true,
size_threshold_bytes: 10 * 1024 * 1024, age_threshold_secs: 86400, }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
pub enum StorageTier {
#[default]
Hot,
Cold,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DualLayerStats {
pub hot: TierStats,
pub cold: TierStats,
pub wal: WalStats,
pub overall: OverallStats,
pub tiering: TieringStats,
pub performance: PerformanceMetrics,
pub collected_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TierStats {
pub document_count: usize,
pub chunk_count: usize,
pub embedding_count: usize,
pub size_bytes: u64,
pub capacity_used_pct: f32,
pub hit_rate: f32,
pub avg_latency_us: u64,
pub p95_latency_us: u64,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct WalStats {
pub enabled: bool,
pub current_size_bytes: u64,
pub pending_entries: usize,
pub last_sync_at: Option<DateTime<Utc>>,
pub last_checkpoint_at: Option<DateTime<Utc>>,
pub write_throughput: f64,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct OverallStats {
pub total_documents: usize,
pub total_chunks: usize,
pub total_embeddings: usize,
pub total_size_bytes: u64,
pub uptime_secs: u64,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TieringStats {
pub hot_count: usize,
pub cold_count: usize,
pub promotions_last_hour: usize,
pub demotions_last_hour: usize,
pub pending_promotions: usize,
pub pending_demotions: usize,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PerformanceMetrics {
pub reads_per_sec: f64,
pub writes_per_sec: f64,
pub avg_read_latency_us: u64,
pub avg_write_latency_us: u64,
pub cache_hit_ratio: f32,
pub hot_hit_ratio: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WalStatus {
pub enabled: bool,
pub healthy: bool,
pub current_file: Option<PathBuf>,
pub pending_entries: usize,
pub ms_since_last_sync: u64,
pub ms_since_last_checkpoint: u64,
pub errors: Vec<String>,
}
impl Default for WalStatus {
fn default() -> Self {
Self {
enabled: false,
healthy: true,
current_file: None,
pending_entries: 0,
ms_since_last_sync: 0,
ms_since_last_checkpoint: 0,
errors: Vec::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CheckpointResult {
pub entries_checkpointed: usize,
pub bytes_written: u64,
pub duration: Duration,
pub errors: Vec<String>,
}
impl Default for CheckpointResult {
fn default() -> Self {
Self {
entries_checkpointed: 0,
bytes_written: 0,
duration: Duration::ZERO,
errors: Vec::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TieringResult {
pub promoted: usize,
pub demoted: usize,
pub skipped: usize,
pub duration: Duration,
}
impl Default for TieringResult {
fn default() -> Self {
Self {
promoted: 0,
demoted: 0,
skipped: 0,
duration: Duration::ZERO,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct BulkOperationResult {
pub succeeded: usize,
pub failed: usize,
pub errors: Vec<(Uuid, String)>,
pub duration: Duration,
}
#[derive(Debug, thiserror::Error)]
pub enum DualLayerError {
#[error("Hot layer error: {0}")]
HotLayer(String),
#[error("Cold layer error: {0}")]
ColdLayer(String),
#[error("WAL error: {0}")]
Wal(String),
#[error("Document not found: {0}")]
NotFound(Uuid),
#[error("Tier routing error: {0}")]
TierRouting(String),
#[error("Sync error: {0}")]
Sync(String),
#[error("Transaction error: {0}")]
Transaction(String),
#[error("Capacity exceeded: {message}")]
CapacityExceeded { message: String, tier: StorageTier },
#[error("Connection error: {0}")]
Connection(String),
#[error("Operation timed out after {duration_ms}ms")]
Timeout { duration_ms: u64 },
#[error("Invalid configuration: {0}")]
InvalidConfig(String),
#[error("Invalid input: {0}")]
InvalidInput(String),
#[error("Access denied: {0}")]
AccessDenied(String),
#[error("Storage error: {0}")]
Storage(#[from] MemError),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Serialization error: {0}")]
Serialization(#[from] serde_json::Error),
}
impl DualLayerError {
pub fn is_retryable(&self) -> bool {
matches!(
self,
DualLayerError::Connection(_)
| DualLayerError::Timeout { .. }
| DualLayerError::Sync(_)
)
}
pub fn may_cause_data_loss(&self) -> bool {
matches!(
self,
DualLayerError::Wal(_) | DualLayerError::Sync(_) | DualLayerError::Transaction(_)
)
}
}
pub type DualLayerResult<T> = std::result::Result<T, DualLayerError>;
#[async_trait]
pub trait HotMemoryLayer: Send + Sync {
async fn store(&mut self, doc: &Document) -> Result<()>;
async fn get(&self, id: &Uuid) -> Result<Option<Document>>;
async fn delete(&mut self, id: &Uuid) -> Result<bool>;
async fn contains(&self, id: &Uuid) -> bool;
async fn keys(&self) -> Vec<Uuid>;
fn stats(&self) -> TierStats;
async fn clear(&mut self);
async fn evict(&mut self, count: usize) -> usize;
}
#[async_trait]
pub trait DualLayerBackend: StorageBackend {
async fn store_document_in_tier(
&self,
doc: &Document,
tier: StorageTier,
context: &AccessContext,
) -> Result<()>;
async fn get_document_with_tier(
&self,
id: &Uuid,
context: &AccessContext,
) -> Result<Option<(Document, StorageTier)>>;
async fn move_to_tier(
&self,
id: &Uuid,
target_tier: StorageTier,
context: &AccessContext,
) -> Result<()>;
async fn get_tier(&self, id: &Uuid) -> Result<Option<StorageTier>>;
async fn store_documents_bulk(
&self,
docs: &[Document],
context: &AccessContext,
) -> Result<BulkOperationResult>;
async fn get_documents_bulk(
&self,
ids: &[Uuid],
context: &AccessContext,
) -> Result<Vec<(Uuid, Option<Document>)>>;
async fn delete_documents_bulk(
&self,
ids: &[Uuid],
context: &AccessContext,
) -> Result<BulkOperationResult>;
async fn store_embeddings_bulk(
&self,
embeddings: &[(Uuid, Vec<f32>)],
context: &AccessContext,
) -> Result<BulkOperationResult>;
async fn sync_wal(&self) -> Result<()>;
async fn checkpoint(&self) -> Result<CheckpointResult>;
async fn wal_status(&self) -> Result<WalStatus>;
async fn run_tiering(&self) -> Result<TieringResult>;
async fn get_promotion_candidates(&self, limit: usize) -> Result<Vec<Uuid>>;
async fn get_demotion_candidates(&self, limit: usize) -> Result<Vec<Uuid>>;
async fn detailed_stats(&self, context: &AccessContext) -> Result<DualLayerStats>;
async fn hot_stats(&self) -> Result<TierStats>;
async fn cold_stats(&self, context: &AccessContext) -> Result<TierStats>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransactionState {
Active,
Committed,
RolledBack,
Failed,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IsolationLevel {
ReadUncommitted,
ReadCommitted,
RepeatableRead,
Serializable,
}
#[derive(Debug, Clone)]
pub enum TransactionOperation {
StoreDocument {
doc: Box<Document>,
tier: StorageTier,
},
DeleteDocument {
id: Uuid,
},
StoreEmbedding {
chunk_id: Uuid,
embedding: Vec<f32>,
},
MoveTier {
id: Uuid,
from: StorageTier,
to: StorageTier,
},
}
#[derive(Debug, Clone)]
pub struct Savepoint {
pub name: String,
pub operation_index: usize,
pub timestamp: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransactionResult {
pub transaction_id: Uuid,
pub operations_succeeded: usize,
pub operations_failed: usize,
pub duration: Duration,
}
#[derive(Debug, Default)]
pub struct TierIndex {
tiers: HashMap<Uuid, StorageTier>,
}
impl TierIndex {
pub fn new() -> Self {
Self::default()
}
pub fn set_tier(&mut self, id: Uuid, tier: StorageTier) {
self.tiers.insert(id, tier);
}
pub fn get_tier(&self, id: &Uuid) -> Option<StorageTier> {
self.tiers.get(id).copied()
}
pub fn remove(&mut self, id: &Uuid) -> Option<StorageTier> {
self.tiers.remove(id)
}
pub fn all_ids(&self) -> Vec<Uuid> {
self.tiers.keys().copied().collect()
}
pub fn count_by_tier(&self) -> (usize, usize) {
let hot = self
.tiers
.values()
.filter(|&&t| t == StorageTier::Hot)
.count();
let cold = self
.tiers
.values()
.filter(|&&t| t == StorageTier::Cold)
.count();
(hot, cold)
}
pub fn ids_in_tier(&self, tier: StorageTier) -> Vec<Uuid> {
self.tiers
.iter()
.filter(|(_, &t)| t == tier)
.map(|(id, _)| *id)
.collect()
}
}
#[derive(Debug, Clone)]
pub struct AccessEntry {
pub access_count: usize,
pub last_access: Instant,
pub first_access: Instant,
}
#[derive(Debug)]
pub struct AccessTracker {
entries: HashMap<Uuid, AccessEntry>,
policy: TieringPolicy,
}
impl AccessTracker {
pub fn new(policy: &TieringPolicy) -> Self {
Self {
entries: HashMap::new(),
policy: policy.clone(),
}
}
pub fn record_access(&mut self, id: &Uuid) {
let now = Instant::now();
self.entries
.entry(*id)
.and_modify(|e| {
e.access_count += 1;
e.last_access = now;
})
.or_insert(AccessEntry {
access_count: 1,
last_access: now,
first_access: now,
});
}
pub fn remove(&mut self, id: &Uuid) {
self.entries.remove(id);
}
pub fn get_access_count(&self, id: &Uuid) -> usize {
self.entries.get(id).map(|e| e.access_count).unwrap_or(0)
}
pub fn get_promotion_candidates(&self, current_cold_ids: &[Uuid], limit: usize) -> Vec<Uuid> {
let mut candidates: Vec<_> = current_cold_ids
.iter()
.filter_map(|id| self.entries.get(id).map(|entry| (*id, entry.access_count)))
.filter(|(_, count)| *count >= self.policy.promotion_threshold_count())
.collect();
candidates.sort_by(|a, b| b.1.cmp(&a.1));
candidates
.into_iter()
.take(limit)
.map(|(id, _)| id)
.collect()
}
pub fn get_demotion_candidates(&self, current_hot_ids: &[Uuid], limit: usize) -> Vec<Uuid> {
let now = Instant::now();
let threshold = Duration::from_secs(self.policy.age_threshold_secs);
let mut candidates: Vec<_> = current_hot_ids
.iter()
.filter_map(|id| {
self.entries.get(id).and_then(|entry| {
if now.duration_since(entry.last_access) > threshold {
Some((*id, entry.last_access))
} else {
None
}
})
})
.collect();
candidates.sort_by(|a, b| a.1.cmp(&b.1));
candidates
.into_iter()
.take(limit)
.map(|(id, _)| id)
.collect()
}
}
impl TieringPolicy {
fn promotion_threshold_count(&self) -> usize {
3 }
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum WalEntry {
StoreDocument {
id: Uuid,
tier: StorageTier,
timestamp: DateTime<Utc>,
},
DeleteDocument { id: Uuid, timestamp: DateTime<Utc> },
StoreEmbedding {
chunk_id: Uuid,
timestamp: DateTime<Utc>,
},
MoveTier {
id: Uuid,
from: StorageTier,
to: StorageTier,
timestamp: DateTime<Utc>,
},
TransactionBegin { id: Uuid, timestamp: DateTime<Utc> },
TransactionCommit { id: Uuid, timestamp: DateTime<Utc> },
TransactionRollback { id: Uuid, timestamp: DateTime<Utc> },
Checkpoint { timestamp: DateTime<Utc> },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextResult {
pub id: Uuid,
pub content: String,
pub score: f32,
pub layer: StorageTier,
pub created_at: DateTime<Utc>,
pub access_count: usize,
}
pub struct DualLayerStorage {
hot: Arc<RwLock<InMemoryHotLayer>>,
cold: Arc<dyn StorageBackend>,
tier_index: Arc<RwLock<TierIndex>>,
access_tracker: Arc<RwLock<AccessTracker>>,
config: DualLayerConfig,
started_at: Instant,
}
struct InMemoryHotLayer {
documents: HashMap<Uuid, Document>,
embeddings: HashMap<Uuid, Vec<f32>>,
content: HashMap<Uuid, String>,
metadata: HashMap<Uuid, HotEntryMeta>,
config: HotMemoryConfig,
}
#[derive(Debug, Clone)]
struct HotEntryMeta {
created_at: DateTime<Utc>,
last_accessed: Instant,
access_count: usize,
size_bytes: usize,
}
impl InMemoryHotLayer {
fn new(config: super::config::HotMemoryConfig) -> Self {
let dual_config = HotMemoryConfig {
max_capacity_bytes: 0, max_entries: config.max_entries,
ttl_secs: config.ttl_secs,
eviction_policy: EvictionPolicy::Lru, backend: HotBackendType::InMemory, compression_enabled: false, compression_level: 1, };
Self {
documents: HashMap::new(),
embeddings: HashMap::new(),
content: HashMap::new(),
metadata: HashMap::new(),
config: dual_config,
}
}
fn insert_content(&mut self, id: Uuid, content: String, embedding: Vec<f32>) {
let size = content.len() + embedding.len() * 4;
while self.documents.len() >= self.config.max_entries && !self.documents.is_empty() {
self.evict_one();
}
self.content.insert(id, content);
self.embeddings.insert(id, embedding);
self.metadata.insert(
id,
HotEntryMeta {
created_at: Utc::now(),
last_accessed: Instant::now(),
access_count: 0,
size_bytes: size,
},
);
}
fn get_content(&mut self, id: &Uuid) -> Option<(String, Vec<f32>)> {
if let Some(meta) = self.metadata.get_mut(id) {
meta.last_accessed = Instant::now();
meta.access_count += 1;
}
let content = self.content.get(id)?;
let embedding = self.embeddings.get(id)?;
Some((content.clone(), embedding.clone()))
}
fn remove(&mut self, id: &Uuid) -> bool {
self.content.remove(id);
self.embeddings.remove(id);
self.metadata.remove(id);
self.documents.remove(id).is_some()
}
fn evict_one(&mut self) {
if let Some((&oldest_id, _)) = self
.metadata
.iter()
.min_by(|a, b| a.1.last_accessed.cmp(&b.1.last_accessed))
{
self.remove(&oldest_id);
}
}
fn search(&self, query_embedding: &[f32], limit: usize) -> Vec<(Uuid, f32, String)> {
let mut results: Vec<_> = self
.embeddings
.iter()
.filter_map(|(id, emb)| {
let score = cosine_similarity(query_embedding, emb);
if score > 0.0 {
let content = self.content.get(id)?.clone();
Some((*id, score, content))
} else {
None
}
})
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(limit);
results
}
fn len(&self) -> usize {
self.content.len()
}
fn get_meta(&self, id: &Uuid) -> Option<&HotEntryMeta> {
self.metadata.get(id)
}
}
impl DualLayerStorage {
pub async fn new(config: DualLayerConfig) -> DualLayerResult<Self> {
let hot = Arc::new(RwLock::new(InMemoryHotLayer::new(config.hot.clone())));
let cold_path = config.cold.db_path.clone();
tokio::fs::create_dir_all(&cold_path).await?;
let cold: Arc<dyn StorageBackend> = Arc::new(
crate::storage::FileStorage::new(cold_path)
.await
.map_err(|e| DualLayerError::ColdLayer(e.to_string()))?,
);
let tier_index = Arc::new(RwLock::new(TierIndex::new()));
let default_tiering = TieringPolicy::default();
let access_tracker = Arc::new(RwLock::new(AccessTracker::new(&default_tiering)));
Ok(Self {
hot,
cold,
tier_index,
access_tracker,
config,
started_at: Instant::now(),
})
}
pub async fn default_instance() -> DualLayerResult<Self> {
Self::new(DualLayerConfig::default()).await
}
pub async fn retrieve_context(
&self,
query: &str,
limit: usize,
) -> DualLayerResult<Vec<ContextResult>> {
let query_embedding = generate_query_embedding(query);
let hot_results = {
let hot = self.hot.read().await;
hot.search(&query_embedding, limit * 2)
};
let context = AccessContext::new(
"system".to_string(),
AccessLevel::Read,
"retrieve_context".to_string(),
);
let cold_results = self
.cold
.search_by_vector(&query_embedding, limit * 2, &context)
.await
.map_err(|e| DualLayerError::ColdLayer(e.to_string()))?;
let mut merged: HashMap<Uuid, ContextResult> = HashMap::new();
for (id, score, content) in hot_results {
let hot = self.hot.read().await;
let meta = hot.get_meta(&id);
merged.insert(
id,
ContextResult {
id,
content,
score,
layer: StorageTier::Hot,
created_at: meta.map(|m| m.created_at).unwrap_or_else(Utc::now),
access_count: meta.map(|m| m.access_count).unwrap_or(0),
},
);
}
for (id, score) in cold_results {
merged
.entry(id)
.and_modify(|existing| {
if score > existing.score {
existing.score = score;
}
})
.or_insert_with(|| ContextResult {
id,
content: "[cold storage]".to_string(), score,
layer: StorageTier::Cold,
created_at: Utc::now(),
access_count: 0,
});
}
let mut results: Vec<_> = merged.into_values().collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(limit);
{
let mut tracker = self.access_tracker.write().await;
for r in &results {
tracker.record_access(&r.id);
}
}
Ok(results)
}
pub async fn store(&self, id: Uuid, content: &str, embedding: &[f32]) -> DualLayerResult<()> {
{
let mut hot = self.hot.write().await;
hot.insert_content(id, content.to_string(), embedding.to_vec());
}
{
let mut index = self.tier_index.write().await;
index.set_tier(id, StorageTier::Hot);
}
Ok(())
}
pub async fn delete(&self, id: &Uuid) -> DualLayerResult<bool> {
let removed_hot = {
let mut hot = self.hot.write().await;
hot.remove(id)
};
{
let mut index = self.tier_index.write().await;
index.remove(id);
}
{
let mut tracker = self.access_tracker.write().await;
tracker.remove(id);
}
Ok(removed_hot)
}
pub async fn stats(&self) -> DualLayerResult<DualLayerStats> {
let hot = self.hot.read().await;
let hot_count = hot.len();
let context = AccessContext::new(
"system".to_string(),
AccessLevel::Admin,
"stats".to_string(),
);
let cold_stats = self
.cold
.stats(&context)
.await
.map_err(|e| DualLayerError::ColdLayer(e.to_string()))?;
Ok(DualLayerStats {
hot: TierStats {
document_count: 0,
chunk_count: hot_count,
embedding_count: hot_count,
size_bytes: 0,
capacity_used_pct: (hot_count as f32 / self.config.hot.max_entries as f32) * 100.0,
hit_rate: 0.0,
avg_latency_us: 0,
p95_latency_us: 0,
},
cold: TierStats {
document_count: cold_stats.document_count,
chunk_count: cold_stats.chunk_count,
embedding_count: cold_stats.embedding_count,
size_bytes: cold_stats.size_bytes,
capacity_used_pct: 0.0,
hit_rate: 0.0,
avg_latency_us: 0,
p95_latency_us: 0,
},
wal: WalStats::default(),
overall: OverallStats {
total_documents: cold_stats.document_count,
total_chunks: hot_count + cold_stats.chunk_count,
total_embeddings: hot_count + cold_stats.embedding_count,
total_size_bytes: cold_stats.size_bytes,
uptime_secs: self.started_at.elapsed().as_secs(),
},
tiering: TieringStats::default(),
performance: PerformanceMetrics::default(),
collected_at: Utc::now(),
})
}
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if mag_a == 0.0 || mag_b == 0.0 {
return 0.0;
}
dot / (mag_a * mag_b)
}
fn generate_query_embedding(query: &str) -> Vec<f32> {
let mut embedding = vec![0.0f32; 384];
for (i, ch) in query.chars().enumerate() {
let idx = (ch as usize + i * 31) % 384;
embedding[idx] += 1.0 / (i as f32 + 1.0);
}
let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude > 0.0 {
for e in &mut embedding {
*e /= magnitude;
}
}
embedding
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
let c = vec![0.0, 1.0, 0.0];
assert!(cosine_similarity(&a, &c).abs() < 0.001);
}
#[test]
fn test_generate_embedding() {
let emb = generate_query_embedding("test query");
assert_eq!(emb.len(), 384);
let magnitude: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((magnitude - 1.0).abs() < 0.01);
}
#[test]
fn test_default_config() {
let config = DualLayerConfig::default();
assert!(config.wal.enabled);
}
#[test]
fn test_tier_index() {
let mut index = TierIndex::new();
let id1 = Uuid::new_v4();
let id2 = Uuid::new_v4();
index.set_tier(id1, StorageTier::Hot);
index.set_tier(id2, StorageTier::Cold);
assert_eq!(index.get_tier(&id1), Some(StorageTier::Hot));
assert_eq!(index.get_tier(&id2), Some(StorageTier::Cold));
let (hot, cold) = index.count_by_tier();
assert_eq!(hot, 1);
assert_eq!(cold, 1);
}
#[test]
fn test_access_tracker() {
let policy = TieringPolicy::default();
let mut tracker = AccessTracker::new(&policy);
let id = Uuid::new_v4();
for _ in 0..5 {
tracker.record_access(&id);
}
assert_eq!(tracker.get_access_count(&id), 5);
tracker.remove(&id);
assert_eq!(tracker.get_access_count(&id), 0);
}
#[test]
fn test_error_retryable() {
let timeout_err = DualLayerError::Timeout { duration_ms: 5000 };
assert!(timeout_err.is_retryable());
let not_found = DualLayerError::NotFound(Uuid::new_v4());
assert!(!not_found.is_retryable());
}
#[test]
fn test_eviction_policy_default() {
assert_eq!(EvictionPolicy::default(), EvictionPolicy::Lru);
}
#[test]
fn test_wal_config_default() {
let config = WalConfig::default();
assert!(config.enabled);
assert_eq!(config.sync_mode, WalSyncMode::Interval);
assert_eq!(config.sync_interval_ms, 1000);
}
}