use super::super::episodes::row_to_episode;
use super::super::patterns::row_to_pattern;
use crate::TursoStorage;
use do_memory_core::{Episode, Error, Heuristic, Pattern, Result, episode::PatternId};
use tracing::{debug, info};
use uuid::Uuid;
impl TursoStorage {
pub async fn get_episodes_batch(&self, ids: &[Uuid]) -> Result<Vec<Option<Episode>>> {
if ids.is_empty() {
debug!("Empty IDs batch received for episode retrieval");
return Ok(Vec::new());
}
debug!("Retrieving episodes batch: {} items", ids.len());
let conn = self.get_connection().await?;
let placeholders: Vec<String> = ids.iter().map(|_| "?".to_string()).collect();
let sql = format!(
r#"
SELECT episode_id, task_type, task_description, context,
start_time, end_time, steps, outcome, reward,
reflection, patterns, heuristics,
COALESCE(checkpoints, '[]') AS checkpoints,
metadata, domain, language,
archived_at
FROM episodes WHERE episode_id IN ({})
"#,
placeholders.join(", ")
);
let params: Vec<libsql::Value> = ids
.iter()
.map(|id| libsql::Value::Text(id.to_string()))
.collect();
let mut rows = conn
.query(&sql, libsql::params_from_iter(params))
.await
.map_err(|e| Error::Storage(format!("Failed to query episodes batch: {}", e)))?;
let mut episode_map = std::collections::HashMap::new();
while let Some(row) = rows
.next()
.await
.map_err(|e| Error::Storage(format!("Failed to fetch episode row: {}", e)))?
{
let episode = row_to_episode(&row)?;
episode_map.insert(episode.episode_id, episode);
}
let result: Vec<Option<Episode>> =
ids.iter().map(|id| episode_map.get(id).cloned()).collect();
info!(
"Retrieved {} of {} requested episodes",
result.iter().filter(|e| e.is_some()).count(),
ids.len()
);
Ok(result)
}
pub async fn get_patterns_batch(&self, ids: &[PatternId]) -> Result<Vec<Option<Pattern>>> {
if ids.is_empty() {
debug!("Empty IDs batch received for pattern retrieval");
return Ok(Vec::new());
}
debug!("Retrieving patterns batch: {} items", ids.len());
let conn = self.get_connection().await?;
let placeholders: Vec<String> = ids.iter().map(|_| "?".to_string()).collect();
let sql = format!(
r#"
SELECT pattern_id, pattern_type, pattern_data, success_rate,
context_domain, context_language, context_tags, occurrence_count,
created_at, updated_at
FROM patterns WHERE pattern_id IN ({})
"#,
placeholders.join(", ")
);
let params: Vec<libsql::Value> = ids
.iter()
.map(|id| libsql::Value::Text(id.to_string()))
.collect();
let mut rows = conn
.query(&sql, libsql::params_from_iter(params))
.await
.map_err(|e| Error::Storage(format!("Failed to query patterns batch: {}", e)))?;
let mut pattern_map = std::collections::HashMap::new();
while let Some(row) = rows
.next()
.await
.map_err(|e| Error::Storage(format!("Failed to fetch pattern row: {}", e)))?
{
let pattern = row_to_pattern(&row)?;
pattern_map.insert(pattern.id(), pattern);
}
let result: Vec<Option<Pattern>> =
ids.iter().map(|id| pattern_map.get(id).cloned()).collect();
info!(
"Retrieved {} of {} requested patterns",
result.iter().filter(|e| e.is_some()).count(),
ids.len()
);
Ok(result)
}
pub async fn get_heuristics_batch(&self, ids: &[Uuid]) -> Result<Vec<Option<Heuristic>>> {
if ids.is_empty() {
debug!("Empty IDs batch received for heuristic retrieval");
return Ok(Vec::new());
}
debug!("Retrieving heuristics batch: {} items", ids.len());
let conn = self.get_connection().await?;
let placeholders: Vec<String> = ids.iter().map(|_| "?".to_string()).collect();
let sql = format!(
r#"
SELECT heuristic_id, condition_text, action_text, confidence, evidence, created_at, updated_at
FROM heuristics WHERE heuristic_id IN ({})
"#,
placeholders.join(", ")
);
let params: Vec<libsql::Value> = ids
.iter()
.map(|id| libsql::Value::Text(id.to_string()))
.collect();
let mut rows = conn
.query(&sql, libsql::params_from_iter(params))
.await
.map_err(|e| Error::Storage(format!("Failed to query heuristics batch: {}", e)))?;
let mut heuristic_map = std::collections::HashMap::new();
while let Some(row) = rows
.next()
.await
.map_err(|e| Error::Storage(format!("Failed to fetch heuristic row: {}", e)))?
{
let heuristic = super::super::heuristics::row_to_heuristic(&row)?;
heuristic_map.insert(heuristic.heuristic_id, heuristic);
}
let result: Vec<Option<Heuristic>> = ids
.iter()
.map(|id| heuristic_map.get(id).cloned())
.collect();
info!(
"Retrieved {} of {} requested heuristics",
result.iter().filter(|e| e.is_some()).count(),
ids.len()
);
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use do_memory_core::{Episode, TaskContext, TaskType, memory::checkpoint::CheckpointMeta};
use tempfile::TempDir;
use uuid::Uuid;
async fn create_test_storage() -> Result<(TursoStorage, TempDir)> {
let dir = TempDir::new().unwrap();
let db_path = dir.path().join("test.db");
let db = libsql::Builder::new_local(&db_path)
.build()
.await
.map_err(|e| Error::Storage(format!("Failed to create test database: {}", e)))?;
let storage = TursoStorage::from_database(db)?;
storage.initialize_schema().await?;
Ok((storage, dir))
}
#[tokio::test]
async fn test_get_episodes_batch_empty() {
let (storage, _dir) = create_test_storage().await.unwrap();
let result = storage.get_episodes_batch(&[]).await.unwrap();
assert!(result.is_empty());
}
#[tokio::test]
async fn test_get_episodes_batch_with_missing() {
let (storage, _dir) = create_test_storage().await.unwrap();
let id1 = Uuid::new_v4();
let id2 = Uuid::new_v4();
let episode = Episode {
episode_id: id1,
task_type: TaskType::CodeGeneration,
task_description: "Test task".to_string(),
context: TaskContext::default(),
start_time: chrono::Utc::now(),
end_time: None,
steps: Vec::new(),
outcome: None,
reward: None,
reflection: None,
patterns: Vec::new(),
heuristics: Vec::new(),
applied_patterns: Vec::new(),
salient_features: None,
metadata: std::collections::HashMap::new(),
tags: Vec::new(),
checkpoints: Vec::new(),
};
storage.store_episodes_batch(vec![episode]).await.unwrap();
let result = storage.get_episodes_batch(&[id1, id2]).await.unwrap();
assert_eq!(result.len(), 2);
assert!(result[0].is_some());
assert!(result[1].is_none());
}
#[tokio::test]
async fn test_get_patterns_batch_empty() {
let (storage, _dir) = create_test_storage().await.unwrap();
let result = storage.get_patterns_batch(&[]).await.unwrap();
assert!(result.is_empty());
}
#[tokio::test]
async fn test_store_and_get_multiple_episodes_batch() {
let (storage, _dir) = create_test_storage().await.unwrap();
let episodes = vec![
Episode::new(
"Task 1".to_string(),
TaskContext::default(),
TaskType::CodeGeneration,
),
Episode::new(
"Task 2".to_string(),
TaskContext::default(),
TaskType::Debugging,
),
Episode::new(
"Task 3".to_string(),
TaskContext::default(),
TaskType::Refactoring,
),
];
let result = storage.store_episodes_batch(episodes).await;
assert!(result.is_ok());
let ids: Vec<Uuid> = vec![
uuid::Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap(),
uuid::Uuid::parse_str("00000000-0000-0000-0000-000000000002").unwrap(),
uuid::Uuid::parse_str("00000000-0000-0000-0000-000000000003").unwrap(),
];
let _retrieved = storage.get_episodes_batch(&ids).await.unwrap();
}
#[tokio::test]
async fn test_get_episodes_batch_preserves_checkpoints() {
let (storage, _dir) = create_test_storage().await.unwrap();
let mut episode = Episode::new(
"Task with checkpoint".to_string(),
TaskContext::default(),
TaskType::Analysis,
);
episode
.checkpoints
.push(CheckpointMeta::new("batch checkpoint".to_string(), 1, None));
let episode_id = episode.episode_id;
storage.store_episodes_batch(vec![episode]).await.unwrap();
let retrieved = storage.get_episodes_batch(&[episode_id]).await.unwrap();
assert_eq!(retrieved.len(), 1);
let stored = retrieved[0].as_ref().unwrap();
assert_eq!(stored.checkpoints.len(), 1);
assert_eq!(stored.checkpoints[0].reason, "batch checkpoint");
}
#[tokio::test]
async fn test_get_heuristics_batch_empty() {
let (storage, _dir) = create_test_storage().await.unwrap();
let result = storage.get_heuristics_batch(&[]).await.unwrap();
assert!(result.is_empty());
}
#[tokio::test]
async fn test_get_heuristics_batch_with_missing() {
let (storage, _dir) = create_test_storage().await.unwrap();
let id1 = Uuid::new_v4();
let id2 = Uuid::new_v4();
let heuristic = create_test_heuristic_with_id(id1);
storage
.store_heuristics_batch(vec![heuristic])
.await
.unwrap();
let result = storage.get_heuristics_batch(&[id1, id2]).await.unwrap();
assert_eq!(result.len(), 2);
assert!(result[0].is_some());
assert!(result[1].is_none());
}
#[tokio::test]
async fn test_store_and_get_heuristics_batch() {
let (storage, _dir) = create_test_storage().await.unwrap();
let id1 = Uuid::new_v4();
let id2 = Uuid::new_v4();
let id3 = Uuid::new_v4();
let heuristics = vec![
create_test_heuristic_with_id(id1),
create_test_heuristic_with_id(id2),
create_test_heuristic_with_id(id3),
];
storage
.store_heuristics_batch(heuristics.clone())
.await
.unwrap();
let retrieved = storage
.get_heuristics_batch(&[id1, id2, id3])
.await
.unwrap();
assert_eq!(retrieved.len(), 3);
assert!(retrieved[0].is_some());
assert!(retrieved[1].is_some());
assert!(retrieved[2].is_some());
assert_eq!(retrieved[0].as_ref().unwrap().heuristic_id, id1);
assert_eq!(retrieved[1].as_ref().unwrap().heuristic_id, id2);
assert_eq!(retrieved[2].as_ref().unwrap().heuristic_id, id3);
}
fn create_test_heuristic_with_id(id: Uuid) -> Heuristic {
use do_memory_core::types::Evidence;
Heuristic {
heuristic_id: id,
condition: format!("test condition {}", id),
action: format!("test action {}", id),
confidence: 0.85,
evidence: Evidence {
episode_ids: vec![],
success_rate: 0.9,
sample_size: 10,
},
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
}
}
}