swarm-engine-core 0.1.6

Core types and orchestration for SwarmEngine
Documentation
//! EpisodeRepository - Domain 層に公開するインターフェース
//!
//! RecordStore と EpisodeStore を組み合わせて、
//! Domain Entity (Episode) の CRUD を提供する。

use std::sync::Arc;

use super::episode_store::{EpisodeDto, EpisodeFilter, EpisodeMeta, EpisodeStore, StoreError};
use super::record_store::RecordStore;
use crate::learn::episode::{Episode, EpisodeContext, EpisodeId};

// ============================================================================
// EpisodeRepository Trait
// ============================================================================

/// Domain 層に公開する Episode の Repository インターフェース
///
/// Domain は永続化の詳細(RecordStore, EpisodeStore)を知らず、
/// このインターフェースのみを使用する。
pub trait EpisodeRepository: Send + Sync {
    /// Episode を保存(Record も含めて永続化)
    fn save(&self, episode: &Episode) -> Result<EpisodeId, StoreError>;

    /// ID で Episode を取得(Record を含む完全な Entity)
    fn find_by_id(&self, id: &EpisodeId) -> Result<Option<Episode>, StoreError>;

    /// フィルタで検索(Record を含む完全な Entity)
    fn find_all(&self, filter: &EpisodeFilter) -> Result<Vec<Episode>, StoreError>;

    /// 件数を取得
    fn count(&self, filter: Option<&EpisodeFilter>) -> Result<usize, StoreError>;

    /// メタ情報のみをリスト(軽量)
    fn list_meta(&self, filter: Option<&EpisodeFilter>) -> Result<Vec<EpisodeMeta>, StoreError>;
}

// ============================================================================
// DefaultEpisodeRepository
// ============================================================================

/// RecordStore + EpisodeStore を組み合わせた Repository 実装
pub struct DefaultEpisodeRepository {
    record_store: Arc<dyn RecordStore>,
    episode_store: Arc<dyn EpisodeStore>,
}

impl DefaultEpisodeRepository {
    pub fn new(record_store: Arc<dyn RecordStore>, episode_store: Arc<dyn EpisodeStore>) -> Self {
        Self {
            record_store,
            episode_store,
        }
    }

    /// DTO と Record ID リストから Episode Entity を再構築
    fn reconstruct_episode(&self, dto: &EpisodeDto) -> Result<Episode, StoreError> {
        // Record を取得
        let records = self
            .record_store
            .get_batch(&dto.record_ids)
            .map_err(|e| StoreError::Other(format!("Failed to get records: {}", e)))?;

        // EpisodeContext を再構築
        let mut context = EpisodeContext::new();
        for record in records {
            context = context.with_record(record);
        }

        // Episode を構築
        let episode = Episode::builder()
            .id(dto.id.clone())
            .learn_model(&dto.learn_model)
            .outcome(dto.outcome.clone())
            .metadata(dto.metadata.clone())
            .context(context)
            .build();

        Ok(episode)
    }
}

impl EpisodeRepository for DefaultEpisodeRepository {
    fn save(&self, episode: &Episode) -> Result<EpisodeId, StoreError> {
        // 1. Record を永続化
        let mut record_ids = Vec::new();
        for record in &episode.context.records {
            let id = self
                .record_store
                .append(record)
                .map_err(|e| StoreError::Other(format!("Failed to save record: {}", e)))?;
            record_ids.push(id);
        }

        // 2. Episode DTO を作成して永続化
        let dto = EpisodeDto::from_episode(episode).with_record_ids(record_ids);

        self.episode_store.append(&dto)
    }

    fn find_by_id(&self, id: &EpisodeId) -> Result<Option<Episode>, StoreError> {
        match self.episode_store.get(id)? {
            Some(dto) => Ok(Some(self.reconstruct_episode(&dto)?)),
            None => Ok(None),
        }
    }

    fn find_all(&self, filter: &EpisodeFilter) -> Result<Vec<Episode>, StoreError> {
        let dtos = self.episode_store.query(filter)?;
        let mut episodes = Vec::new();

        for dto in dtos {
            episodes.push(self.reconstruct_episode(&dto)?);
        }

        Ok(episodes)
    }

    fn count(&self, filter: Option<&EpisodeFilter>) -> Result<usize, StoreError> {
        self.episode_store.count(filter)
    }

    fn list_meta(&self, filter: Option<&EpisodeFilter>) -> Result<Vec<EpisodeMeta>, StoreError> {
        self.episode_store.list_meta(filter)
    }
}

// ============================================================================
// Tests
// ============================================================================

#[cfg(test)]
mod tests {
    use super::super::episode_store::InMemoryEpisodeStore;
    use super::super::record_store::InMemoryRecordStore;
    use super::*;
    use crate::learn::episode::Outcome;
    use crate::learn::record::ActionRecord;

    fn make_test_episode(worker_id: usize) -> Episode {
        Episode::builder()
            .learn_model("worker_task")
            .record(ActionRecord::new(1, worker_id, "CheckStatus").success(true))
            .record(ActionRecord::new(2, worker_id, "ReadLogs").success(true))
            .record(ActionRecord::new(3, worker_id, "done").success(true))
            .outcome(Outcome::success_binary())
            .scenario("test-scenario")
            .build()
    }

    #[test]
    fn test_repository_save_and_find() {
        let record_store = Arc::new(InMemoryRecordStore::new());
        let episode_store = Arc::new(InMemoryEpisodeStore::new());
        let repo = DefaultEpisodeRepository::new(record_store, episode_store);

        let episode = make_test_episode(0);
        let id = episode.id.clone();

        repo.save(&episode).unwrap();

        let found = repo.find_by_id(&id).unwrap();
        assert!(found.is_some());

        let found = found.unwrap();
        assert_eq!(found.id, id);
        assert_eq!(found.context.records.len(), 3);
    }

    #[test]
    fn test_repository_find_all() {
        let record_store = Arc::new(InMemoryRecordStore::new());
        let episode_store = Arc::new(InMemoryEpisodeStore::new());
        let repo = DefaultEpisodeRepository::new(record_store, episode_store);

        repo.save(&make_test_episode(0)).unwrap();
        repo.save(&make_test_episode(1)).unwrap();
        repo.save(&make_test_episode(2)).unwrap();

        let filter = EpisodeFilter::new();
        let found = repo.find_all(&filter).unwrap();
        assert_eq!(found.len(), 3);

        // 各 Episode が Record を持っていることを確認
        for ep in found {
            assert_eq!(ep.context.records.len(), 3);
        }
    }

    #[test]
    fn test_repository_count() {
        let record_store = Arc::new(InMemoryRecordStore::new());
        let episode_store = Arc::new(InMemoryEpisodeStore::new());
        let repo = DefaultEpisodeRepository::new(record_store, episode_store);

        repo.save(&make_test_episode(0)).unwrap();
        repo.save(&make_test_episode(1)).unwrap();

        assert_eq!(repo.count(None).unwrap(), 2);
    }
}