use crate::episode::Episode;
use crate::spatiotemporal::index::types::{TaskTypeIndex, TemporalCluster};
use crate::types::TaskType;
use chrono::{DateTime, Utc};
use std::collections::HashMap;
use uuid::Uuid;
#[derive(Debug, Clone, PartialEq)]
pub struct DomainIndex {
pub domain: String,
pub task_type_indices: HashMap<TaskType, TaskTypeIndex>,
pub uncategorized_episodes: Vec<Uuid>,
pub total_episodes: usize,
pub last_updated: DateTime<Utc>,
}
impl DomainIndex {
#[must_use]
pub fn new(domain: String) -> Self {
Self {
domain,
task_type_indices: HashMap::new(),
uncategorized_episodes: Vec::new(),
total_episodes: 0,
last_updated: Utc::now(),
}
}
pub fn insert_episode(&mut self, episode: &Episode) {
let task_type = episode.task_type;
self.task_type_indices
.entry(task_type)
.or_insert_with(|| TaskTypeIndex::new(task_type))
.insert_from_episode(episode);
if !self.uncategorized_episodes.contains(&episode.episode_id) {
self.uncategorized_episodes.push(episode.episode_id);
}
self.total_episodes += 1;
self.last_updated = Utc::now();
}
pub fn remove_episode(&mut self, episode_id: Uuid) -> bool {
let mut removed = false;
for task_type_index in self.task_type_indices.values_mut() {
if task_type_index.remove_episode(episode_id) {
removed = true;
}
}
if let Some(pos) = self
.uncategorized_episodes
.iter()
.position(|&id| id == episode_id)
{
self.uncategorized_episodes.remove(pos);
removed = true;
}
if removed {
self.total_episodes = self.total_episodes.saturating_sub(1);
self.last_updated = Utc::now();
}
removed
}
#[must_use]
pub fn get_episodes_by_task_type_and_time(
&self,
task_type: TaskType,
start: DateTime<Utc>,
end: DateTime<Utc>,
) -> Vec<Uuid> {
if let Some(task_type_index) = self.task_type_indices.get(&task_type) {
task_type_index.get_episodes_in_range(start, end)
} else {
Vec::new()
}
}
#[must_use]
pub fn get_recent_episodes(&self, limit: usize) -> Vec<Uuid> {
let mut all_episodes: Vec<Uuid> = self.uncategorized_episodes.clone();
for task_type_index in self.task_type_indices.values() {
for cluster in &task_type_index.temporal_clusters {
all_episodes.extend(cluster.episode_ids.clone());
}
}
all_episodes.into_iter().take(limit).collect()
}
#[must_use]
pub fn get_clusters_for_task_type(&self, task_type: TaskType) -> Option<&Vec<TemporalCluster>> {
self.task_type_indices
.get(&task_type)
.map(|idx| &idx.temporal_clusters)
}
pub fn cleanup_empty_clusters(&mut self) {
for task_type_index in self.task_type_indices.values_mut() {
task_type_index.cleanup_empty_clusters();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{TaskContext, TaskType};
fn create_test_episode(domain: &str, task_type: TaskType) -> Episode {
let context = TaskContext {
domain: domain.to_string(),
complexity: crate::types::ComplexityLevel::Simple,
tags: vec![],
..Default::default()
};
Episode::new("Test episode".to_string(), context, task_type)
}
#[test]
fn test_domain_index_insert() {
let mut index = DomainIndex::new("test-domain".to_string());
let episode1 = create_test_episode("test-domain", TaskType::CodeGeneration);
let episode2 = create_test_episode("test-domain", TaskType::CodeGeneration);
index.insert_episode(&episode1);
index.insert_episode(&episode2);
assert_eq!(index.total_episodes, 2);
assert!(index.uncategorized_episodes.contains(&episode1.episode_id));
assert!(index.uncategorized_episodes.contains(&episode2.episode_id));
}
#[test]
fn test_domain_index_remove() {
let mut index = DomainIndex::new("test-domain".to_string());
let episode = create_test_episode("test-domain", TaskType::Debugging);
let episode_id = episode.episode_id;
index.insert_episode(&episode);
assert_eq!(index.total_episodes, 1);
let removed = index.remove_episode(episode_id);
assert!(removed);
assert_eq!(index.total_episodes, 0);
assert!(!index.uncategorized_episodes.contains(&episode_id));
}
#[test]
fn test_get_recent_episodes() {
let mut index = DomainIndex::new("test-domain".to_string());
for i in 0..5 {
let episode = create_test_episode("test-domain", TaskType::CodeGeneration);
index.insert_episode(&episode);
if i > 0 {
std::thread::sleep(std::time::Duration::from_millis(10));
}
}
let recent = index.get_recent_episodes(3);
assert_eq!(recent.len(), 3);
}
#[test]
fn test_cleanup_empty_clusters() {
let mut index = DomainIndex::new("test-domain".to_string());
let episode = create_test_episode("test-domain", TaskType::Analysis);
index.insert_episode(&episode);
index.remove_episode(episode.episode_id);
index.cleanup_empty_clusters();
}
}