use crate::episode::Episode;
use crate::types::TaskContext;
use std::sync::Arc;
use super::super::SelfLearningMemory;
impl SelfLearningMemory {
pub(super) fn is_relevant_episode(
&self,
episode: &Arc<Episode>,
context: &TaskContext,
task_description: &str,
) -> bool {
if episode.context.domain == context.domain {
return true;
}
if episode.context.language == context.language && episode.context.language.is_some() {
return true;
}
if episode.context.framework == context.framework && episode.context.framework.is_some() {
return true;
}
let common_tags: Vec<_> = episode
.context
.tags
.iter()
.filter(|t| context.tags.contains(t))
.collect();
if !common_tags.is_empty() {
return true;
}
let desc_lower = task_description.to_lowercase();
let episode_desc_lower = episode.task_description.to_lowercase();
let common_words: Vec<_> = desc_lower
.split_whitespace()
.filter(|w| w.len() > 3) .filter(|w| episode_desc_lower.contains(w))
.collect();
!common_words.is_empty()
}
pub(super) fn calculate_relevance_score(
&self,
episode: &Arc<Episode>,
context: &TaskContext,
task_description: &str,
) -> f32 {
let episode_ref: &Episode = episode.as_ref();
let mut score = 0.0;
if let Some(reward) = &episode_ref.reward {
score += reward.total * 0.3;
}
let mut context_score = 0.0;
if episode_ref.context.domain == context.domain {
context_score += 0.4;
}
if episode_ref.context.language == context.language
&& episode_ref.context.language.is_some()
{
context_score += 0.3;
}
if episode_ref.context.framework == context.framework
&& episode_ref.context.framework.is_some()
{
context_score += 0.2;
}
let common_tags: Vec<_> = episode_ref
.context
.tags
.iter()
.filter(|t| context.tags.contains(t))
.collect();
if !common_tags.is_empty() {
context_score += 0.1 * common_tags.len() as f32;
}
score += context_score.min(0.4);
let desc_lower = task_description.to_lowercase();
let episode_desc_lower = episode_ref.task_description.to_lowercase();
let desc_words: Vec<_> = desc_lower.split_whitespace().collect();
let common_words: Vec<_> = desc_words
.iter()
.filter(|w| w.len() > 3)
.filter(|w| episode_desc_lower.contains(**w))
.collect();
if !desc_words.is_empty() {
let similarity = common_words.len() as f32 / desc_words.len() as f32;
score += similarity * 0.3;
}
score
}
pub(super) fn calculate_heuristic_relevance(
&self,
heuristic: &crate::pattern::Heuristic,
context: &TaskContext,
) -> f32 {
let mut score = 0.0;
let condition_lower = heuristic.condition.to_lowercase();
if condition_lower.contains(&context.domain.to_lowercase()) {
score += 1.0;
}
if let Some(lang) = &context.language {
if condition_lower.contains(&lang.to_lowercase()) {
score += 0.8;
}
}
if let Some(framework) = &context.framework {
if condition_lower.contains(&framework.to_lowercase()) {
score += 0.5;
}
}
for tag in &context.tags {
if condition_lower.contains(&tag.to_lowercase()) {
score += 0.3;
}
}
if score == 0.0 {
score = 0.1;
}
score
}
}