use ustr::{Ustr, UstrMap};
use crate::{
data::{MasteryScore, UnitReward},
graph::UnitGraph,
scheduler::data::SchedulerData,
};
pub(super) const MIN_ABS_REWARD: f32 = 0.2;
pub(super) const MIN_WEIGHT: f32 = 0.2;
pub(super) const WEIGHT_FACTOR: f32 = 0.8;
pub(super) const REWARD_FACTOR: f32 = 0.9;
pub(super) struct RewardPropagator {
pub data: SchedulerData,
}
impl RewardPropagator {
fn initial_reward(score: &MasteryScore) -> f32 {
match score {
MasteryScore::Five => 0.8,
MasteryScore::Four => 0.4,
MasteryScore::Three => -0.3,
MasteryScore::Two => -0.5,
MasteryScore::One => -1.0,
}
}
fn get_next_units(unit_graph: &dyn UnitGraph, unit_id: Ustr, reward: f32) -> Vec<(Ustr, f32)> {
if reward > 0.0 {
unit_graph.get_encompasses(unit_id).unwrap_or_default()
} else {
unit_graph.get_encompassed_by(unit_id).unwrap_or_default()
}
}
pub(super) fn stop_propagation(reward: f32, weight: f32) -> bool {
reward.abs() < MIN_ABS_REWARD || weight < MIN_WEIGHT
}
fn resolve_roots(unit_graph: &dyn UnitGraph, exercise_id: Ustr) -> Option<(Ustr, Ustr)> {
let lesson_id = unit_graph.get_exercise_lesson(exercise_id)?;
let course_id = unit_graph.get_lesson_course(lesson_id)?;
Some((lesson_id, course_id))
}
fn propagate_rewards_helper(
unit_graph: &dyn UnitGraph,
lesson_id: Ustr,
course_id: Ustr,
score: &MasteryScore,
timestamp: i64,
) -> Vec<UnitReward> {
if lesson_id.is_empty() || course_id.is_empty() {
return vec![]; }
let initial_reward = Self::initial_reward(score);
let next_lessons = Self::get_next_units(unit_graph, lesson_id, initial_reward);
let next_courses = Self::get_next_units(unit_graph, course_id, initial_reward);
let mut stack: Vec<UnitReward> = Vec::new();
next_lessons
.iter()
.chain(next_courses.iter())
.for_each(|(id, edge_weight)| {
let value = edge_weight * initial_reward;
let weight = *edge_weight;
if Self::stop_propagation(value, weight) {
return;
}
stack.push(UnitReward {
unit_id: *id,
value,
weight,
timestamp,
});
});
let mut results: UstrMap<UnitReward> = UstrMap::default();
while let Some(item) = stack.pop() {
if let Some(existing_reward) = results.get(&item.unit_id)
&& existing_reward.value.abs() >= item.value.abs()
{
continue;
}
results.insert(item.unit_id, item.clone());
for (next_unit_id, edge_weight) in
&Self::get_next_units(unit_graph, item.unit_id, item.value)
{
let next_value = *edge_weight * REWARD_FACTOR * item.value;
let next_weight = *edge_weight * WEIGHT_FACTOR * item.weight;
if Self::stop_propagation(next_value, next_weight) {
continue;
}
stack.push(UnitReward {
unit_id: *next_unit_id,
value: next_value,
weight: next_weight,
timestamp,
});
}
}
results.values().cloned().collect()
}
pub(super) fn propagate_rewards(
&self,
exercise_id: Ustr,
score: &MasteryScore,
timestamp: i64,
) -> Vec<UnitReward> {
let unit_graph = self.data.unit_graph.read();
let roots = Self::resolve_roots(&*unit_graph, exercise_id);
let Some((lesson_id, course_id)) = roots else {
return vec![]; };
Self::propagate_rewards_helper(&*unit_graph, lesson_id, course_id, score, timestamp)
}
}
#[cfg(test)]
#[cfg_attr(coverage, coverage(off))]
mod test {
use anyhow::Result;
use ustr::{Ustr, UstrMap};
use crate::{
data::{MasteryScore, UnitReward},
graph::{InMemoryUnitGraph, UnitGraph},
scheduler::reward_propagator::{MIN_ABS_REWARD, MIN_WEIGHT, RewardPropagator},
};
fn build_path_graph(source_encompassed: &[(Ustr, f32)]) -> Result<InMemoryUnitGraph> {
let mut graph = InMemoryUnitGraph::default();
graph.add_course(Ustr::from("0"))?;
graph.add_lesson(Ustr::from("0::0"), Ustr::from("0"))?;
graph.add_lesson(Ustr::from("0::1"), Ustr::from("0"))?;
graph.add_lesson(Ustr::from("0::2"), Ustr::from("0"))?;
graph.add_lesson(Ustr::from("0::3"), Ustr::from("0"))?;
graph.add_exercise(Ustr::from("0::0::0"), Ustr::from("0::0"))?;
graph.add_encompassed(Ustr::from("0::0"), &[], source_encompassed)?;
graph.add_encompassed(Ustr::from("0::1"), &[], &[(Ustr::from("0::3"), 1.0)])?;
graph.add_encompassed(Ustr::from("0::2"), &[], &[(Ustr::from("0::3"), 1.0)])?;
Ok(graph)
}
fn propagate_five_rewards(unit_graph: &dyn UnitGraph) -> Result<UstrMap<UnitReward>> {
let rewards = RewardPropagator::propagate_rewards_helper(
unit_graph,
Ustr::from("0::0"),
Ustr::from("0"),
&MasteryScore::Five,
0,
);
Ok(rewards
.into_iter()
.map(|reward| (reward.unit_id, reward))
.collect())
}
#[test]
fn initial_reward() {
assert_eq!(RewardPropagator::initial_reward(&MasteryScore::Five), 0.8);
assert_eq!(RewardPropagator::initial_reward(&MasteryScore::Four), 0.4);
assert_eq!(RewardPropagator::initial_reward(&MasteryScore::Three), -0.3);
assert_eq!(RewardPropagator::initial_reward(&MasteryScore::Two), -0.5);
assert_eq!(RewardPropagator::initial_reward(&MasteryScore::One), -1.0);
}
#[test]
fn stop_propagation() {
assert!(!RewardPropagator::stop_propagation(
MIN_ABS_REWARD,
MIN_WEIGHT,
));
assert!(RewardPropagator::stop_propagation(
MIN_ABS_REWARD - 0.001,
MIN_WEIGHT,
));
assert!(RewardPropagator::stop_propagation(
-MIN_ABS_REWARD + 0.001,
MIN_WEIGHT,
));
assert!(RewardPropagator::stop_propagation(
MIN_ABS_REWARD,
MIN_WEIGHT - 0.001,
));
}
#[test]
fn strongest_path_wins() -> Result<()> {
let graph = build_path_graph(&[(Ustr::from("0::1"), 1.0), (Ustr::from("0::2"), 0.5)])?;
let reward_map = propagate_five_rewards(&graph)?;
let reward = reward_map.get(&Ustr::from("0::3")).unwrap();
assert!((reward.value - 0.72).abs() < f32::EPSILON);
assert!((reward.weight - 0.8).abs() < f32::EPSILON);
Ok(())
}
#[test]
fn strongest_path_is_order_independent() -> Result<()> {
let first_order =
build_path_graph(&[(Ustr::from("0::1"), 1.0), (Ustr::from("0::2"), 0.5)])?;
let second_order =
build_path_graph(&[(Ustr::from("0::2"), 0.5), (Ustr::from("0::1"), 1.0)])?;
let first_reward = propagate_five_rewards(&first_order)?
.get(&Ustr::from("0::3"))
.cloned()
.unwrap();
let second_reward = propagate_five_rewards(&second_order)?
.get(&Ustr::from("0::3"))
.cloned()
.unwrap();
assert!((first_reward.value - second_reward.value).abs() < f32::EPSILON);
assert!((first_reward.weight - second_reward.weight).abs() < f32::EPSILON);
Ok(())
}
#[test]
fn edge_weights_attenuate_reward_weight() -> Result<()> {
let mut graph = InMemoryUnitGraph::default();
graph.add_course(Ustr::from("0"))?;
graph.add_lesson(Ustr::from("0::0"), Ustr::from("0"))?;
graph.add_lesson(Ustr::from("0::1"), Ustr::from("0"))?;
graph.add_lesson(Ustr::from("0::2"), Ustr::from("0"))?;
graph.add_encompassed(Ustr::from("0::0"), &[], &[(Ustr::from("0::1"), 0.8)])?;
graph.add_encompassed(Ustr::from("0::1"), &[], &[(Ustr::from("0::2"), 0.8)])?;
let reward_map = propagate_five_rewards(&graph)?;
let first_hop = reward_map.get(&Ustr::from("0::1")).unwrap();
assert!((first_hop.value - 0.64).abs() < f32::EPSILON);
assert!((first_hop.weight - 0.8).abs() < f32::EPSILON);
let second_hop = reward_map.get(&Ustr::from("0::2")).unwrap();
assert!((second_hop.value - 0.4608).abs() < f32::EPSILON);
assert!((second_hop.weight - 0.512).abs() < f32::EPSILON);
Ok(())
}
#[test]
fn weak_initial_edges_are_pruned() -> Result<()> {
let mut graph = InMemoryUnitGraph::default();
graph.add_course(Ustr::from("0"))?;
graph.add_lesson(Ustr::from("0::0"), Ustr::from("0"))?;
graph.add_lesson(Ustr::from("0::1"), Ustr::from("0"))?;
graph.add_encompassed(Ustr::from("0::0"), &[], &[(Ustr::from("0::1"), 0.1)])?;
let reward_map = propagate_five_rewards(&graph)?;
assert!(reward_map.is_empty());
Ok(())
}
#[test]
fn weak_recursive_hops_are_pruned() -> Result<()> {
let mut graph = InMemoryUnitGraph::default();
graph.add_course(Ustr::from("0"))?;
graph.add_lesson(Ustr::from("0::0"), Ustr::from("0"))?;
graph.add_lesson(Ustr::from("0::1"), Ustr::from("0"))?;
graph.add_lesson(Ustr::from("0::2"), Ustr::from("0"))?;
graph.add_encompassed(Ustr::from("0::0"), &[], &[(Ustr::from("0::1"), 1.0)])?;
graph.add_encompassed(Ustr::from("0::1"), &[], &[(Ustr::from("0::2"), 0.1)])?;
let reward_map = propagate_five_rewards(&graph)?;
assert!(reward_map.contains_key(&Ustr::from("0::1")));
assert!(!reward_map.contains_key(&Ustr::from("0::2")));
Ok(())
}
#[test]
fn resolve_roots() -> Result<()> {
let mut graph = InMemoryUnitGraph::default();
graph.add_course(Ustr::from("course"))?;
graph.add_lesson(Ustr::from("lesson"), Ustr::from("course"))?;
graph.add_exercise(Ustr::from("exercise"), Ustr::from("lesson"))?;
assert_eq!(
RewardPropagator::resolve_roots(&graph, Ustr::from("exercise")),
Some((Ustr::from("lesson"), Ustr::from("course")))
);
assert_eq!(
RewardPropagator::resolve_roots(&graph, Ustr::from("missing")),
None
);
Ok(())
}
}