pub mod domain_index;
pub mod types;
pub use domain_index::DomainIndex;
use crate::spatiotemporal::index::types::TemporalGranularity;
use crate::types::TaskType;
use chrono::{DateTime, Utc};
use std::collections::HashMap;
use uuid::Uuid;
#[derive(Debug, Clone)]
pub struct SpatiotemporalIndex {
pub domain_indices: HashMap<String, DomainIndex>,
pub total_episodes: usize,
pub created_at: DateTime<Utc>,
pub last_modified: DateTime<Utc>,
}
impl Default for SpatiotemporalIndex {
fn default() -> Self {
Self::new()
}
}
impl SpatiotemporalIndex {
#[must_use]
pub fn new() -> Self {
Self {
domain_indices: HashMap::new(),
total_episodes: 0,
created_at: Utc::now(),
last_modified: Utc::now(),
}
}
pub fn insert(&mut self, episode: &crate::Episode) {
let domain = episode.context.domain.clone();
self.domain_indices
.entry(domain.clone())
.or_insert_with(|| DomainIndex::new(domain))
.insert_episode(episode);
self.total_episodes += 1;
self.last_modified = Utc::now();
}
pub fn remove(&mut self, episode_id: Uuid) -> bool {
let mut removed = false;
for domain_index in self.domain_indices.values_mut() {
if domain_index.remove_episode(episode_id) {
removed = true;
}
}
if removed {
self.total_episodes = self.total_episodes.saturating_sub(1);
self.last_modified = Utc::now();
}
removed
}
#[must_use]
pub fn query(
&self,
domain: &str,
task_type: Option<TaskType>,
start_time: Option<DateTime<Utc>>,
end_time: Option<DateTime<Utc>>,
limit: usize,
) -> Vec<Uuid> {
if let Some(domain_index) = self.domain_indices.get(domain) {
let episode_ids = if let Some(task_type) = task_type {
let start = start_time
.unwrap_or_else(|| DateTime::from_timestamp(0, 0).unwrap_or_else(Utc::now));
let end = end_time.unwrap_or_else(|| {
DateTime::from_timestamp(253_402_300_799, 999_999_999).unwrap_or_else(Utc::now)
});
domain_index.get_episodes_by_task_type_and_time(task_type, start, end)
} else {
domain_index.get_recent_episodes(limit)
};
return episode_ids.into_iter().take(limit).collect();
}
Vec::new()
}
#[must_use]
pub fn get_domains_in_time_range(
&self,
start: DateTime<Utc>,
end: DateTime<Utc>,
) -> Vec<String> {
let mut domains = Vec::new();
for (domain, domain_index) in &self.domain_indices {
for task_type_index in domain_index.task_type_indices.values() {
let episodes_in_range = task_type_index.get_episodes_in_range(start, end);
if !episodes_in_range.is_empty() {
domains.push(domain.clone());
break;
}
}
}
domains
}
#[must_use]
pub fn get_domain_counts(&self) -> HashMap<String, usize> {
self.domain_indices
.iter()
.map(|(domain, idx)| (domain.clone(), idx.total_episodes))
.collect()
}
#[must_use]
pub fn get_temporal_distribution(&self, domain: &str) -> HashMap<TemporalGranularity, usize> {
let mut distribution = HashMap::new();
if let Some(domain_index) = self.domain_indices.get(domain) {
for task_type_index in domain_index.task_type_indices.values() {
for cluster in &task_type_index.temporal_clusters {
*distribution.entry(cluster.granularity).or_insert(0) += cluster.cluster_size;
}
}
}
distribution
}
pub fn clear(&mut self) {
self.domain_indices.clear();
self.total_episodes = 0;
self.last_modified = Utc::now();
}
#[must_use]
pub fn num_domains(&self) -> usize {
self.domain_indices.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{TaskContext, TaskType};
fn create_test_episode(domain: &str, task_type: TaskType) -> crate::Episode {
let context = TaskContext {
domain: domain.to_string(),
complexity: crate::types::ComplexityLevel::Simple,
tags: vec![],
..Default::default()
};
crate::Episode::new("Test episode".to_string(), context, task_type)
}
#[test]
fn test_insert_and_query() {
let mut index = SpatiotemporalIndex::new();
let episode1 = create_test_episode("web-api", TaskType::CodeGeneration);
let episode2 = create_test_episode("web-api", TaskType::CodeGeneration);
let episode3 = create_test_episode("data-processing", TaskType::Analysis);
index.insert(&episode1);
index.insert(&episode2);
index.insert(&episode3);
assert_eq!(index.total_episodes, 3);
assert_eq!(index.num_domains(), 2);
let results = index.query("web-api", None, None, None, 10);
assert_eq!(results.len(), 4);
let results = index.query("web-api", Some(TaskType::CodeGeneration), None, None, 10);
assert_eq!(results.len(), 2);
let results = index.query("nonexistent", None, None, None, 10);
assert!(results.is_empty());
}
#[test]
fn test_remove() {
let mut index = SpatiotemporalIndex::new();
let episode = create_test_episode("test-domain", TaskType::Debugging);
let episode_id = episode.episode_id;
index.insert(&episode);
assert_eq!(index.total_episodes, 1);
let removed = index.remove(episode_id);
assert!(removed);
assert_eq!(index.total_episodes, 0);
let removed = index.remove(episode_id);
assert!(!removed);
}
#[test]
fn test_domain_counts() {
let mut index = SpatiotemporalIndex::new();
for _ in 0..5 {
index.insert(&create_test_episode("domain-a", TaskType::CodeGeneration));
}
for _ in 0..3 {
index.insert(&create_test_episode("domain-b", TaskType::CodeGeneration));
}
let counts = index.get_domain_counts();
assert_eq!(counts.get("domain-a"), Some(&5));
assert_eq!(counts.get("domain-b"), Some(&3));
}
#[test]
fn test_temporal_distribution() {
let mut index = SpatiotemporalIndex::new();
for _ in 0..3 {
index.insert(&create_test_episode(
"test-domain",
TaskType::CodeGeneration,
));
}
let distribution = index.get_temporal_distribution("test-domain");
assert!(distribution.contains_key(&TemporalGranularity::Weekly));
}
}