use anyhow::Result;
use chrono::{DateTime, TimeZone, Utc};
use crate::data::{ExerciseTrial, UnitReward};
pub trait RewardScorer {
fn score_rewards(
&self,
previous_course_rewards: &[UnitReward],
previous_lesson_rewards: &[UnitReward],
) -> Result<f32>;
fn apply_reward(&self, reward: f32, previous_trials: &[ExerciseTrial]) -> bool;
}
const REWARD_HALF_LIFE_DAYS: f32 = 14.0;
const MIN_EFFECTIVE_WEIGHT: f32 = 0.05;
const COURSE_REWARDS_WEIGHT: f32 = 0.3;
const LESSON_REWARDS_WEIGHT: f32 = 0.7;
pub struct WeightedRewardScorer {}
impl WeightedRewardScorer {
fn days_since(reward: &UnitReward, now: DateTime<Utc>) -> f32 {
let timestamp = Utc
.timestamp_opt(reward.timestamp, 0)
.earliest()
.unwrap_or_default();
(now - timestamp).num_days().max(0) as f32
}
fn decay_factor(days: f32) -> f32 {
0.5_f32.powf(days / REWARD_HALF_LIFE_DAYS)
}
fn decayed_reward(reward: &UnitReward, now: DateTime<Utc>) -> (f32, f32) {
let days = Self::days_since(reward, now);
let decay = Self::decay_factor(days);
(reward.value * decay, reward.weight * decay)
}
fn weighted_average(rewards: &[UnitReward], now: DateTime<Utc>) -> f32 {
let mut numerator = 0.0;
let mut denominator = 0.0;
for reward in rewards {
let (effective_value, effective_weight) = Self::decayed_reward(reward, now);
if effective_weight < MIN_EFFECTIVE_WEIGHT {
continue;
}
numerator += effective_value * effective_weight;
denominator += effective_weight;
}
if denominator == 0.0 {
0.0
} else {
numerator / denominator
}
}
}
impl RewardScorer for WeightedRewardScorer {
fn score_rewards(
&self,
previous_course_rewards: &[UnitReward],
previous_lesson_rewards: &[UnitReward],
) -> Result<f32> {
let now = Utc::now();
let course_score = Self::weighted_average(previous_course_rewards, now);
let lesson_score = Self::weighted_average(previous_lesson_rewards, now);
if previous_course_rewards.is_empty() && previous_lesson_rewards.is_empty() {
Ok(0.0)
} else if previous_course_rewards.is_empty() {
Ok(lesson_score)
} else if previous_lesson_rewards.is_empty() {
Ok(course_score)
} else {
let numerator =
course_score * COURSE_REWARDS_WEIGHT + lesson_score * LESSON_REWARDS_WEIGHT;
let denominator = COURSE_REWARDS_WEIGHT + LESSON_REWARDS_WEIGHT;
Ok(numerator / denominator)
}
}
fn apply_reward(&self, reward: f32, previous_trials: &[ExerciseTrial]) -> bool {
if previous_trials.len() <= 2 {
return false;
}
let recent_trials = previous_trials.iter().take(3);
let last_trial = previous_trials.first().unwrap();
let num_days = (Utc::now().timestamp() - last_trial.timestamp) as f32 / (86_400.0);
let average_score = recent_trials.map(|trial| trial.score).sum::<f32>() / 3.0;
if reward > 0.0 && average_score < 3.0 && num_days < 7.0 {
return false;
}
if reward < 0.0 && average_score > 3.5 && num_days < 7.0 {
return false;
}
true
}
}
#[cfg(test)]
#[cfg_attr(coverage, coverage(off))]
mod test {
use chrono::Utc;
use ustr::Ustr;
use crate::{
data::{ExerciseTrial, UnitReward},
reward_scorer::{RewardScorer, WeightedRewardScorer},
};
const SECONDS_IN_DAY: i64 = 60 * 60 * 24;
fn generate_timestamp(num_days: i64) -> i64 {
let now = Utc::now().timestamp();
now - num_days * SECONDS_IN_DAY
}
fn generate_future_timestamp(num_days: i64) -> i64 {
let now = Utc::now().timestamp();
now + num_days * SECONDS_IN_DAY
}
#[test]
fn test_decay_factor() {
assert!((WeightedRewardScorer::decay_factor(0.0) - 1.0).abs() < 0.000_001);
assert!((WeightedRewardScorer::decay_factor(14.0) - 0.5).abs() < 0.001);
assert!((WeightedRewardScorer::decay_factor(28.0) - 0.25).abs() < 0.001);
}
#[test]
fn test_decayed_reward() {
let now = Utc::now();
let reward = UnitReward {
unit_id: Ustr::default(),
value: 1.0,
weight: 2.0,
timestamp: generate_timestamp(14),
};
let (value, weight) = WeightedRewardScorer::decayed_reward(&reward, now);
assert!((value - 0.5).abs() < 0.001);
assert!((weight - 1.0).abs() < 0.001);
let reward = UnitReward {
unit_id: Ustr::default(),
value: -1.0,
weight: 1.0,
timestamp: generate_timestamp(14),
};
let (value, weight) = WeightedRewardScorer::decayed_reward(&reward, now);
assert!((value + 0.5).abs() < 0.001);
assert!((weight - 0.5).abs() < 0.001);
}
#[test]
fn test_future_timestamp_is_clamped() {
let now = Utc::now();
let reward = UnitReward {
unit_id: Ustr::default(),
value: 1.0,
weight: 2.0,
timestamp: generate_future_timestamp(3),
};
let (value, weight) = WeightedRewardScorer::decayed_reward(&reward, now);
assert!((value - 1.0).abs() < 0.001);
assert!((weight - 2.0).abs() < 0.001);
}
#[test]
fn test_no_rewards() {
let scorer = WeightedRewardScorer {};
let result = scorer.score_rewards(&[], &[]).unwrap();
assert_eq!(result, 0.0);
}
#[test]
fn test_only_lesson_rewards() {
let scorer = WeightedRewardScorer {};
let lesson_rewards = vec![
UnitReward {
unit_id: Ustr::default(),
value: 1.0,
weight: 1.0,
timestamp: generate_timestamp(1),
},
UnitReward {
unit_id: Ustr::default(),
value: 2.0,
weight: 1.0,
timestamp: generate_timestamp(2),
},
];
let result = scorer.score_rewards(&[], &lesson_rewards).unwrap();
assert!((result - 1.371).abs() < 0.001);
}
#[test]
fn test_only_course_rewards() {
let scorer = WeightedRewardScorer {};
let course_rewards = vec![
UnitReward {
unit_id: Ustr::default(),
value: 1.0,
weight: 1.0,
timestamp: generate_timestamp(1),
},
UnitReward {
unit_id: Ustr::default(),
value: 2.0,
weight: 1.0,
timestamp: generate_timestamp(2),
},
];
let result = scorer.score_rewards(&course_rewards, &[]).unwrap();
assert!((result - 1.371).abs() < 0.001);
}
#[test]
fn test_both_rewards() {
let scorer = WeightedRewardScorer {};
let course_rewards = vec![
UnitReward {
unit_id: Ustr::default(),
value: 1.0,
weight: 1.0,
timestamp: generate_timestamp(1),
},
UnitReward {
unit_id: Ustr::default(),
value: 2.0,
weight: 1.0,
timestamp: generate_timestamp(2),
},
];
let lesson_rewards = vec![
UnitReward {
unit_id: Ustr::default(),
value: 2.0,
weight: 1.0,
timestamp: generate_timestamp(1),
},
UnitReward {
unit_id: Ustr::default(),
value: 4.0,
weight: 2.0,
timestamp: generate_timestamp(2),
},
];
let result = scorer
.score_rewards(&course_rewards, &lesson_rewards)
.unwrap();
assert!((result - 2.533).abs() < 0.001);
}
#[test]
fn test_min_weight() {
let scorer = WeightedRewardScorer {};
let lesson_rewards = vec![
UnitReward {
unit_id: Ustr::default(),
value: 2.0,
weight: 1.0,
timestamp: generate_timestamp(0),
},
UnitReward {
unit_id: Ustr::default(),
value: 1.0,
weight: 0.0001,
timestamp: generate_timestamp(0) - 1,
},
];
let result = scorer.score_rewards(&[], &lesson_rewards).unwrap();
assert!((result - 2.0).abs() < 0.001);
}
#[test]
fn test_stale_rewards_do_not_drag_denominator() {
let scorer = WeightedRewardScorer {};
let lesson_rewards = vec![
UnitReward {
unit_id: Ustr::default(),
value: 1.0,
weight: 10.0,
timestamp: generate_timestamp(70),
},
UnitReward {
unit_id: Ustr::default(),
value: 1.0,
weight: 1.0,
timestamp: generate_timestamp(0),
},
];
let result = scorer.score_rewards(&[], &lesson_rewards).unwrap();
assert!(result > 0.7);
}
#[test]
fn test_apply_rewards() {
let scorer = WeightedRewardScorer {};
let trials = vec![ExerciseTrial {
score: 2.0,
timestamp: generate_timestamp(1),
}];
assert!(!scorer.apply_reward(0.5, &trials));
assert!(!scorer.apply_reward(-1.0, &trials));
let trials = vec![
ExerciseTrial {
score: 2.0,
timestamp: generate_timestamp(1),
},
ExerciseTrial {
score: 2.0,
timestamp: generate_timestamp(8),
},
ExerciseTrial {
score: 3.0,
timestamp: generate_timestamp(10),
},
];
assert!(!scorer.apply_reward(0.5, &trials));
assert!(scorer.apply_reward(-1.0, &trials));
let trials = vec![
ExerciseTrial {
score: 4.0,
timestamp: generate_timestamp(1),
},
ExerciseTrial {
score: 5.0,
timestamp: generate_timestamp(8),
},
ExerciseTrial {
score: 4.0,
timestamp: generate_timestamp(10),
},
];
assert!(!scorer.apply_reward(-0.5, &trials));
assert!(scorer.apply_reward(1.0, &trials));
let trials = vec![
ExerciseTrial {
score: 3.0,
timestamp: generate_timestamp(1),
},
ExerciseTrial {
score: 3.0,
timestamp: generate_timestamp(8),
},
ExerciseTrial {
score: 4.0,
timestamp: generate_timestamp(10),
},
];
assert!(scorer.apply_reward(0.5, &trials));
let trials = vec![
ExerciseTrial {
score: 2.0,
timestamp: generate_timestamp(1),
},
ExerciseTrial {
score: 3.0,
timestamp: generate_timestamp(8),
},
ExerciseTrial {
score: 2.0,
timestamp: generate_timestamp(10),
},
];
assert!(scorer.apply_reward(-0.5, &trials));
}
}