use crate::TursoStorage;
use do_memory_core::{Episode, Error, Result, semantic::EpisodeSummary};
use tracing::{debug, info};
use uuid::Uuid;
#[cfg(feature = "compression")]
use super::compression::compress_json_field;
impl TursoStorage {
pub async fn store_episode(&self, episode: &Episode) -> Result<()> {
debug!("Storing episode: {}", episode.episode_id);
let (conn, conn_id) = self.get_connection_with_id().await?;
const SQL: &str = r#"
INSERT OR REPLACE INTO episodes (
episode_id, task_type, task_description, context,
start_time, end_time, steps, outcome, reward,
reflection, patterns, heuristics, checkpoints, metadata, domain, language,
archived_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"#;
#[cfg(feature = "compression")]
let compression_threshold = self.config.compression_threshold;
#[cfg(not(feature = "compression"))]
let _compression_threshold = 0;
let context_json = serde_json::to_string(&episode.context).map_err(Error::Serialization)?;
let steps_json = serde_json::to_string(&episode.steps).map_err(Error::Serialization)?;
let outcome_json = episode
.outcome
.as_ref()
.map(serde_json::to_string)
.transpose()
.map_err(Error::Serialization)?;
let reward_json = episode
.reward
.as_ref()
.map(serde_json::to_string)
.transpose()
.map_err(Error::Serialization)?;
let reflection_json = episode
.reflection
.as_ref()
.map(serde_json::to_string)
.transpose()
.map_err(Error::Serialization)?;
#[cfg(feature = "compression")]
let should_compress = self.config.compress_episodes;
#[cfg(not(feature = "compression"))]
let _should_compress = false;
#[cfg(feature = "compression")]
let patterns_json = if should_compress {
let data = serde_json::to_string(&episode.patterns).map_err(Error::Serialization)?;
compress_json_field(data.as_bytes(), compression_threshold)?
} else {
serde_json::to_string(&episode.patterns)
.map_err(Error::Serialization)?
.into_bytes()
};
#[cfg(not(feature = "compression"))]
let patterns_json: Vec<u8> = serde_json::to_string(&episode.patterns)
.map_err(Error::Serialization)?
.into_bytes();
#[cfg(feature = "compression")]
let heuristics_json = if should_compress {
let data = serde_json::to_string(&episode.heuristics).map_err(Error::Serialization)?;
compress_json_field(data.as_bytes(), compression_threshold)?
} else {
serde_json::to_string(&episode.heuristics)
.map_err(Error::Serialization)?
.into_bytes()
};
#[cfg(not(feature = "compression"))]
let heuristics_json: Vec<u8> = serde_json::to_string(&episode.heuristics)
.map_err(Error::Serialization)?
.into_bytes();
#[cfg(feature = "compression")]
let metadata_json = if should_compress {
let data = serde_json::to_string(&episode.metadata).map_err(Error::Serialization)?;
compress_json_field(data.as_bytes(), compression_threshold)?
} else {
serde_json::to_string(&episode.metadata)
.map_err(Error::Serialization)?
.into_bytes()
};
#[cfg(not(feature = "compression"))]
let metadata_json: Vec<u8> = serde_json::to_string(&episode.metadata)
.map_err(Error::Serialization)?
.into_bytes();
let checkpoints_json =
serde_json::to_string(&episode.checkpoints).map_err(Error::Serialization)?;
let archived_at = episode
.metadata
.get("archived_at")
.and_then(|v| v.parse::<i64>().ok());
let patterns_str = String::from_utf8(patterns_json)
.map_err(|e| Error::Storage(format!("Failed to convert patterns to UTF-8: {}", e)))?;
let heuristics_str = String::from_utf8(heuristics_json)
.map_err(|e| Error::Storage(format!("Failed to convert heuristics to UTF-8: {}", e)))?;
let metadata_str = String::from_utf8(metadata_json)
.map_err(|e| Error::Storage(format!("Failed to convert metadata to UTF-8: {}", e)))?;
let stmt = self.prepare_cached(conn_id, &conn, SQL).await?;
stmt.execute(libsql::params![
episode.episode_id.to_string(),
episode.task_type.to_string(),
episode.task_description.clone(),
context_json,
episode.start_time.timestamp(),
episode.end_time.map(|t| t.timestamp()),
steps_json,
outcome_json,
reward_json,
reflection_json,
patterns_str,
heuristics_str,
checkpoints_json,
metadata_str,
episode.context.domain.clone(),
episode.context.language.clone(),
archived_at,
])
.await
.map_err(|e| Error::Storage(format!("Failed to store episode: {}", e)))?;
self.clear_prepared_cache(conn_id);
#[cfg(feature = "adaptive-ttl")]
if let Some(ref cache) = self.episode_cache {
cache.remove(&episode.episode_id.to_string()).await;
}
info!("Successfully stored episode: {}", episode.episode_id);
Ok(())
}
pub async fn get_episode(&self, episode_id: Uuid) -> Result<Option<Episode>> {
debug!("Retrieving episode: {}", episode_id);
#[cfg(feature = "adaptive-ttl")]
if let Some(ref cache) = self.episode_cache {
let cache_key = episode_id.to_string();
if let Some(cached_episode) = cache.get(&cache_key).await {
debug!("Episode cache hit for: {}", episode_id);
return Ok(Some(cached_episode));
}
debug!("Episode cache miss for: {}", episode_id);
}
let (conn, conn_id) = self.get_connection_with_id().await?;
const SQL: &str = r#"
SELECT episode_id, task_type, task_description, context,
start_time, end_time, steps, outcome, reward,
reflection, patterns, heuristics,
COALESCE(checkpoints, '[]') AS checkpoints,
metadata, domain, language,
archived_at
FROM episodes WHERE episode_id = ?
"#;
let stmt = self.prepare_cached(conn_id, &conn, SQL).await?;
let mut rows = stmt
.query(libsql::params![episode_id.to_string()])
.await
.map_err(|e| Error::Storage(format!("Failed to query episode: {}", e)))?;
let result = if let Some(row) = rows
.next()
.await
.map_err(|e| Error::Storage(format!("Failed to fetch episode row: {}", e)))?
{
let episode = Self::row_to_episode(&row)?;
#[cfg(feature = "adaptive-ttl")]
if let Some(ref cache) = self.episode_cache {
cache.insert(episode_id.to_string(), episode.clone()).await;
debug!("Cached episode: {}", episode_id);
}
Ok(Some(episode))
} else {
Ok(None)
};
self.clear_prepared_cache(conn_id);
result
}
pub async fn delete_episode(&self, episode_id: Uuid) -> Result<()> {
debug!("Deleting episode: {}", episode_id);
let (conn, conn_id) = self.get_connection_with_id().await?;
const SQL: &str = "DELETE FROM episodes WHERE episode_id = ?";
let stmt = self.prepare_cached(conn_id, &conn, SQL).await?;
stmt.execute(libsql::params![episode_id.to_string()])
.await
.map_err(|e| Error::Storage(format!("Failed to delete episode: {}", e)))?;
self.clear_prepared_cache(conn_id);
#[cfg(feature = "adaptive-ttl")]
if let Some(ref cache) = self.episode_cache {
cache.remove(&episode_id.to_string()).await;
}
info!("Successfully deleted episode: {}", episode_id);
Ok(())
}
pub async fn store_episode_summary(&self, summary: &EpisodeSummary) -> Result<()> {
debug!("Storing episode summary: {}", summary.episode_id);
let (conn, conn_id) = self.get_connection_with_id().await?;
const SQL: &str = r#"
INSERT OR REPLACE INTO episode_summaries (
episode_id, summary_text, key_concepts, key_steps,
summary_embedding, created_at
) VALUES (?, ?, ?, ?, ?, ?)
"#;
let key_concepts_json =
serde_json::to_string(&summary.key_concepts).map_err(Error::Serialization)?;
let key_steps_json =
serde_json::to_string(&summary.key_steps).map_err(Error::Serialization)?;
let embedding_json = summary
.summary_embedding
.as_ref()
.map(serde_json::to_string)
.transpose()
.map_err(Error::Serialization)?;
let stmt = self.prepare_cached(conn_id, &conn, SQL).await?;
stmt.execute(libsql::params![
summary.episode_id.to_string(),
summary.summary_text.clone(),
key_concepts_json,
key_steps_json,
embedding_json,
summary.created_at.timestamp(),
])
.await
.map_err(|e| Error::Storage(format!("Failed to store summary: {}", e)))?;
self.clear_prepared_cache(conn_id);
info!(
"Successfully stored summary for episode: {}",
summary.episode_id
);
Ok(())
}
pub async fn get_episode_summary(&self, episode_id: Uuid) -> Result<Option<EpisodeSummary>> {
debug!("Retrieving episode summary: {}", episode_id);
let (conn, conn_id) = self.get_connection_with_id().await?;
const SQL: &str = r#"
SELECT episode_id, summary_text, key_concepts, key_steps,
summary_embedding, created_at
FROM episode_summaries WHERE episode_id = ?
"#;
let stmt = self.prepare_cached(conn_id, &conn, SQL).await?;
let mut rows = stmt
.query(libsql::params![episode_id.to_string()])
.await
.map_err(|e| Error::Storage(format!("Failed to query summary: {}", e)))?;
let result = if let Some(row) = rows
.next()
.await
.map_err(|e| Error::Storage(format!("Failed to fetch summary row: {}", e)))?
{
let summary = Self::row_to_summary(&row)?;
Ok(Some(summary))
} else {
Ok(None)
};
self.clear_prepared_cache(conn_id);
result
}
pub async fn get_episode_by_task_desc(&self, task_desc: &str) -> Result<Option<Episode>> {
debug!("Retrieving episode by task description: {}", task_desc);
let conn = self.get_connection().await?;
let sql = r#"
SELECT episode_id, task_type, task_description, context,
start_time, end_time, steps, outcome, reward,
reflection, patterns, heuristics,
COALESCE(checkpoints, '[]') AS checkpoints,
metadata, domain, language,
archived_at
FROM episodes WHERE task_description = ?
"#;
let mut rows = conn
.query(sql, libsql::params![task_desc])
.await
.map_err(|e| Error::Storage(format!("Failed to query episode: {}", e)))?;
if let Some(row) = rows
.next()
.await
.map_err(|e| Error::Storage(format!("Failed to fetch episode row: {}", e)))?
{
let episode = Self::row_to_episode(&row)?;
Ok(Some(episode))
} else {
Ok(None)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use do_memory_core::{Episode, TaskContext, TaskType, memory::checkpoint::CheckpointMeta};
use tempfile::TempDir;
async fn create_test_storage() -> Result<(TursoStorage, TempDir)> {
let dir = TempDir::new().unwrap();
let db_path = dir.path().join("test.db");
let db = libsql::Builder::new_local(&db_path)
.build()
.await
.map_err(|e| Error::Storage(format!("Failed to create test database: {}", e)))?;
let storage = TursoStorage::from_database(db)?;
storage.initialize_schema().await?;
Ok((storage, dir))
}
#[tokio::test]
async fn test_store_and_get_episode() {
let (storage, _dir) = create_test_storage().await.unwrap();
let episode = Episode::new(
"Test episode".to_string(),
TaskContext::default(),
TaskType::CodeGeneration,
);
let episode_id = episode.episode_id;
storage.store_episode(&episode).await.unwrap();
let retrieved = storage.get_episode(episode_id).await.unwrap();
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().task_description, "Test episode");
}
#[tokio::test]
async fn test_delete_episode() {
let (storage, _dir) = create_test_storage().await.unwrap();
let episode = Episode::new(
"To delete".to_string(),
TaskContext::default(),
TaskType::Debugging,
);
let episode_id = episode.episode_id;
storage.store_episode(&episode).await.unwrap();
storage.delete_episode(episode_id).await.unwrap();
let retrieved = storage.get_episode(episode_id).await.unwrap();
assert!(retrieved.is_none());
}
#[tokio::test]
async fn test_get_nonexistent_episode() {
let (storage, _dir) = create_test_storage().await.unwrap();
let nonexistent_id = Uuid::new_v4();
let result = storage.get_episode(nonexistent_id).await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn test_store_and_get_episode_persists_checkpoints() {
let (storage, _dir) = create_test_storage().await.unwrap();
let mut episode = Episode::new(
"Checkpoint test".to_string(),
TaskContext::default(),
TaskType::CodeGeneration,
);
episode.checkpoints.push(CheckpointMeta::new(
"handoff".to_string(),
2,
Some("persist me".to_string()),
));
let episode_id = episode.episode_id;
storage.store_episode(&episode).await.unwrap();
let retrieved = storage.get_episode(episode_id).await.unwrap().unwrap();
assert_eq!(retrieved.checkpoints.len(), 1);
assert_eq!(retrieved.checkpoints[0].reason, "handoff");
assert_eq!(retrieved.checkpoints[0].step_number, 2);
assert_eq!(retrieved.checkpoints[0].note.as_deref(), Some("persist me"));
}
}