#[derive(Debug, Clone)]
pub struct FsrsParams {
pub w20: f64,
pub desired_retention: f64,
pub initial_stability_days: f64,
pub min_stability_days: f64,
}
impl Default for FsrsParams {
fn default() -> Self {
Self {
w20: 0.2,
desired_retention: 0.9,
initial_stability_days: 1.0,
min_stability_days: 0.1,
}
}
}
#[derive(Debug, Clone)]
pub struct FsrsState {
stability_days: f64,
last_review_at: f64,
}
const SECS_PER_DAY: f64 = 86_400.0;
impl FsrsState {
pub fn new(created_at: f64, initial_stability_days: f64) -> Self {
Self {
stability_days: initial_stability_days,
last_review_at: created_at,
}
}
pub fn stability_days(&self) -> f64 {
self.stability_days
}
pub fn last_review_at(&self) -> f64 {
self.last_review_at
}
fn compute_f(w20: f64) -> f64 {
0.9_f64.powf(-1.0 / w20) - 1.0
}
pub fn current_retrievability(&self, now: f64, params: &FsrsParams) -> f64 {
let t_days = ((now - self.last_review_at) / SECS_PER_DAY).max(0.0);
let f = Self::compute_f(params.w20);
(1.0 + f * t_days / self.stability_days).powf(-params.w20)
}
pub fn next_review_interval_days(&self, desired_retention: f64, params: &FsrsParams) -> f64 {
let f = Self::compute_f(params.w20);
if f.abs() < f64::EPSILON {
return self.stability_days;
}
self.stability_days / f * (desired_retention.powf(-1.0 / params.w20) - 1.0)
}
pub fn record_review(&mut self, now: f64, boost: f64, params: &FsrsParams) {
self.stability_days = (self.stability_days * boost).max(params.min_stability_days);
self.last_review_at = now;
}
}
#[cfg(test)]
mod tests {
use super::*;
fn default_params() -> FsrsParams {
FsrsParams::default()
}
#[test]
fn test_retrievability_at_t0_is_1() {
let state = FsrsState::new(1000.0, 1.0);
let params = default_params();
let r = state.current_retrievability(1000.0, ¶ms);
assert!((r - 1.0).abs() < 1e-10, "R(t=0) should be 1.0, got {r}");
}
#[test]
fn test_retrievability_decays_over_time() {
let state = FsrsState::new(0.0, 1.0);
let params = default_params();
let r_1d = state.current_retrievability(SECS_PER_DAY, ¶ms);
let r_7d = state.current_retrievability(7.0 * SECS_PER_DAY, ¶ms);
let r_30d = state.current_retrievability(30.0 * SECS_PER_DAY, ¶ms);
assert!(r_1d < 1.0, "R after 1 day should be < 1.0, got {r_1d}");
assert!(
r_7d < r_1d,
"R after 7d={r_7d} should be < R after 1d={r_1d}"
);
assert!(
r_30d < r_7d,
"R after 30d={r_30d} should be < R after 7d={r_7d}"
);
assert!(r_30d > 0.0, "R should still be > 0");
}
#[test]
fn test_retrievability_approaches_zero() {
let state = FsrsState::new(0.0, 1.0);
let params = default_params();
let r = state.current_retrievability(365.0 * 10.0 * SECS_PER_DAY, ¶ms);
assert!(r < 0.3, "R after 10 years should be low, got {r}");
assert!(r > 0.0, "R should still be > 0");
}
#[test]
fn test_stability_increases_with_review_boost() {
let mut state = FsrsState::new(0.0, 1.0);
let params = default_params();
assert!((state.stability_days() - 1.0).abs() < f64::EPSILON);
state.record_review(SECS_PER_DAY, 1.2, ¶ms);
assert!(
(state.stability_days() - 1.2).abs() < f64::EPSILON,
"S should be 1.2 after boost, got {}",
state.stability_days()
);
state.record_review(2.0 * SECS_PER_DAY, 1.5, ¶ms);
assert!(
(state.stability_days() - 1.8).abs() < 0.001,
"S should be 1.2 * 1.5 = 1.8, got {}",
state.stability_days()
);
}
#[test]
fn test_next_review_interval_increases_with_stability() {
let params = default_params();
let state_low = FsrsState::new(0.0, 1.0);
let state_high = FsrsState::new(0.0, 10.0);
let interval_low = state_low.next_review_interval_days(0.9, ¶ms);
let interval_high = state_high.next_review_interval_days(0.9, ¶ms);
assert!(
interval_high > interval_low,
"Higher stability should give longer interval: low={interval_low}, high={interval_high}"
);
}
#[test]
fn test_next_review_interval_for_default_retention() {
let state = FsrsState::new(0.0, 1.0);
let params = default_params();
let interval = state.next_review_interval_days(0.9, ¶ms);
assert!(
(interval - 1.0).abs() < 0.01,
"With r=0.9, interval should ≈ S=1.0 day, got {interval}"
);
}
#[test]
fn test_review_respects_min_stability() {
let mut state = FsrsState::new(0.0, 0.05);
let params = FsrsParams {
min_stability_days: 0.1,
..Default::default()
};
state.record_review(SECS_PER_DAY, 0.5, ¶ms);
assert!(
(state.stability_days() - 0.1).abs() < f64::EPSILON,
"Stability should be clamped to min 0.1, got {}",
state.stability_days()
);
}
#[test]
fn test_review_updates_last_review_at() {
let mut state = FsrsState::new(0.0, 1.0);
let params = default_params();
state.record_review(5000.0, 1.0, ¶ms);
assert!((state.last_review_at() - 5000.0).abs() < f64::EPSILON);
}
#[test]
fn test_retrievability_resets_after_review() {
let mut state = FsrsState::new(0.0, 1.0);
let params = default_params();
let r_before = state.current_retrievability(100.0 * SECS_PER_DAY, ¶ms);
assert!(
r_before < 0.5,
"R at 100 days should be < 0.5, got {r_before}"
);
state.record_review(100.0 * SECS_PER_DAY, 1.0, ¶ms);
let r_after = state.current_retrievability(100.0 * SECS_PER_DAY, ¶ms);
assert!(
(r_after - 1.0).abs() < 1e-10,
"R should be 1.0 right after review, got {r_after}"
);
}
#[test]
fn test_f_constant_value() {
let f = FsrsState::compute_f(0.2);
let expected = 0.9_f64.powf(-5.0) - 1.0;
assert!(
(f - expected).abs() < 1e-10,
"F should be {expected}, got {f}"
);
}
}