use async_trait::async_trait;
use do_memory_core::memory::attribution::{
RecommendationFeedback, RecommendationSession, RecommendationStats,
};
use do_memory_core::storage::circuit_breaker::{
CircuitBreaker, CircuitBreakerConfig, CircuitState,
};
use do_memory_core::{Episode, Heuristic, Pattern, Result, StorageBackend};
use std::sync::Arc;
use tracing::{info, warn};
use uuid::Uuid;
#[cfg(test)]
use do_memory_core::Error;
use crate::TursoStorage;
pub struct ResilientStorage {
storage: Arc<TursoStorage>,
circuit_breaker: Arc<CircuitBreaker>,
}
impl ResilientStorage {
pub fn new(storage: TursoStorage, config: CircuitBreakerConfig) -> Self {
info!("Creating resilient storage with circuit breaker protection");
Self {
storage: Arc::new(storage),
circuit_breaker: Arc::new(CircuitBreaker::new(config)),
}
}
pub async fn circuit_state(&self) -> CircuitState {
self.circuit_breaker.state().await
}
pub async fn circuit_stats(
&self,
) -> do_memory_core::storage::circuit_breaker::CircuitBreakerStats {
self.circuit_breaker.stats().await
}
pub async fn reset_circuit(&self) {
self.circuit_breaker.reset().await;
}
pub async fn health_check(&self) -> Result<bool> {
let circuit_state = self.circuit_state().await;
if circuit_state != CircuitState::Closed {
warn!("Health check: circuit breaker is {:?}", circuit_state);
return Ok(false);
}
self.circuit_breaker
.call(|| async { self.storage.health_check().await })
.await
}
}
#[async_trait]
impl StorageBackend for ResilientStorage {
async fn store_episode(&self, episode: &Episode) -> Result<()> {
let storage = Arc::clone(&self.storage);
let episode = episode.clone();
self.circuit_breaker
.call(move || {
let storage = Arc::clone(&storage);
async move { storage.store_episode(&episode).await }
})
.await
}
async fn get_episode(&self, id: Uuid) -> Result<Option<Episode>> {
let storage = Arc::clone(&self.storage);
self.circuit_breaker
.call(move || {
let storage = Arc::clone(&storage);
async move { storage.get_episode(id).await }
})
.await
}
async fn delete_episode(&self, id: Uuid) -> Result<()> {
let storage = Arc::clone(&self.storage);
self.circuit_breaker
.call(move || {
let storage = Arc::clone(&storage);
async move { storage.delete_episode(id).await }
})
.await
}
async fn store_pattern(&self, pattern: &Pattern) -> Result<()> {
let storage = Arc::clone(&self.storage);
let pattern = pattern.clone();
self.circuit_breaker
.call(move || {
let storage = Arc::clone(&storage);
async move { storage.store_pattern(&pattern).await }
})
.await
}
async fn get_pattern(&self, id: do_memory_core::episode::PatternId) -> Result<Option<Pattern>> {
let storage = Arc::clone(&self.storage);
self.circuit_breaker
.call(move || {
let storage = Arc::clone(&storage);
async move { storage.get_pattern(id).await }
})
.await
}
async fn store_heuristic(&self, heuristic: &Heuristic) -> Result<()> {
let storage = Arc::clone(&self.storage);
let heuristic = heuristic.clone();
self.circuit_breaker
.call(move || {
let storage = Arc::clone(&storage);
async move { storage.store_heuristic(&heuristic).await }
})
.await
}
async fn get_heuristic(&self, id: Uuid) -> Result<Option<Heuristic>> {
let storage = Arc::clone(&self.storage);
self.circuit_breaker
.call(move || {
let storage = Arc::clone(&storage);
async move { storage.get_heuristic(id).await }
})
.await
}
async fn query_episodes_since(
&self,
since: chrono::DateTime<chrono::Utc>,
limit: Option<usize>,
) -> Result<Vec<Episode>> {
let storage = Arc::clone(&self.storage);
self.circuit_breaker
.call(move || {
let storage = Arc::clone(&storage);
async move { storage.query_episodes_since(since, limit).await }
})
.await
}
async fn query_episodes_by_metadata(
&self,
key: &str,
value: &str,
limit: Option<usize>,
) -> Result<Vec<Episode>> {
let storage = Arc::clone(&self.storage);
let key_string = key.to_string();
let value_string = value.to_string();
let limit_param = limit;
self.circuit_breaker
.call(move || {
let storage = Arc::clone(&storage);
let key_string = key_string;
let value_string = value_string;
async move {
storage
.query_episodes_by_metadata(&key_string, &value_string, limit_param)
.await
}
})
.await
}
async fn store_embedding(&self, id: &str, embedding: Vec<f32>) -> Result<()> {
let storage = Arc::clone(&self.storage);
let id_string = id.to_string();
self.circuit_breaker
.call(move || {
let storage = Arc::clone(&storage);
async move { storage.store_embedding(&id_string, embedding).await }
})
.await
}
async fn get_embedding(&self, id: &str) -> Result<Option<Vec<f32>>> {
let storage = Arc::clone(&self.storage);
let id_string = id.to_string();
self.circuit_breaker
.call(move || {
let storage = Arc::clone(&storage);
let id = id_string;
async move { storage.get_embedding(&id).await }
})
.await
}
async fn delete_embedding(&self, id: &str) -> Result<bool> {
let storage = Arc::clone(&self.storage);
let id_string = id.to_string();
self.circuit_breaker
.call(move || {
let storage = Arc::clone(&storage);
let id = id_string;
async move { storage.delete_embedding(&id).await }
})
.await
}
async fn store_embeddings_batch(&self, embeddings: Vec<(String, Vec<f32>)>) -> Result<()> {
let storage = Arc::clone(&self.storage);
self.circuit_breaker
.call(move || {
let storage = Arc::clone(&storage);
async move { storage.store_embeddings_batch(embeddings).await }
})
.await
}
async fn get_embeddings_batch(&self, ids: &[String]) -> Result<Vec<Option<Vec<f32>>>> {
let storage = Arc::clone(&self.storage);
let ids_vec = ids.to_vec();
self.circuit_breaker
.call(move || {
let storage = Arc::clone(&storage);
let ids = ids_vec;
async move { storage.get_embeddings_batch(&ids).await }
})
.await
}
async fn store_recommendation_session(&self, session: &RecommendationSession) -> Result<()> {
let storage = Arc::clone(&self.storage);
let session = session.clone();
self.circuit_breaker
.call(move || {
let storage = Arc::clone(&storage);
async move { storage.store_recommendation_session(&session).await }
})
.await
}
async fn get_recommendation_session(
&self,
session_id: Uuid,
) -> Result<Option<RecommendationSession>> {
let storage = Arc::clone(&self.storage);
self.circuit_breaker
.call(move || {
let storage = Arc::clone(&storage);
async move { storage.get_recommendation_session(session_id).await }
})
.await
}
async fn get_recommendation_session_for_episode(
&self,
episode_id: Uuid,
) -> Result<Option<RecommendationSession>> {
let storage = Arc::clone(&self.storage);
self.circuit_breaker
.call(move || {
let storage = Arc::clone(&storage);
async move {
storage
.get_recommendation_session_for_episode(episode_id)
.await
}
})
.await
}
async fn store_recommendation_feedback(&self, feedback: &RecommendationFeedback) -> Result<()> {
let storage = Arc::clone(&self.storage);
let feedback = feedback.clone();
self.circuit_breaker
.call(move || {
let storage = Arc::clone(&storage);
async move { storage.store_recommendation_feedback(&feedback).await }
})
.await
}
async fn get_recommendation_feedback(
&self,
session_id: Uuid,
) -> Result<Option<RecommendationFeedback>> {
let storage = Arc::clone(&self.storage);
self.circuit_breaker
.call(move || {
let storage = Arc::clone(&storage);
async move { storage.get_recommendation_feedback(session_id).await }
})
.await
}
async fn get_recommendation_stats(&self) -> Result<RecommendationStats> {
let storage = Arc::clone(&self.storage);
self.circuit_breaker
.call(move || {
let storage = Arc::clone(&storage);
async move { storage.get_recommendation_stats().await }
})
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
use do_memory_core::storage::circuit_breaker::CircuitBreakerConfig;
use std::time::Duration;
use tempfile::TempDir;
async fn create_test_storage() -> Result<(ResilientStorage, TempDir)> {
let dir = TempDir::new().unwrap();
let db_path = dir.path().join("test.db");
let db = libsql::Builder::new_local(&db_path)
.build()
.await
.map_err(|e| Error::Storage(format!("Failed to create test database: {}", e)))?;
let turso = TursoStorage::from_database(db)?;
turso.initialize_schema().await?;
let config = CircuitBreakerConfig {
failure_threshold: 3,
timeout: Duration::from_secs(1),
..Default::default()
};
let resilient = ResilientStorage::new(turso, config);
Ok((resilient, dir))
}
#[tokio::test]
async fn test_resilient_storage_creation() {
let result = create_test_storage().await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_health_check_with_closed_circuit() {
let (storage, _dir) = create_test_storage().await.unwrap();
let healthy = storage.health_check().await.unwrap();
assert!(healthy);
assert_eq!(storage.circuit_state().await, CircuitState::Closed);
}
#[tokio::test]
async fn test_circuit_stats_tracking() {
let (storage, _dir) = create_test_storage().await.unwrap();
let episode = Episode::new(
"test".to_string(),
Default::default(),
do_memory_core::TaskType::CodeGeneration,
);
let result = storage.store_episode(&episode).await;
assert!(result.is_ok());
let stats = storage.circuit_stats().await;
assert_eq!(stats.total_calls, 1);
assert_eq!(stats.successful_calls, 1);
assert_eq!(stats.failed_calls, 0);
}
#[tokio::test]
async fn test_circuit_reset() {
let (storage, _dir) = create_test_storage().await.unwrap();
storage.reset_circuit().await;
assert_eq!(storage.circuit_state().await, CircuitState::Closed);
let stats = storage.circuit_stats().await;
assert_eq!(stats.consecutive_failures, 0);
}
}