use super::types::{DistillationCandidate, QueryPattern, current_timestamp};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TriggerCondition {
FrequencyThreshold(u32),
QAPairCount(usize),
ConfidenceBelow(f32),
ConfidenceAbove(f32),
TimeSinceFirstSeen(u64),
Combined(Vec<TriggerCondition>),
Any(Vec<TriggerCondition>),
}
impl TriggerCondition {
#[must_use]
pub fn frequency(threshold: u32) -> Self {
Self::FrequencyThreshold(threshold)
}
#[must_use]
pub fn min_qa_pairs(count: usize) -> Self {
Self::QAPairCount(count)
}
#[must_use]
pub fn confidence_below(threshold: f32) -> Self {
Self::ConfidenceBelow(threshold.clamp(0.0, 1.0))
}
#[must_use]
pub fn confidence_above(threshold: f32) -> Self {
Self::ConfidenceAbove(threshold.clamp(0.0, 1.0))
}
#[must_use]
pub fn time_since_first_seen(duration: Duration) -> Self {
Self::TimeSinceFirstSeen(duration.as_secs())
}
#[must_use]
pub fn all(conditions: Vec<TriggerCondition>) -> Self {
Self::Combined(conditions)
}
#[must_use]
pub fn any_of(conditions: Vec<TriggerCondition>) -> Self {
Self::Any(conditions)
}
#[must_use]
pub fn is_satisfied(&self, candidate: &DistillationCandidate) -> bool {
match self {
Self::FrequencyThreshold(threshold) => candidate.frequency >= *threshold,
Self::QAPairCount(min_count) => candidate.qa_pairs.len() >= *min_count,
Self::ConfidenceBelow(threshold) => {
!candidate.qa_pairs.is_empty() && candidate.avg_confidence < *threshold
}
Self::ConfidenceAbove(threshold) => {
!candidate.qa_pairs.is_empty() && candidate.avg_confidence >= *threshold
}
Self::TimeSinceFirstSeen(secs) => {
let now = current_timestamp();
now.saturating_sub(candidate.first_seen) >= *secs
}
Self::Combined(conditions) => conditions.iter().all(|c| c.is_satisfied(candidate)),
Self::Any(conditions) => conditions.iter().any(|c| c.is_satisfied(candidate)),
}
}
#[must_use]
pub fn description(&self) -> String {
match self {
Self::FrequencyThreshold(t) => format!("frequency >= {t}"),
Self::QAPairCount(c) => format!("Q&A pairs >= {c}"),
Self::ConfidenceBelow(t) => format!("confidence < {t:.2}"),
Self::ConfidenceAbove(t) => format!("confidence >= {t:.2}"),
Self::TimeSinceFirstSeen(s) => format!("time since first seen >= {s}s"),
Self::Combined(conditions) => {
let descs: Vec<_> = conditions.iter().map(Self::description).collect();
format!("({})", descs.join(" AND "))
}
Self::Any(conditions) => {
let descs: Vec<_> = conditions.iter().map(Self::description).collect();
format!("({})", descs.join(" OR "))
}
}
}
}
impl Default for TriggerCondition {
fn default() -> Self {
Self::Combined(vec![
Self::FrequencyThreshold(5),
Self::QAPairCount(3),
Self::ConfidenceAbove(0.7),
])
}
}
#[derive(Debug)]
pub struct DistillationTrigger {
conditions: Vec<TriggerCondition>,
cooldown: Duration,
last_triggered: HashMap<u64, u64>,
}
impl DistillationTrigger {
#[must_use]
pub fn new(conditions: Vec<TriggerCondition>) -> Self {
Self {
conditions,
cooldown: Duration::from_secs(3600), last_triggered: HashMap::new(),
}
}
#[must_use]
pub fn with_condition(condition: TriggerCondition) -> Self {
Self::new(vec![condition])
}
#[must_use]
pub fn with_defaults() -> Self {
Self::new(vec![TriggerCondition::default()])
}
#[must_use]
pub fn with_cooldown(mut self, cooldown: Duration) -> Self {
self.cooldown = cooldown;
self
}
#[must_use]
pub fn with_additional_condition(mut self, condition: TriggerCondition) -> Self {
self.conditions.push(condition);
self
}
#[must_use]
pub fn conditions(&self) -> &[TriggerCondition] {
&self.conditions
}
#[must_use]
pub fn cooldown(&self) -> Duration {
self.cooldown
}
#[must_use]
pub fn should_trigger(&self, candidate: &DistillationCandidate) -> bool {
if self.is_in_cooldown(&candidate.pattern) {
return false;
}
self.conditions.iter().all(|c| c.is_satisfied(candidate))
}
#[must_use]
pub fn is_in_cooldown(&self, pattern: &QueryPattern) -> bool {
if let Some(last_time) = self.last_triggered.get(&pattern.pattern_hash) {
let now = current_timestamp();
now.saturating_sub(*last_time) < self.cooldown.as_secs()
} else {
false
}
}
#[must_use]
pub fn remaining_cooldown(&self, pattern: &QueryPattern) -> Option<Duration> {
self.last_triggered
.get(&pattern.pattern_hash)
.and_then(|last_time| {
let now = current_timestamp();
let elapsed = now.saturating_sub(*last_time);
let cooldown_secs = self.cooldown.as_secs();
if elapsed < cooldown_secs {
Some(Duration::from_secs(cooldown_secs - elapsed))
} else {
None
}
})
}
pub fn mark_triggered(&mut self, pattern: &QueryPattern) {
self.last_triggered
.insert(pattern.pattern_hash, current_timestamp());
}
pub fn clear_triggered(&mut self, pattern: &QueryPattern) {
self.last_triggered.remove(&pattern.pattern_hash);
}
#[must_use]
pub fn evaluate_all<'a>(
&self,
candidates: &'a [DistillationCandidate],
) -> Vec<&'a DistillationCandidate> {
candidates
.iter()
.filter(|c| self.should_trigger(c))
.collect()
}
#[must_use]
pub fn evaluate_with_details<'a>(
&self,
candidates: &'a [DistillationCandidate],
) -> Vec<TriggerEvaluation<'a>> {
candidates
.iter()
.map(|c| self.evaluate_candidate(c))
.collect()
}
#[must_use]
pub fn evaluate_candidate<'a>(
&self,
candidate: &'a DistillationCandidate,
) -> TriggerEvaluation<'a> {
let in_cooldown = self.is_in_cooldown(&candidate.pattern);
let condition_results: Vec<_> = self
.conditions
.iter()
.map(|c| (c.description(), c.is_satisfied(candidate)))
.collect();
let all_conditions_met = condition_results.iter().all(|(_, met)| *met);
let should_trigger = all_conditions_met && !in_cooldown;
TriggerEvaluation {
candidate,
should_trigger,
in_cooldown,
remaining_cooldown: self.remaining_cooldown(&candidate.pattern),
condition_results,
}
}
#[must_use]
pub fn statistics(&self) -> TriggerStatistics {
let now = current_timestamp();
let cooldown_secs = self.cooldown.as_secs();
let patterns_in_cooldown = self
.last_triggered
.values()
.filter(|&t| now.saturating_sub(*t) < cooldown_secs)
.count();
TriggerStatistics {
total_conditions: self.conditions.len(),
patterns_tracked: self.last_triggered.len(),
patterns_in_cooldown,
cooldown_secs: self.cooldown.as_secs(),
}
}
pub fn clear_history(&mut self) {
self.last_triggered.clear();
}
pub fn prune_history(&mut self, max_age_secs: u64) {
let now = current_timestamp();
self.last_triggered
.retain(|_, t| now.saturating_sub(*t) <= max_age_secs);
}
}
impl Default for DistillationTrigger {
fn default() -> Self {
Self::with_defaults()
}
}
impl Clone for DistillationTrigger {
fn clone(&self) -> Self {
Self {
conditions: self.conditions.clone(),
cooldown: self.cooldown,
last_triggered: self.last_triggered.clone(),
}
}
}
#[derive(Debug)]
pub struct TriggerEvaluation<'a> {
pub candidate: &'a DistillationCandidate,
pub should_trigger: bool,
pub in_cooldown: bool,
pub remaining_cooldown: Option<Duration>,
pub condition_results: Vec<(String, bool)>,
}
impl TriggerEvaluation<'_> {
#[must_use]
pub fn satisfied_count(&self) -> usize {
self.condition_results
.iter()
.filter(|(_, met)| *met)
.count()
}
#[must_use]
pub fn total_conditions(&self) -> usize {
self.condition_results.len()
}
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn satisfaction_ratio(&self) -> f32 {
if self.condition_results.is_empty() {
1.0
} else {
self.satisfied_count() as f32 / self.total_conditions() as f32
}
}
}
#[derive(Debug, Clone, Default)]
pub struct TriggerStatistics {
pub total_conditions: usize,
pub patterns_tracked: usize,
pub patterns_in_cooldown: usize,
pub cooldown_secs: u64,
}
#[cfg(test)]
#[allow(clippy::similar_names)]
mod tests {
use super::*;
use crate::distillation::types::QAPair;
fn create_test_candidate(
query: &str,
frequency: u32,
pairs_count: usize,
confidence: f32,
) -> DistillationCandidate {
let pattern = QueryPattern::new(query);
let mut candidate = DistillationCandidate::new(pattern.clone());
for _ in 1..frequency {
candidate.record_occurrence();
}
for i in 0..pairs_count {
let pair = QAPair::new(query, &format!("Answer {i}"), confidence, pattern.clone());
candidate.add_qa_pair(pair, 100);
}
candidate
}
#[test]
fn test_frequency_threshold_condition() {
let condition = TriggerCondition::frequency(5);
let low = create_test_candidate("test", 3, 1, 0.9);
let high = create_test_candidate("test", 7, 1, 0.9);
assert!(!condition.is_satisfied(&low));
assert!(condition.is_satisfied(&high));
}
#[test]
fn test_qa_pair_count_condition() {
let condition = TriggerCondition::min_qa_pairs(3);
let few = create_test_candidate("test", 5, 2, 0.9);
let many = create_test_candidate("test", 5, 5, 0.9);
assert!(!condition.is_satisfied(&few));
assert!(condition.is_satisfied(&many));
}
#[test]
fn test_confidence_below_condition() {
let condition = TriggerCondition::confidence_below(0.8);
let high_conf = create_test_candidate("test", 5, 2, 0.9);
let low_conf = create_test_candidate("test", 5, 2, 0.7);
assert!(!condition.is_satisfied(&high_conf));
assert!(condition.is_satisfied(&low_conf));
}
#[test]
fn test_confidence_above_condition() {
let condition = TriggerCondition::confidence_above(0.8);
let high_conf = create_test_candidate("test", 5, 2, 0.9);
let low_conf = create_test_candidate("test", 5, 2, 0.7);
assert!(condition.is_satisfied(&high_conf));
assert!(!condition.is_satisfied(&low_conf));
}
#[test]
fn test_combined_conditions() {
let condition = TriggerCondition::all(vec![
TriggerCondition::frequency(5),
TriggerCondition::min_qa_pairs(2),
]);
let meets_none = create_test_candidate("test", 3, 1, 0.9);
let meets_one = create_test_candidate("test", 6, 1, 0.9);
let meets_all = create_test_candidate("test", 6, 3, 0.9);
assert!(!condition.is_satisfied(&meets_none));
assert!(!condition.is_satisfied(&meets_one));
assert!(condition.is_satisfied(&meets_all));
}
#[test]
fn test_any_conditions() {
let condition = TriggerCondition::any_of(vec![
TriggerCondition::frequency(10),
TriggerCondition::min_qa_pairs(5),
]);
let meets_none = create_test_candidate("test", 3, 2, 0.9);
let meets_freq = create_test_candidate("test", 12, 2, 0.9);
let meets_pairs = create_test_candidate("test", 3, 6, 0.9);
assert!(!condition.is_satisfied(&meets_none));
assert!(condition.is_satisfied(&meets_freq));
assert!(condition.is_satisfied(&meets_pairs));
}
#[test]
fn test_condition_description() {
let freq = TriggerCondition::frequency(5);
assert_eq!(freq.description(), "frequency >= 5");
let combined = TriggerCondition::all(vec![
TriggerCondition::frequency(5),
TriggerCondition::min_qa_pairs(3),
]);
assert!(combined.description().contains("AND"));
}
#[test]
fn test_distillation_trigger_creation() {
let trigger = DistillationTrigger::with_defaults();
assert!(!trigger.conditions.is_empty());
}
#[test]
fn test_distillation_trigger_with_cooldown() {
let trigger = DistillationTrigger::with_defaults().with_cooldown(Duration::from_secs(1800));
assert_eq!(trigger.cooldown().as_secs(), 1800);
}
#[test]
fn test_distillation_trigger_should_trigger() {
let trigger = DistillationTrigger::new(vec![
TriggerCondition::frequency(3),
TriggerCondition::min_qa_pairs(2),
]);
let ready = create_test_candidate("test", 5, 3, 0.9);
let not_ready = create_test_candidate("test", 2, 1, 0.9);
assert!(trigger.should_trigger(&ready));
assert!(!trigger.should_trigger(¬_ready));
}
#[test]
fn test_distillation_trigger_cooldown() {
let mut trigger = DistillationTrigger::new(vec![TriggerCondition::frequency(3)])
.with_cooldown(Duration::from_secs(3600));
let candidate = create_test_candidate("test", 5, 1, 0.9);
assert!(trigger.should_trigger(&candidate));
trigger.mark_triggered(&candidate.pattern);
assert!(!trigger.should_trigger(&candidate));
assert!(trigger.is_in_cooldown(&candidate.pattern));
}
#[test]
fn test_distillation_trigger_clear_triggered() {
let mut trigger = DistillationTrigger::new(vec![TriggerCondition::frequency(3)])
.with_cooldown(Duration::from_secs(3600));
let candidate = create_test_candidate("test", 5, 1, 0.9);
trigger.mark_triggered(&candidate.pattern);
assert!(trigger.is_in_cooldown(&candidate.pattern));
trigger.clear_triggered(&candidate.pattern);
assert!(!trigger.is_in_cooldown(&candidate.pattern));
}
#[test]
fn test_distillation_trigger_evaluate_all() {
let trigger = DistillationTrigger::new(vec![TriggerCondition::frequency(5)]);
let candidates = vec![
create_test_candidate("ready1", 6, 1, 0.9),
create_test_candidate("not_ready", 3, 1, 0.9),
create_test_candidate("ready2", 8, 1, 0.9),
];
let triggered = trigger.evaluate_all(&candidates);
assert_eq!(triggered.len(), 2);
}
#[test]
fn test_distillation_trigger_evaluate_with_details() {
let trigger = DistillationTrigger::new(vec![
TriggerCondition::frequency(5),
TriggerCondition::min_qa_pairs(2),
]);
let candidate = create_test_candidate("test", 3, 1, 0.9);
let eval = trigger.evaluate_candidate(&candidate);
assert!(!eval.should_trigger);
assert_eq!(eval.total_conditions(), 2);
assert_eq!(eval.satisfied_count(), 0);
}
#[test]
fn test_trigger_evaluation_satisfaction_ratio() {
let trigger = DistillationTrigger::new(vec![
TriggerCondition::frequency(3),
TriggerCondition::min_qa_pairs(5),
]);
let candidate = create_test_candidate("test", 5, 3, 0.9);
let eval = trigger.evaluate_candidate(&candidate);
assert!((eval.satisfaction_ratio() - 0.5).abs() < f32::EPSILON);
}
#[test]
fn test_distillation_trigger_statistics() {
let mut trigger = DistillationTrigger::new(vec![
TriggerCondition::frequency(3),
TriggerCondition::min_qa_pairs(2),
])
.with_cooldown(Duration::from_secs(3600));
let candidate1 = create_test_candidate("test1", 5, 1, 0.9);
let candidate2 = create_test_candidate("test2", 5, 1, 0.9);
trigger.mark_triggered(&candidate1.pattern);
trigger.mark_triggered(&candidate2.pattern);
let stats = trigger.statistics();
assert_eq!(stats.total_conditions, 2);
assert_eq!(stats.patterns_tracked, 2);
assert_eq!(stats.patterns_in_cooldown, 2);
}
#[test]
fn test_distillation_trigger_clear_history() {
let mut trigger = DistillationTrigger::with_defaults();
let candidate = create_test_candidate("test", 5, 1, 0.9);
trigger.mark_triggered(&candidate.pattern);
assert!(!trigger.last_triggered.is_empty());
trigger.clear_history();
assert!(trigger.last_triggered.is_empty());
}
#[test]
fn test_distillation_trigger_clone() {
let mut trigger = DistillationTrigger::new(vec![TriggerCondition::frequency(5)]);
let candidate = create_test_candidate("test", 5, 1, 0.9);
trigger.mark_triggered(&candidate.pattern);
let cloned = trigger.clone();
assert_eq!(trigger.conditions.len(), cloned.conditions.len());
assert!(cloned.is_in_cooldown(&candidate.pattern));
}
#[test]
fn test_default_trigger_condition() {
let default = TriggerCondition::default();
if let TriggerCondition::Combined(conditions) = default {
assert_eq!(conditions.len(), 3);
} else {
panic!("Default should be Combined");
}
}
#[test]
fn test_with_additional_condition() {
let trigger = DistillationTrigger::with_defaults()
.with_additional_condition(TriggerCondition::frequency(10));
assert_eq!(trigger.conditions.len(), 2);
}
#[test]
fn test_remaining_cooldown() {
let mut trigger =
DistillationTrigger::with_defaults().with_cooldown(Duration::from_secs(3600));
let candidate = create_test_candidate("test", 5, 1, 0.9);
assert!(trigger.remaining_cooldown(&candidate.pattern).is_none());
trigger.mark_triggered(&candidate.pattern);
let remaining = trigger.remaining_cooldown(&candidate.pattern);
assert!(remaining.is_some());
assert!(remaining.unwrap().as_secs() > 0);
}
#[test]
fn test_time_since_first_seen_condition() {
let condition = TriggerCondition::time_since_first_seen(Duration::from_secs(0));
let candidate = create_test_candidate("test", 5, 1, 0.9);
assert!(condition.is_satisfied(&candidate));
}
#[test]
fn test_condition_with_no_qa_pairs() {
let confidence_below = TriggerCondition::confidence_below(0.5);
let confidence_above = TriggerCondition::confidence_above(0.5);
let candidate = create_test_candidate("test", 5, 0, 0.0);
assert!(!confidence_below.is_satisfied(&candidate));
assert!(!confidence_above.is_satisfied(&candidate));
}
}