use super::types::{
HierarchicalScore, RetrievalQuery, calculate_text_similarity, get_or_generate_episode_embedding,
};
use crate::episode::Episode;
use chrono::Utc;
use std::sync::Arc;
impl super::HierarchicalRetriever {
pub(super) fn filter_by_domain<'a>(
&self,
episodes: &'a [Arc<Episode>],
query: &RetrievalQuery,
) -> Vec<&'a Arc<Episode>> {
if let Some(ref domain) = query.domain {
episodes
.iter()
.filter(|ep| ep.context.domain == *domain)
.collect()
} else {
episodes.iter().collect()
}
}
pub(super) fn filter_by_task_type<'a>(
&self,
candidates: &[&'a Arc<Episode>],
query: &RetrievalQuery,
) -> Vec<&'a Arc<Episode>> {
if let Some(task_type) = query.task_type {
candidates
.iter()
.filter(|ep| ep.task_type == task_type)
.copied()
.collect()
} else {
candidates.to_vec()
}
}
pub(super) fn select_temporal_clusters<'a>(
&self,
candidates: &[&'a Arc<Episode>],
_query: &RetrievalQuery,
) -> Vec<&'a Arc<Episode>> {
if candidates.is_empty() {
return vec![];
}
let mut sorted: Vec<_> = candidates.to_vec();
sorted.sort_by_key(|b| std::cmp::Reverse(b.start_time));
let cluster_size = candidates.len() / self.max_clusters_to_search.max(1);
let take_count = cluster_size.max(10).min(candidates.len());
sorted.into_iter().take(take_count).collect()
}
pub(super) fn score_episodes(
&self,
candidates: &[&Arc<Episode>],
query: &RetrievalQuery,
) -> Vec<HierarchicalScore> {
let now = Utc::now();
let scored: Vec<HierarchicalScore> = candidates
.iter()
.map(|episode| {
let level_1_score = if let Some(ref domain) = query.domain {
if episode.context.domain == *domain {
1.0
} else {
0.0
}
} else {
0.5 };
let level_2_score = if let Some(task_type) = query.task_type {
if episode.task_type == task_type {
1.0
} else {
0.0
}
} else {
0.5 };
let age_seconds = (now - episode.start_time).num_seconds().max(0) as f32;
let max_age_seconds = 30.0 * 24.0 * 3600.0; let level_3_score = 1.0 - (age_seconds / max_age_seconds).min(1.0);
let level_4_score = if let Some(ref query_emb) = query.query_embedding {
let episode_emb =
get_or_generate_episode_embedding(episode, &query.episode_embeddings);
let similarity = crate::embeddings::cosine_similarity(query_emb, &episode_emb);
(similarity + 1.0) / 2.0 } else {
calculate_text_similarity(
&query.query_text.to_lowercase(),
&episode.task_description.to_lowercase(),
)
};
let temporal_weight = self.temporal_bias_weight;
let similarity_weight = 1.0 - temporal_weight - 0.6;
let relevance_score = 0.3 * level_1_score
+ 0.3 * level_2_score
+ temporal_weight * level_3_score
+ similarity_weight.max(0.1) * level_4_score;
HierarchicalScore {
episode_id: episode.episode_id,
relevance_score,
level_1_score,
level_2_score,
level_3_score,
level_4_score,
}
})
.collect();
scored
}
pub(super) fn rank_by_combined_score(
&self,
mut scored: Vec<HierarchicalScore>,
) -> Vec<HierarchicalScore> {
scored.sort_by(|a, b| {
b.relevance_score
.partial_cmp(&a.relevance_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
scored
}
}