use std::sync::Arc;
use std::time::Duration;
use super::store::{EpisodeStore, StoreError};
use crate::util::epoch_millis;
pub struct TriggerContext<'a> {
pub store: Option<&'a dyn EpisodeStore>,
pub event_count: Option<usize>,
pub last_train_at: Option<u64>,
pub last_train_count: usize,
pub metrics: Option<&'a TriggerMetrics>,
}
impl<'a> TriggerContext<'a> {
pub fn with_store(store: &'a dyn EpisodeStore) -> Self {
Self {
store: Some(store),
event_count: None,
last_train_at: None,
last_train_count: 0,
metrics: None,
}
}
pub fn with_count(count: usize) -> Self {
Self {
store: None,
event_count: Some(count),
last_train_at: None,
last_train_count: 0,
metrics: None,
}
}
pub fn last_train_at(mut self, timestamp: u64) -> Self {
self.last_train_at = Some(timestamp);
self
}
pub fn last_train_count(mut self, count: usize) -> Self {
self.last_train_count = count;
self
}
pub fn metrics(mut self, metrics: &'a TriggerMetrics) -> Self {
self.metrics = Some(metrics);
self
}
pub fn current_count(&self) -> Result<usize, TriggerError> {
if let Some(count) = self.event_count {
return Ok(count);
}
if let Some(store) = self.store {
return Ok(store.count(None)?);
}
Ok(0)
}
}
#[derive(Debug, Clone, Default)]
pub struct TriggerMetrics {
pub recent_success_rate: f64,
pub overall_success_rate: f64,
pub recent_sample_size: usize,
}
#[derive(Debug)]
pub enum TriggerError {
Store(StoreError),
MetricsUnavailable(String),
}
impl std::fmt::Display for TriggerError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Store(e) => write!(f, "Store error: {}", e),
Self::MetricsUnavailable(msg) => write!(f, "Metrics unavailable: {}", msg),
}
}
}
impl std::error::Error for TriggerError {}
impl From<StoreError> for TriggerError {
fn from(e: StoreError) -> Self {
Self::Store(e)
}
}
pub trait TrainTrigger: Send + Sync {
fn should_train(&self, context: &TriggerContext) -> Result<bool, TriggerError>;
fn name(&self) -> &str;
fn describe(&self) -> String;
}
pub struct CountTrigger {
threshold: usize,
}
impl CountTrigger {
pub fn new(threshold: usize) -> Self {
Self { threshold }
}
}
impl TrainTrigger for CountTrigger {
fn should_train(&self, ctx: &TriggerContext) -> Result<bool, TriggerError> {
let current_count = ctx.current_count()?;
let new_episodes = current_count.saturating_sub(ctx.last_train_count);
Ok(new_episodes >= self.threshold)
}
fn name(&self) -> &str {
"count"
}
fn describe(&self) -> String {
format!("Train when {} new episodes accumulated", self.threshold)
}
}
pub struct TimeTrigger {
interval_secs: u64,
}
impl TimeTrigger {
pub fn new(interval: Duration) -> Self {
Self {
interval_secs: interval.as_secs(),
}
}
pub fn hours(hours: u64) -> Self {
Self {
interval_secs: hours * 3600,
}
}
pub fn minutes(minutes: u64) -> Self {
Self {
interval_secs: minutes * 60,
}
}
}
impl TrainTrigger for TimeTrigger {
fn should_train(&self, ctx: &TriggerContext) -> Result<bool, TriggerError> {
let Some(last_train) = ctx.last_train_at else {
let count = ctx.current_count()?;
return Ok(count > 0);
};
let now = epoch_millis();
let elapsed_secs = (now.saturating_sub(last_train)) / 1000;
Ok(elapsed_secs >= self.interval_secs)
}
fn name(&self) -> &str {
"time"
}
fn describe(&self) -> String {
if self.interval_secs >= 3600 {
format!("Train every {} hours", self.interval_secs / 3600)
} else if self.interval_secs >= 60 {
format!("Train every {} minutes", self.interval_secs / 60)
} else {
format!("Train every {} seconds", self.interval_secs)
}
}
}
pub struct QualityTrigger {
threshold: f64,
min_samples: usize,
}
impl QualityTrigger {
pub fn new(threshold: f64) -> Self {
Self {
threshold,
min_samples: 10,
}
}
pub fn with_min_samples(mut self, min: usize) -> Self {
self.min_samples = min;
self
}
}
impl TrainTrigger for QualityTrigger {
fn should_train(&self, ctx: &TriggerContext) -> Result<bool, TriggerError> {
let metrics = ctx.metrics.ok_or_else(|| {
TriggerError::MetricsUnavailable("QualityTrigger requires metrics".into())
})?;
if metrics.recent_sample_size < self.min_samples {
return Ok(false);
}
Ok(metrics.recent_success_rate < self.threshold)
}
fn name(&self) -> &str {
"quality"
}
fn describe(&self) -> String {
format!(
"Train when success rate < {:.0}% (min {} samples)",
self.threshold * 100.0,
self.min_samples
)
}
}
pub struct ManualTrigger;
impl TrainTrigger for ManualTrigger {
fn should_train(&self, _ctx: &TriggerContext) -> Result<bool, TriggerError> {
Ok(false)
}
fn name(&self) -> &str {
"manual"
}
fn describe(&self) -> String {
"Manual trigger only".into()
}
}
pub struct NeverTrigger;
impl TrainTrigger for NeverTrigger {
fn should_train(&self, _ctx: &TriggerContext) -> Result<bool, TriggerError> {
Ok(false)
}
fn name(&self) -> &str {
"never"
}
fn describe(&self) -> String {
"Never triggers".into()
}
}
pub struct AlwaysTrigger;
impl TrainTrigger for AlwaysTrigger {
fn should_train(&self, _ctx: &TriggerContext) -> Result<bool, TriggerError> {
Ok(true)
}
fn name(&self) -> &str {
"always"
}
fn describe(&self) -> String {
"Always triggers".into()
}
}
pub struct OrTrigger {
triggers: Vec<Arc<dyn TrainTrigger>>,
}
impl OrTrigger {
pub fn new(triggers: Vec<Arc<dyn TrainTrigger>>) -> Self {
Self { triggers }
}
}
impl TrainTrigger for OrTrigger {
fn should_train(&self, ctx: &TriggerContext) -> Result<bool, TriggerError> {
for trigger in &self.triggers {
if trigger.should_train(ctx)? {
return Ok(true);
}
}
Ok(false)
}
fn name(&self) -> &str {
"or"
}
fn describe(&self) -> String {
let names: Vec<_> = self.triggers.iter().map(|t| t.name()).collect();
format!("OR({})", names.join(", "))
}
}
pub struct AndTrigger {
triggers: Vec<Arc<dyn TrainTrigger>>,
}
impl AndTrigger {
pub fn new(triggers: Vec<Arc<dyn TrainTrigger>>) -> Self {
Self { triggers }
}
}
impl TrainTrigger for AndTrigger {
fn should_train(&self, ctx: &TriggerContext) -> Result<bool, TriggerError> {
if self.triggers.is_empty() {
return Ok(false);
}
for trigger in &self.triggers {
if !trigger.should_train(ctx)? {
return Ok(false);
}
}
Ok(true)
}
fn name(&self) -> &str {
"and"
}
fn describe(&self) -> String {
let names: Vec<_> = self.triggers.iter().map(|t| t.name()).collect();
format!("AND({})", names.join(", "))
}
}
pub struct TriggerBuilder;
impl TriggerBuilder {
pub fn every_n_episodes(n: usize) -> Arc<dyn TrainTrigger> {
Arc::new(CountTrigger::new(n))
}
pub fn every_hours(hours: u64) -> Arc<dyn TrainTrigger> {
Arc::new(TimeTrigger::hours(hours))
}
pub fn every_minutes(minutes: u64) -> Arc<dyn TrainTrigger> {
Arc::new(TimeTrigger::minutes(minutes))
}
pub fn on_quality_drop(threshold: f64) -> Arc<dyn TrainTrigger> {
Arc::new(QualityTrigger::new(threshold))
}
pub fn default_watch() -> Arc<dyn TrainTrigger> {
Arc::new(OrTrigger::new(vec![
Self::every_n_episodes(100),
Self::every_hours(1),
]))
}
pub fn manual() -> Arc<dyn TrainTrigger> {
Arc::new(ManualTrigger)
}
pub fn never() -> Arc<dyn TrainTrigger> {
Arc::new(NeverTrigger)
}
pub fn always() -> Arc<dyn TrainTrigger> {
Arc::new(AlwaysTrigger)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::learn::store::{EpisodeDto, InMemoryEpisodeStore};
use crate::learn::{EpisodeId, EpisodeMetadata, Outcome};
fn create_test_store(count: usize) -> InMemoryEpisodeStore {
let store = InMemoryEpisodeStore::new();
for _ in 0..count {
let dto = EpisodeDto {
id: EpisodeId::new(),
learn_model: "test".to_string(),
outcome: Outcome::success(1.0),
metadata: EpisodeMetadata::new(),
record_ids: vec![],
};
store.append(&dto).unwrap();
}
store
}
fn create_context<'a>(
store: &'a dyn EpisodeStore,
last_train_at: Option<u64>,
last_train_count: usize,
metrics: Option<&'a TriggerMetrics>,
) -> TriggerContext<'a> {
TriggerContext {
store: Some(store),
event_count: None,
last_train_at,
last_train_count,
metrics,
}
}
#[test]
fn test_count_trigger_below_threshold() {
let store = create_test_store(5);
let trigger = CountTrigger::new(10);
let ctx = create_context(&store, None, 0, None);
assert!(!trigger.should_train(&ctx).unwrap());
}
#[test]
fn test_count_trigger_at_threshold() {
let store = create_test_store(10);
let trigger = CountTrigger::new(10);
let ctx = create_context(&store, None, 0, None);
assert!(trigger.should_train(&ctx).unwrap());
}
#[test]
fn test_count_trigger_with_previous_count() {
let store = create_test_store(15);
let trigger = CountTrigger::new(10);
let ctx = create_context(&store, None, 10, None);
assert!(!trigger.should_train(&ctx).unwrap());
let ctx = create_context(&store, None, 5, None);
assert!(trigger.should_train(&ctx).unwrap());
}
#[test]
fn test_time_trigger_first_time_with_episodes() {
let store = create_test_store(5);
let trigger = TimeTrigger::hours(1);
let ctx = create_context(&store, None, 0, None);
assert!(trigger.should_train(&ctx).unwrap());
}
#[test]
fn test_time_trigger_first_time_no_episodes() {
let store = create_test_store(0);
let trigger = TimeTrigger::hours(1);
let ctx = create_context(&store, None, 0, None);
assert!(!trigger.should_train(&ctx).unwrap());
}
#[test]
fn test_time_trigger_not_elapsed() {
let store = create_test_store(5);
let trigger = TimeTrigger::hours(1);
let now = epoch_millis();
let ctx = create_context(&store, Some(now - 1000), 0, None);
assert!(!trigger.should_train(&ctx).unwrap());
}
#[test]
fn test_time_trigger_elapsed() {
let store = create_test_store(5);
let trigger = TimeTrigger::hours(1);
let now = epoch_millis();
let ctx = create_context(&store, Some(now - 3601 * 1000), 0, None);
assert!(trigger.should_train(&ctx).unwrap());
}
#[test]
fn test_quality_trigger_no_metrics() {
let store = create_test_store(5);
let trigger = QualityTrigger::new(0.5);
let ctx = create_context(&store, None, 0, None);
assert!(trigger.should_train(&ctx).is_err());
}
#[test]
fn test_quality_trigger_insufficient_samples() {
let store = create_test_store(5);
let trigger = QualityTrigger::new(0.5).with_min_samples(10);
let metrics = TriggerMetrics {
recent_success_rate: 0.3, overall_success_rate: 0.5,
recent_sample_size: 5, };
let ctx = create_context(&store, None, 0, Some(&metrics));
assert!(!trigger.should_train(&ctx).unwrap());
}
#[test]
fn test_quality_trigger_above_threshold() {
let store = create_test_store(5);
let trigger = QualityTrigger::new(0.5);
let metrics = TriggerMetrics {
recent_success_rate: 0.7,
overall_success_rate: 0.7,
recent_sample_size: 20,
};
let ctx = create_context(&store, None, 0, Some(&metrics));
assert!(!trigger.should_train(&ctx).unwrap());
}
#[test]
fn test_quality_trigger_below_threshold() {
let store = create_test_store(5);
let trigger = QualityTrigger::new(0.5);
let metrics = TriggerMetrics {
recent_success_rate: 0.3,
overall_success_rate: 0.5,
recent_sample_size: 20,
};
let ctx = create_context(&store, None, 0, Some(&metrics));
assert!(trigger.should_train(&ctx).unwrap());
}
#[test]
fn test_or_trigger_all_false() {
let store = create_test_store(5);
let trigger = OrTrigger::new(vec![
Arc::new(CountTrigger::new(100)),
Arc::new(NeverTrigger),
]);
let ctx = create_context(&store, None, 0, None);
assert!(!trigger.should_train(&ctx).unwrap());
}
#[test]
fn test_or_trigger_one_true() {
let store = create_test_store(5);
let trigger = OrTrigger::new(vec![Arc::new(AlwaysTrigger), Arc::new(NeverTrigger)]);
let ctx = create_context(&store, None, 0, None);
assert!(trigger.should_train(&ctx).unwrap());
}
#[test]
fn test_and_trigger_empty() {
let store = create_test_store(5);
let trigger = AndTrigger::new(vec![]);
let ctx = create_context(&store, None, 0, None);
assert!(!trigger.should_train(&ctx).unwrap());
}
#[test]
fn test_and_trigger_all_true() {
let store = create_test_store(5);
let trigger = AndTrigger::new(vec![Arc::new(AlwaysTrigger), Arc::new(AlwaysTrigger)]);
let ctx = create_context(&store, None, 0, None);
assert!(trigger.should_train(&ctx).unwrap());
}
#[test]
fn test_and_trigger_one_false() {
let store = create_test_store(5);
let trigger = AndTrigger::new(vec![Arc::new(AlwaysTrigger), Arc::new(NeverTrigger)]);
let ctx = create_context(&store, None, 0, None);
assert!(!trigger.should_train(&ctx).unwrap());
}
#[test]
fn test_trigger_builder_default_watch() {
let trigger = TriggerBuilder::default_watch();
assert_eq!(trigger.name(), "or");
assert!(trigger.describe().contains("OR"));
}
#[test]
fn test_trigger_describe() {
assert_eq!(
CountTrigger::new(50).describe(),
"Train when 50 new episodes accumulated"
);
assert_eq!(TimeTrigger::hours(2).describe(), "Train every 2 hours");
assert_eq!(
TimeTrigger::minutes(30).describe(),
"Train every 30 minutes"
);
assert!(QualityTrigger::new(0.5).describe().contains("50%"));
}
#[test]
fn test_context_with_count_no_store() {
let ctx = TriggerContext::with_count(15);
let trigger = CountTrigger::new(10);
assert!(trigger.should_train(&ctx).unwrap());
}
#[test]
fn test_context_with_count_below_threshold() {
let ctx = TriggerContext::with_count(5);
let trigger = CountTrigger::new(10);
assert!(!trigger.should_train(&ctx).unwrap());
}
#[test]
fn test_context_with_count_and_last_train_count() {
let ctx = TriggerContext::with_count(20).last_train_count(15);
let trigger = CountTrigger::new(10);
assert!(!trigger.should_train(&ctx).unwrap());
}
#[test]
fn test_context_builder_fluent() {
let metrics = TriggerMetrics {
recent_success_rate: 0.3,
overall_success_rate: 0.5,
recent_sample_size: 20,
};
let now = epoch_millis();
let ctx = TriggerContext::with_count(100)
.last_train_at(now - 3600 * 1000) .last_train_count(50)
.metrics(&metrics);
let count_trigger = CountTrigger::new(10);
assert!(count_trigger.should_train(&ctx).unwrap());
let time_trigger = TimeTrigger::minutes(30);
assert!(time_trigger.should_train(&ctx).unwrap());
let quality_trigger = QualityTrigger::new(0.5);
assert!(quality_trigger.should_train(&ctx).unwrap());
}
#[test]
fn test_time_trigger_with_count_first_time() {
let ctx = TriggerContext::with_count(5);
let trigger = TimeTrigger::hours(1);
assert!(trigger.should_train(&ctx).unwrap());
}
#[test]
fn test_time_trigger_with_count_first_time_no_events() {
let ctx = TriggerContext::with_count(0);
let trigger = TimeTrigger::hours(1);
assert!(!trigger.should_train(&ctx).unwrap());
}
}