mod scoring;
mod types;
#[cfg(test)]
mod tests;
pub use types::{HierarchicalScore, RetrievalQuery};
use crate::episode::Episode;
use anyhow::Result;
use std::sync::Arc;
use tracing::{debug, instrument};
#[derive(Debug, Clone)]
pub struct HierarchicalRetriever {
temporal_bias_weight: f32,
max_clusters_to_search: usize,
}
impl Default for HierarchicalRetriever {
fn default() -> Self {
Self::new()
}
}
impl HierarchicalRetriever {
#[must_use]
pub fn new() -> Self {
Self {
temporal_bias_weight: 0.3,
max_clusters_to_search: 5,
}
}
#[must_use]
pub fn with_config(temporal_bias_weight: f32, max_clusters_to_search: usize) -> Self {
Self {
temporal_bias_weight,
max_clusters_to_search,
}
}
#[instrument(skip(self, all_episodes), fields(
query_text = %query.query_text,
query_domain = ?query.domain,
query_task_type = ?query.task_type,
total_episodes = all_episodes.len(),
limit = query.limit
))]
pub async fn retrieve(
&self,
query: &RetrievalQuery,
all_episodes: &[Arc<Episode>],
) -> Result<Vec<HierarchicalScore>> {
debug!("Starting hierarchical retrieval");
let domain_filtered = self.filter_by_domain(all_episodes, query);
debug!(
"Level 1 (domain filter): {} episodes",
domain_filtered.len()
);
let task_filtered = self.filter_by_task_type(&domain_filtered, query);
debug!(
"Level 2 (task type filter): {} episodes",
task_filtered.len()
);
let temporal_candidates = self.select_temporal_clusters(&task_filtered, query);
debug!(
"Level 3 (temporal clusters): {} episodes",
temporal_candidates.len()
);
let scored = self.score_episodes(&temporal_candidates, query);
debug!("Level 4 (similarity scoring): {} episodes", scored.len());
let mut ranked = self.rank_by_combined_score(scored);
ranked.truncate(query.limit);
debug!(
"Hierarchical retrieval complete: {} results returned",
ranked.len()
);
Ok(ranked)
}
}