use super::*;
use do_memory_core::StorageBackend;
use std::sync::Arc;
use tempfile::TempDir;
fn create_test_embedding_384() -> Vec<f32> {
let mut embedding = Vec::with_capacity(384);
for i in 0..384 {
embedding.push(0.01_f32 * (i as f32 % 100.0 + 1.0));
}
embedding
}
fn create_test_embedding_384_with_seed(seed: f32) -> Vec<f32> {
let mut embedding = Vec::with_capacity(384);
for i in 0..384 {
embedding.push(seed + 0.001_f32 * (i as f32));
}
embedding
}
async fn create_test_storage() -> Result<(TursoStorage, TempDir)> {
let dir = TempDir::new().unwrap();
let db_path = dir.path().join("test.db");
let db = libsql::Builder::new_local(&db_path)
.build()
.await
.map_err(|e| Error::Storage(format!("Failed to create test database: {}", e)))?;
let storage = TursoStorage {
db: Arc::new(db),
pool: None,
#[cfg(feature = "keepalive-pool")]
keepalive_pool: None,
adaptive_pool: None,
caching_pool: None,
prepared_cache: Arc::new(crate::PreparedStatementCache::with_config(
crate::PreparedCacheConfig::default(),
)),
config: TursoConfig::default(),
#[cfg(feature = "compression")]
compression_stats: Arc::new(std::sync::Mutex::new(
crate::CompressionStatistics::default(),
)),
#[cfg(feature = "adaptive-ttl")]
episode_cache: None,
};
storage.initialize_schema().await?;
Ok((storage, dir))
}
#[tokio::test]
async fn test_storage_creation() {
let result = create_test_storage().await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_health_check() {
let (storage, _dir) = create_test_storage().await.unwrap();
let healthy = storage.health_check().await.unwrap();
assert!(healthy);
}
#[tokio::test]
async fn test_statistics() {
let (storage, _dir) = create_test_storage().await.unwrap();
let stats = storage.get_statistics().await.unwrap();
assert_eq!(stats.episode_count, 0);
assert_eq!(stats.pattern_count, 0);
assert_eq!(stats.heuristic_count, 0);
}
#[tokio::test]
async fn test_store_and_get_embedding() {
let (storage, _dir) = create_test_storage().await.unwrap();
let id = "test_embedding_1";
let embedding = create_test_embedding_384();
storage
.store_embedding(id, embedding.clone())
.await
.unwrap();
let retrieved = storage.get_embedding(id).await.unwrap();
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap(), embedding);
}
#[tokio::test]
async fn test_get_nonexistent_embedding() {
let (storage, _dir) = create_test_storage().await.unwrap();
let retrieved = storage.get_embedding("nonexistent").await.unwrap();
assert!(retrieved.is_none());
}
#[tokio::test]
async fn test_delete_embedding() {
let (storage, _dir) = create_test_storage().await.unwrap();
let id = "test_embedding_delete";
let embedding = create_test_embedding_384();
storage
.store_embedding(id, embedding.clone())
.await
.unwrap();
let retrieved = storage.get_embedding(id).await.unwrap();
assert!(retrieved.is_some());
let deleted = storage.delete_embedding(id).await.unwrap();
assert!(deleted);
let retrieved = storage.get_embedding(id).await.unwrap();
assert!(retrieved.is_none());
}
#[tokio::test]
async fn test_delete_nonexistent_embedding() {
let (storage, _dir) = create_test_storage().await.unwrap();
let deleted = storage.delete_embedding("nonexistent").await.unwrap();
assert!(!deleted);
}
#[tokio::test]
async fn test_store_embeddings_batch() {
let (storage, _dir) = create_test_storage().await.unwrap();
let embeddings = vec![
(
"batch_1".to_string(),
create_test_embedding_384_with_seed(0.1),
),
(
"batch_2".to_string(),
create_test_embedding_384_with_seed(0.2),
),
(
"batch_3".to_string(),
create_test_embedding_384_with_seed(0.3),
),
];
storage
.store_embeddings_batch(embeddings.clone())
.await
.unwrap();
for (id, expected_embedding) in &embeddings {
let retrieved = storage.get_embedding(id).await.unwrap();
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap(), *expected_embedding);
}
}
#[tokio::test]
async fn test_get_embeddings_batch() {
let (storage, _dir) = create_test_storage().await.unwrap();
let embeddings = vec![
(
"get_batch_1".to_string(),
create_test_embedding_384_with_seed(0.1),
),
(
"get_batch_2".to_string(),
create_test_embedding_384_with_seed(0.2),
),
(
"get_batch_3".to_string(),
create_test_embedding_384_with_seed(0.3),
),
];
storage
.store_embeddings_batch(embeddings.clone())
.await
.unwrap();
let ids = vec![
"get_batch_1".to_string(),
"get_batch_2".to_string(),
"get_batch_3".to_string(),
"nonexistent".to_string(),
];
let results = storage.get_embeddings_batch(&ids).await.unwrap();
assert_eq!(results.len(), 4);
assert!(results[0].is_some());
assert_eq!(results[0].as_ref().unwrap(), &embeddings[0].1);
assert!(results[1].is_some());
assert_eq!(results[1].as_ref().unwrap(), &embeddings[1].1);
assert!(results[2].is_some());
assert_eq!(results[2].as_ref().unwrap(), &embeddings[2].1);
assert!(results[3].is_none()); }
#[tokio::test]
#[cfg(feature = "turso_multi_dimension")]
async fn test_different_embedding_dimensions() {
let (storage, _dir) = create_test_storage().await.unwrap();
let dim_384: Vec<f32> = (0..384).map(|i| i as f32 / 384.0).collect();
let dim_1024: Vec<f32> = (0..1024).map(|i| i as f32 / 1024.0).collect();
let dim_1536: Vec<f32> = (0..1536).map(|i| i as f32 / 1536.0).collect();
storage
.store_embedding("dim_384", dim_384.clone())
.await
.unwrap();
storage
.store_embedding("dim_1024", dim_1024.clone())
.await
.unwrap();
storage
.store_embedding("dim_1536", dim_1536.clone())
.await
.unwrap();
let retrieved_384 = storage.get_embedding("dim_384").await.unwrap();
assert!(retrieved_384.is_some());
assert_eq!(retrieved_384.unwrap().len(), 384);
let retrieved_1024 = storage.get_embedding("dim_1024").await.unwrap();
assert!(retrieved_1024.is_some());
assert_eq!(retrieved_1024.unwrap().len(), 1024);
let retrieved_1536 = storage.get_embedding("dim_1536").await.unwrap();
assert!(retrieved_1536.is_some());
assert_eq!(retrieved_1536.unwrap().len(), 1536);
}
#[tokio::test]
async fn test_update_existing_embedding() {
let (storage, _dir) = create_test_storage().await.unwrap();
let id = "update_test";
let embedding_v1 = create_test_embedding_384_with_seed(0.1); let embedding_v2 = create_test_embedding_384_with_seed(0.9);
storage
.store_embedding(id, embedding_v1.clone())
.await
.unwrap();
let retrieved = storage.get_embedding(id).await.unwrap();
assert_eq!(retrieved.unwrap(), embedding_v1);
storage
.store_embedding(id, embedding_v2.clone())
.await
.unwrap();
let retrieved = storage.get_embedding(id).await.unwrap();
assert_eq!(retrieved.unwrap(), embedding_v2);
}
#[tokio::test]
async fn test_empty_embeddings_batch() {
let (storage, _dir) = create_test_storage().await.unwrap();
storage.store_embeddings_batch(vec![]).await.unwrap();
let results = storage.get_embeddings_batch(&[]).await.unwrap();
assert!(results.is_empty());
}
#[cfg(feature = "compression")]
mod compression_tests {
use super::*;
use do_memory_core::StorageBackend;
#[tokio::test]
async fn test_large_episode_compression() {
let (storage, _dir) = create_test_storage().await.unwrap();
let mut steps = Vec::new();
for i in 0..100 {
steps.push(do_memory_core::episode::ExecutionStep {
step_number: i,
tool: format!("tool_{}", i % 10),
action: format!("action_{}", i),
parameters: serde_json::json!({
"param": format!("value_{}", i),
"data": "x".repeat(100) }),
result: Some(do_memory_core::types::ExecutionResult::Success {
output: format!("output_{}", i),
}),
latency_ms: i as u64,
timestamp: chrono::Utc::now(),
tokens_used: None,
metadata: std::collections::HashMap::new(),
});
}
let episode = do_memory_core::Episode {
episode_id: uuid::Uuid::new_v4(),
task_type: do_memory_core::TaskType::CodeGeneration,
task_description: "Test large episode compression".to_string(),
context: do_memory_core::TaskContext {
domain: "test".to_string(),
language: Some("rust".to_string()),
framework: None,
complexity: do_memory_core::types::ComplexityLevel::Complex,
tags: vec!["compression".to_string()],
},
steps,
outcome: None,
reward: None,
reflection: None,
patterns: vec![],
heuristics: vec![],
applied_patterns: vec![],
salient_features: None,
tags: vec![],
checkpoints: vec![],
start_time: chrono::Utc::now(),
end_time: None,
metadata: std::collections::HashMap::new(),
};
storage.store_episode(&episode).await.unwrap();
let retrieved = storage.get_episode(episode.episode_id).await.unwrap();
assert!(retrieved.is_some());
let retrieved_episode = retrieved.unwrap();
assert_eq!(retrieved_episode.steps.len(), 100);
assert_eq!(retrieved_episode.task_description, episode.task_description);
}
#[tokio::test]
async fn test_small_episode_no_compression() {
let (storage, _dir) = create_test_storage().await.unwrap();
let episode = do_memory_core::Episode {
episode_id: uuid::Uuid::new_v4(),
task_type: do_memory_core::TaskType::Analysis,
task_description: "Test small episode without compression".to_string(),
context: do_memory_core::TaskContext {
domain: "test".to_string(),
language: Some("rust".to_string()),
framework: None,
complexity: do_memory_core::types::ComplexityLevel::Simple,
tags: vec![],
},
steps: vec![],
outcome: None,
reward: None,
reflection: None,
patterns: vec![],
heuristics: vec![],
applied_patterns: vec![],
salient_features: None,
tags: vec![],
checkpoints: vec![],
start_time: chrono::Utc::now(),
end_time: None,
metadata: std::collections::HashMap::new(),
};
storage.store_episode(&episode).await.unwrap();
let retrieved = storage.get_episode(episode.episode_id).await.unwrap();
assert!(retrieved.is_some());
assert_eq!(
retrieved.unwrap().task_description,
episode.task_description
);
}
#[tokio::test]
async fn test_embedding_compression() {
let (storage, _dir) = create_test_storage().await.unwrap();
let embedding: Vec<f32> = (0..384).map(|i| (i as f32 / 384.0).sin()).collect();
storage
.store_embedding("test_compressed_embedding", embedding.clone())
.await
.unwrap();
let retrieved = storage
.get_embedding("test_compressed_embedding")
.await
.unwrap();
assert!(retrieved.is_some());
let retrieved_embedding = retrieved.unwrap();
assert_eq!(retrieved_embedding.len(), 384);
for (original, retrieved) in embedding.iter().zip(retrieved_embedding.iter()) {
assert!((original - retrieved).abs() < 1e-5);
}
}
#[cfg(feature = "compression")]
#[tokio::test]
async fn test_compression_statistics() {
use crate::CompressionStatistics;
let mut stats = CompressionStatistics::new();
stats.record_compression(1000, 400, 50);
stats.record_compression(2000, 800, 100);
stats.record_skipped();
stats.record_decompression(75);
assert_eq!(stats.total_original_bytes, 3000);
assert_eq!(stats.total_compressed_bytes, 1200);
assert_eq!(stats.compression_count, 2);
assert_eq!(stats.skipped_count, 1);
assert_eq!(stats.compression_time_us, 150);
assert_eq!(stats.decompression_time_us, 75);
let ratio = stats.compression_ratio();
assert!(ratio > 0.35 && ratio < 0.45);
let savings = stats.bandwidth_savings_percent();
assert!(savings > 55.0 && savings < 65.0);
}
}