use crate::episode::Episode;
use crate::pattern::Pattern;
use serde::{Deserialize, Serialize};
pub const DEFAULT_AFFINITY_THRESHOLD: f32 = 0.25;
pub const DEFAULT_MIN_SUCCESS_RATE: f32 = 0.70;
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct RelativeAffinity {
pub score_old: f32,
pub score_new: f32,
pub drel: f32,
}
impl RelativeAffinity {
pub fn compute(
episode: &Episode,
old_patterns: &[Pattern],
new_patterns: &[Pattern],
episode_embedding: Option<&[f32]>,
) -> Self {
let score_old = max_cosine_similarity(episode, old_patterns, episode_embedding);
let score_new = max_cosine_similarity(episode, new_patterns, episode_embedding);
let denom = score_old.max(score_new).max(1e-6);
let drel = (score_new - score_old).abs() / denom;
Self {
score_old,
score_new,
drel,
}
}
#[must_use]
pub fn is_ambiguous(&self, threshold: f32) -> bool {
self.drel < threshold
}
#[must_use]
pub fn clarity(&self) -> f32 {
1.0 - self.drel
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct EpisodeAssignmentGuard {
pub success_rate: f32,
pub affinity_clarity: f32,
pub min_success_rate: f32,
pub min_affinity_clarity: f32,
}
impl EpisodeAssignmentGuard {
#[must_use]
pub fn new(success_rate: f32, affinity_clarity: f32) -> Self {
Self {
success_rate,
affinity_clarity,
min_success_rate: DEFAULT_MIN_SUCCESS_RATE,
min_affinity_clarity: DEFAULT_AFFINITY_THRESHOLD,
}
}
#[must_use]
pub fn with_thresholds(
success_rate: f32,
affinity_clarity: f32,
min_success_rate: f32,
min_affinity_clarity: f32,
) -> Self {
Self {
success_rate,
affinity_clarity,
min_success_rate,
min_affinity_clarity,
}
}
#[must_use]
pub fn allows_mutation(&self) -> bool {
self.success_rate >= self.min_success_rate
&& self.affinity_clarity >= self.min_affinity_clarity
}
#[must_use]
pub fn allows_retrieval(&self) -> bool {
self.success_rate >= self.min_success_rate * 0.5
}
#[must_use]
pub fn rejection_reason(&self) -> Option<RejectionReason> {
if self.success_rate < self.min_success_rate {
return Some(RejectionReason::LowSuccessRate {
actual: self.success_rate,
required: self.min_success_rate,
});
}
if self.affinity_clarity < self.min_affinity_clarity {
return Some(RejectionReason::AmbiguousAffinity {
actual: self.affinity_clarity,
required: self.min_affinity_clarity,
});
}
None
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum RejectionReason {
LowSuccessRate { actual: f32, required: f32 },
AmbiguousAffinity { actual: f32, required: f32 },
}
#[derive(Debug, Clone)]
pub struct PatternAffinityClassifier {
affinity_threshold: f32,
min_success_rate: f32,
}
impl Default for PatternAffinityClassifier {
fn default() -> Self {
Self::new()
}
}
impl PatternAffinityClassifier {
#[must_use]
pub fn new() -> Self {
Self {
affinity_threshold: DEFAULT_AFFINITY_THRESHOLD,
min_success_rate: DEFAULT_MIN_SUCCESS_RATE,
}
}
#[must_use]
pub fn with_config(affinity_threshold: f32, min_success_rate: f32) -> Self {
Self {
affinity_threshold,
min_success_rate,
}
}
#[must_use]
pub fn compute_affinity(
&self,
episode: &Episode,
old_patterns: &[Pattern],
new_patterns: &[Pattern],
episode_embedding: Option<&[f32]>,
) -> RelativeAffinity {
RelativeAffinity::compute(episode, old_patterns, new_patterns, episode_embedding)
}
#[must_use]
pub fn create_guard(
&self,
episode: &Episode,
old_patterns: &[Pattern],
new_patterns: &[Pattern],
episode_embedding: Option<&[f32]>,
) -> EpisodeAssignmentGuard {
let affinity =
self.compute_affinity(episode, old_patterns, new_patterns, episode_embedding);
let success_rate = episode
.reward
.as_ref()
.map(|r| r.total / 2.0)
.unwrap_or(0.5);
EpisodeAssignmentGuard::with_thresholds(
success_rate,
affinity.clarity(),
self.min_success_rate,
self.affinity_threshold,
)
}
#[must_use]
pub fn should_gate_episode(
&self,
episode: &Episode,
old_patterns: &[Pattern],
new_patterns: &[Pattern],
episode_embedding: Option<&[f32]>,
) -> bool {
let affinity =
self.compute_affinity(episode, old_patterns, new_patterns, episode_embedding);
affinity.is_ambiguous(self.affinity_threshold)
}
#[must_use]
pub fn affinity_threshold(&self) -> f32 {
self.affinity_threshold
}
#[must_use]
pub fn min_success_rate(&self) -> f32 {
self.min_success_rate
}
}
fn max_cosine_similarity(
episode: &Episode,
patterns: &[Pattern],
episode_embedding: Option<&[f32]>,
) -> f32 {
if patterns.is_empty() {
return 0.0;
}
let ep_emb = episode_embedding;
patterns
.iter()
.map(|pattern| {
pattern_embedding_similarity(ep_emb, pattern)
.unwrap_or_else(|| context_similarity(episode, pattern))
})
.fold(0.0, f32::max)
}
fn pattern_embedding_similarity(
_episode_embedding: Option<&[f32]>,
_pattern: &Pattern,
) -> Option<f32> {
None
}
fn context_similarity(episode: &Episode, pattern: &Pattern) -> f32 {
let ep_context = &episode.context;
let pat_context = pattern.context();
match pat_context {
Some(pat_ctx) => {
let mut score = 0.0;
let mut components = 0;
if ep_context.domain == pat_ctx.domain {
score += 1.0;
}
components += 1;
let ep_tags: std::collections::HashSet<_> = ep_context.tags.iter().collect();
let pat_tags: std::collections::HashSet<_> = pat_ctx.tags.iter().collect();
let intersection = ep_tags.intersection(&pat_tags).count();
let union = ep_tags.union(&pat_tags).count();
if union > 0 {
score += intersection as f32 / union as f32;
components += 1;
}
if ep_context.language == pat_ctx.language {
score += 0.5;
components += 1;
}
score / components as f32
}
None => 0.3, }
}
#[cfg(test)]
mod tests;