use std::sync::Arc;
use super::episode_store::{EpisodeDto, EpisodeFilter, EpisodeMeta, EpisodeStore, StoreError};
use super::record_store::RecordStore;
use crate::learn::episode::{Episode, EpisodeContext, EpisodeId};
pub trait EpisodeRepository: Send + Sync {
fn save(&self, episode: &Episode) -> Result<EpisodeId, StoreError>;
fn find_by_id(&self, id: &EpisodeId) -> Result<Option<Episode>, StoreError>;
fn find_all(&self, filter: &EpisodeFilter) -> Result<Vec<Episode>, StoreError>;
fn count(&self, filter: Option<&EpisodeFilter>) -> Result<usize, StoreError>;
fn list_meta(&self, filter: Option<&EpisodeFilter>) -> Result<Vec<EpisodeMeta>, StoreError>;
}
pub struct DefaultEpisodeRepository {
record_store: Arc<dyn RecordStore>,
episode_store: Arc<dyn EpisodeStore>,
}
impl DefaultEpisodeRepository {
pub fn new(record_store: Arc<dyn RecordStore>, episode_store: Arc<dyn EpisodeStore>) -> Self {
Self {
record_store,
episode_store,
}
}
fn reconstruct_episode(&self, dto: &EpisodeDto) -> Result<Episode, StoreError> {
let records = self
.record_store
.get_batch(&dto.record_ids)
.map_err(|e| StoreError::Other(format!("Failed to get records: {}", e)))?;
let mut context = EpisodeContext::new();
for record in records {
context = context.with_record(record);
}
let episode = Episode::builder()
.id(dto.id.clone())
.learn_model(&dto.learn_model)
.outcome(dto.outcome.clone())
.metadata(dto.metadata.clone())
.context(context)
.build();
Ok(episode)
}
}
impl EpisodeRepository for DefaultEpisodeRepository {
fn save(&self, episode: &Episode) -> Result<EpisodeId, StoreError> {
let mut record_ids = Vec::new();
for record in &episode.context.records {
let id = self
.record_store
.append(record)
.map_err(|e| StoreError::Other(format!("Failed to save record: {}", e)))?;
record_ids.push(id);
}
let dto = EpisodeDto::from_episode(episode).with_record_ids(record_ids);
self.episode_store.append(&dto)
}
fn find_by_id(&self, id: &EpisodeId) -> Result<Option<Episode>, StoreError> {
match self.episode_store.get(id)? {
Some(dto) => Ok(Some(self.reconstruct_episode(&dto)?)),
None => Ok(None),
}
}
fn find_all(&self, filter: &EpisodeFilter) -> Result<Vec<Episode>, StoreError> {
let dtos = self.episode_store.query(filter)?;
let mut episodes = Vec::new();
for dto in dtos {
episodes.push(self.reconstruct_episode(&dto)?);
}
Ok(episodes)
}
fn count(&self, filter: Option<&EpisodeFilter>) -> Result<usize, StoreError> {
self.episode_store.count(filter)
}
fn list_meta(&self, filter: Option<&EpisodeFilter>) -> Result<Vec<EpisodeMeta>, StoreError> {
self.episode_store.list_meta(filter)
}
}
#[cfg(test)]
mod tests {
use super::super::episode_store::InMemoryEpisodeStore;
use super::super::record_store::InMemoryRecordStore;
use super::*;
use crate::learn::episode::Outcome;
use crate::learn::record::ActionRecord;
fn make_test_episode(worker_id: usize) -> Episode {
Episode::builder()
.learn_model("worker_task")
.record(ActionRecord::new(1, worker_id, "CheckStatus").success(true))
.record(ActionRecord::new(2, worker_id, "ReadLogs").success(true))
.record(ActionRecord::new(3, worker_id, "done").success(true))
.outcome(Outcome::success_binary())
.scenario("test-scenario")
.build()
}
#[test]
fn test_repository_save_and_find() {
let record_store = Arc::new(InMemoryRecordStore::new());
let episode_store = Arc::new(InMemoryEpisodeStore::new());
let repo = DefaultEpisodeRepository::new(record_store, episode_store);
let episode = make_test_episode(0);
let id = episode.id.clone();
repo.save(&episode).unwrap();
let found = repo.find_by_id(&id).unwrap();
assert!(found.is_some());
let found = found.unwrap();
assert_eq!(found.id, id);
assert_eq!(found.context.records.len(), 3);
}
#[test]
fn test_repository_find_all() {
let record_store = Arc::new(InMemoryRecordStore::new());
let episode_store = Arc::new(InMemoryEpisodeStore::new());
let repo = DefaultEpisodeRepository::new(record_store, episode_store);
repo.save(&make_test_episode(0)).unwrap();
repo.save(&make_test_episode(1)).unwrap();
repo.save(&make_test_episode(2)).unwrap();
let filter = EpisodeFilter::new();
let found = repo.find_all(&filter).unwrap();
assert_eq!(found.len(), 3);
for ep in found {
assert_eq!(ep.context.records.len(), 3);
}
}
#[test]
fn test_repository_count() {
let record_store = Arc::new(InMemoryRecordStore::new());
let episode_store = Arc::new(InMemoryEpisodeStore::new());
let repo = DefaultEpisodeRepository::new(record_store, episode_store);
repo.save(&make_test_episode(0)).unwrap();
repo.save(&make_test_episode(1)).unwrap();
assert_eq!(repo.count(None).unwrap(), 2);
}
}