use crate::fitts::{CalibrationResult, FittsModel};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CardState {
pub interval_days: f64,
pub ease_factor: f64,
pub repetitions: u32,
pub lapses: u32,
pub last_review: Option<DateTime<Utc>>,
}
impl Default for CardState {
fn default() -> Self {
Self {
interval_days: 0.0,
ease_factor: 2.5,
repetitions: 0,
lapses: 0,
last_review: None,
}
}
}
impl CardState {
pub fn next_review_date(&self) -> DateTime<Utc> {
let base = self.last_review.unwrap_or_else(Utc::now);
base + chrono::Duration::days(self.interval_days.ceil() as i64)
}
pub fn is_due(&self) -> bool {
Utc::now() >= self.next_review_date()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum Rating {
Again = 0,
Hard = 1,
Good = 2,
Easy = 3,
}
impl Rating {
pub fn to_quality(self) -> u8 {
match self {
Rating::Again => 0,
Rating::Hard => 1,
Rating::Good => 2,
Rating::Easy => 3,
}
}
pub fn is_success(self) -> bool {
self.to_quality() >= 1
}
pub fn all() -> [Self; 4] {
[Self::Again, Self::Hard, Self::Good, Self::Easy]
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct ReviewInput {
pub rating: Rating,
pub response_time_ms: u64,
}
impl ReviewInput {
pub fn from_rating(rating: Rating) -> Self {
Self {
rating,
response_time_ms: 0,
}
}
pub fn new(rating: Rating, response_time_ms: u64) -> Self {
Self {
rating,
response_time_ms,
}
}
pub fn response_time_seconds(&self) -> f64 {
self.response_time_ms as f64 / 1000.0
}
}
impl From<Rating> for ReviewInput {
fn from(rating: Rating) -> Self {
Self::from_rating(rating)
}
}
#[derive(Debug, Clone)]
pub struct ReviewResult {
pub card: CardState,
pub predicted_rt: f64,
pub actual_rt: Option<f64>,
pub prediction_error: Option<f64>,
pub retrievability: f64,
pub calibration: Option<CalibrationResult>,
}
#[derive(Debug, Clone)]
pub struct FittsScheduler {
pub fitts: FittsModel,
}
impl Default for FittsScheduler {
fn default() -> Self {
Self::new()
}
}
impl FittsScheduler {
pub fn new() -> Self {
Self {
fitts: FittsModel::new(0.5, 0.3),
}
}
pub fn with_fitts(fitts: FittsModel) -> Self {
Self { fitts }
}
pub fn with_learning_rate(learning_rate: f64) -> Self {
Self {
fitts: FittsModel::with_learning_rate(0.5, 0.3, learning_rate),
}
}
pub fn review(&mut self, mut card: CardState, input: impl Into<ReviewInput>) -> ReviewResult {
let input = input.into();
let quality = input.rating.to_quality();
let (predicted_rt, retrievability) = self.predict(&card);
let old_interval = card.interval_days.max(1.0);
let old_ease = card.ease_factor;
if quality < 1 {
card.repetitions = 0;
card.interval_days = 1.0;
card.lapses += 1;
} else {
card.interval_days = match card.repetitions {
0 => 1.0,
1 => 6.0,
_ => (card.interval_days * card.ease_factor).round(),
};
card.repetitions += 1;
}
let q = quality as f64;
let q_scaled = 1.0 + (q * 4.0 / 3.0);
let ease_delta = 0.1 - (5.0 - q_scaled) * (0.08 + (5.0 - q_scaled) * 0.02);
card.ease_factor = (card.ease_factor + ease_delta).max(1.3);
card.last_review = Some(Utc::now());
let (actual_rt, prediction_error, calibration) = if input.response_time_ms > 0 {
let actual = input.response_time_seconds();
let stability = old_interval; let cal = self
.fitts
.calibrate(old_interval, old_ease, stability, actual);
(Some(actual), Some(cal.error), Some(cal))
} else {
(None, None, None)
};
ReviewResult {
card,
predicted_rt,
actual_rt,
prediction_error,
retrievability,
calibration,
}
}
pub fn predict(&self, card: &CardState) -> (f64, f64) {
let interval = card.interval_days.max(1.0);
let ease = card.ease_factor;
let stability = interval;
self.fitts.predict(interval, ease, stability)
}
pub fn order_by_difficulty(&self, cards: &mut [CardState]) {
cards.sort_by(|a, b| {
let (rt_a, _) = self.predict(a);
let (rt_b, _) = self.predict(b);
rt_b.partial_cmp(&rt_a).unwrap_or(std::cmp::Ordering::Equal)
});
}
pub fn model_params(&self) -> (f64, f64) {
(self.fitts.params.a, self.fitts.params.b)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rating_quality_mapping() {
assert_eq!(Rating::Again.to_quality(), 0);
assert_eq!(Rating::Hard.to_quality(), 1);
assert_eq!(Rating::Good.to_quality(), 2);
assert_eq!(Rating::Easy.to_quality(), 3);
}
#[test]
fn test_rating_is_success() {
assert!(!Rating::Again.is_success());
assert!(Rating::Hard.is_success());
assert!(Rating::Good.is_success());
assert!(Rating::Easy.is_success());
}
#[test]
fn test_review_input_creation() {
let input = ReviewInput::new(Rating::Good, 2500);
assert_eq!(input.rating, Rating::Good);
assert_eq!(input.response_time_ms, 2500);
assert!((input.response_time_seconds() - 2.5).abs() < 0.001);
}
#[test]
fn test_review_with_rating_only() {
let mut scheduler = FittsScheduler::new();
let card = CardState::default();
let result = scheduler.review(card, Rating::Good);
assert_eq!(result.card.interval_days, 1.0);
assert!(result.actual_rt.is_none());
assert!(result.calibration.is_none());
}
#[test]
fn test_review_with_response_time() {
let mut scheduler = FittsScheduler::with_learning_rate(0.1);
let card = CardState::default();
let input = ReviewInput::new(Rating::Good, 3000);
let result = scheduler.review(card, input);
assert_eq!(result.card.interval_days, 1.0);
assert!(result.actual_rt.is_some());
assert!((result.actual_rt.unwrap() - 3.0).abs() < 0.001);
assert!(result.calibration.is_some());
}
#[test]
fn test_calibration_improves_predictions() {
let mut scheduler = FittsScheduler::with_learning_rate(0.1);
let actual_rt_ms = 2000;
for _ in 0..20 {
let card = CardState {
interval_days: 7.0,
ease_factor: 2.5,
..Default::default()
};
let input = ReviewInput::new(Rating::Good, actual_rt_ms);
scheduler.review(card, input);
}
let card = CardState {
interval_days: 7.0,
ease_factor: 2.5,
..Default::default()
};
let (predicted, _) = scheduler.predict(&card);
assert!(
predicted > 1.0 && predicted < 4.0,
"Prediction should be calibrated towards actual RT"
);
}
#[test]
fn test_sm2_new_card() {
let mut scheduler = FittsScheduler::new();
let card = CardState::default();
let result = scheduler.review(card, Rating::Good);
assert_eq!(result.card.interval_days, 1.0);
assert_eq!(result.card.repetitions, 1);
}
#[test]
fn test_sm2_progression() {
let mut scheduler = FittsScheduler::new();
let mut card = CardState::default();
card = scheduler.review(card, Rating::Good).card;
assert_eq!(card.interval_days, 1.0);
card = scheduler.review(card, Rating::Good).card;
assert_eq!(card.interval_days, 6.0);
card = scheduler.review(card, Rating::Good).card;
assert!(card.interval_days > 6.0);
}
#[test]
fn test_sm2_hard_is_success() {
let mut scheduler = FittsScheduler::new();
let card = CardState::default();
let result = scheduler.review(card, Rating::Hard);
assert_eq!(result.card.repetitions, 1);
assert_eq!(result.card.lapses, 0);
assert_eq!(result.card.interval_days, 1.0);
}
#[test]
fn test_sm2_lapse() {
let mut scheduler = FittsScheduler::new();
let card = CardState {
interval_days: 30.0,
ease_factor: 2.5,
repetitions: 5,
lapses: 0,
last_review: None,
};
let result = scheduler.review(card, Rating::Again);
assert_eq!(result.card.interval_days, 1.0);
assert_eq!(result.card.repetitions, 0);
assert_eq!(result.card.lapses, 1);
}
#[test]
fn test_ease_factor_bounds() {
let mut scheduler = FittsScheduler::new();
let mut card = CardState::default();
for _ in 0..10 {
card = scheduler.review(card, Rating::Again).card;
}
assert!(card.ease_factor >= 1.3);
}
#[test]
fn test_quality_scaling_to_sm2() {
let test_cases: [(f64, f64); 4] = [
(0.0, 1.0), (1.0, 2.333), (2.0, 3.667), (3.0, 5.0), ];
for (q, expected) in test_cases.iter() {
let q_scaled = 1.0 + (q * 4.0 / 3.0);
assert!(
(q_scaled - expected).abs() < 0.01,
"Quality {} should map to ~{}, got {}",
q,
expected,
q_scaled
);
}
}
#[test]
fn test_order_by_difficulty() {
let scheduler = FittsScheduler::new();
let mut cards = vec![
CardState {
interval_days: 1.0,
ease_factor: 2.5,
..Default::default()
},
CardState {
interval_days: 30.0,
ease_factor: 1.5,
..Default::default()
},
CardState {
interval_days: 7.0,
ease_factor: 2.0,
..Default::default()
},
];
scheduler.order_by_difficulty(&mut cards);
let (rt_first, _) = scheduler.predict(&cards[0]);
let (rt_last, _) = scheduler.predict(&cards[cards.len() - 1]);
assert!(rt_first >= rt_last);
}
}