use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct FittsParameters {
pub a: f64,
pub b: f64,
pub distance_factor: f64,
pub accessibility_factor: f64,
pub threshold: f64,
pub learning_rate: f64,
}
impl FittsParameters {
pub fn new(
a: f64,
b: f64,
distance_factor: f64,
accessibility_factor: f64,
threshold: f64,
learning_rate: f64,
) -> Result<Self, FittsError> {
let params = Self {
a,
b,
distance_factor,
accessibility_factor,
threshold,
learning_rate,
};
params.validate()?;
Ok(params)
}
pub fn validate(&self) -> Result<(), FittsError> {
if !self.a.is_finite()
|| !self.b.is_finite()
|| !self.distance_factor.is_finite()
|| !self.accessibility_factor.is_finite()
|| !self.threshold.is_finite()
|| !self.learning_rate.is_finite()
{
return Err(FittsError::InvalidParameters(
"Parameters must be finite".into(),
));
}
if self.a <= 0.0
|| self.b <= 0.0
|| self.distance_factor <= 0.0
|| self.accessibility_factor <= 0.0
|| self.threshold <= 0.0
{
return Err(FittsError::InvalidParameters(
"Core parameters must be positive".into(),
));
}
if self.learning_rate < 0.0 || self.learning_rate > 1.0 {
return Err(FittsError::InvalidParameters(
"Learning rate must be in [0, 1]".into(),
));
}
Ok(())
}
}
impl Default for FittsParameters {
fn default() -> Self {
Self {
a: 0.5, b: 0.3, distance_factor: 1.0,
accessibility_factor: 1.0,
threshold: 2.0, learning_rate: 0.05, }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Default)]
pub struct FittsModel {
pub params: FittsParameters,
}
impl FittsModel {
pub fn new(a: f64, b: f64) -> Self {
Self {
params: FittsParameters {
a: a.clamp(0.01, 10.0),
b: b.clamp(0.01, 5.0),
..Default::default()
},
}
}
pub fn with_learning_rate(a: f64, b: f64, learning_rate: f64) -> Self {
Self {
params: FittsParameters {
a: a.clamp(0.01, 10.0),
b: b.clamp(0.01, 5.0),
learning_rate: learning_rate.clamp(0.0, 1.0),
..Default::default()
},
}
}
pub fn with_validated_params(params: FittsParameters) -> Result<Self, FittsError> {
params.validate()?;
Ok(Self { params })
}
pub fn with_params(params: FittsParameters) -> Self {
Self { params }
}
pub fn memory_distance(&self, interval_days: f64, ease: f64) -> f64 {
let interval = interval_days.clamp(0.0, 1e6);
let safe_ease = ease.clamp(1.0, 5.0);
let time_factor = (1.0 + interval).ln();
let ease_penalty = 1.0 / safe_ease;
let distance = self.params.distance_factor * time_factor * ease_penalty;
distance.clamp(0.0, 100.0)
}
pub fn memory_accessibility(&self, stability: f64, ease: f64) -> f64 {
let safe_stability = stability.clamp(0.01, 1000.0);
let safe_ease = ease.clamp(1.0, 5.0);
let accessibility = self.params.accessibility_factor * safe_stability * safe_ease;
accessibility.max(0.01)
}
pub fn index_of_difficulty(&self, interval_days: f64, ease: f64, stability: f64) -> f64 {
let distance = self.memory_distance(interval_days, ease);
let accessibility = self.memory_accessibility(stability, ease);
let ratio = distance / accessibility;
(ratio + 1.0).log2().max(0.0)
}
pub fn response_time(&self, interval_days: f64, ease: f64, stability: f64) -> f64 {
if interval_days < 0.0 || ease <= 0.0 || stability <= 0.0 {
return self.params.a;
}
let id = self.index_of_difficulty(interval_days, ease, stability);
if !id.is_finite() {
return self.params.a;
}
let rt = self.params.a + self.params.b * id;
rt.max(0.0)
}
pub fn retrievability(&self, interval_days: f64, ease: f64, stability: f64) -> f64 {
let rt = self.response_time(interval_days, ease, stability);
let threshold = self.params.threshold.max(0.1);
let scale = threshold / 2.0; let exponent = ((rt - threshold) / scale).clamp(-700.0, 700.0);
let r = 1.0 / (1.0 + exponent.exp());
r.clamp(0.0, 1.0)
}
pub fn predict(&self, interval_days: f64, ease: f64, stability: f64) -> (f64, f64) {
let rt = self.response_time(interval_days, ease, stability);
let r = self.retrievability(interval_days, ease, stability);
(rt, r)
}
pub fn calibrate(
&mut self,
interval_days: f64,
ease: f64,
stability: f64,
actual_rt_seconds: f64,
) -> CalibrationResult {
let predicted_rt = self.response_time(interval_days, ease, stability);
let error = actual_rt_seconds - predicted_rt;
let id = self.index_of_difficulty(interval_days, ease, stability);
let lr = self.params.learning_rate;
if lr > 0.0 {
let a_update = lr * error;
let b_update = lr * error * id;
self.params.a = (self.params.a + a_update).clamp(0.01, 10.0);
self.params.b = (self.params.b + b_update).clamp(0.01, 5.0);
}
CalibrationResult {
predicted_rt,
actual_rt: actual_rt_seconds,
error,
new_a: self.params.a,
new_b: self.params.b,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct CalibrationResult {
pub predicted_rt: f64,
pub actual_rt: f64,
pub error: f64,
pub new_a: f64,
pub new_b: f64,
}
impl fmt::Display for FittsModel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"FittsModel(a={:.3}, b={:.3}, lr={:.3})",
self.params.a, self.params.b, self.params.learning_rate
)
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum FittsError {
InvalidParameters(String),
ComputationError(String),
}
impl fmt::Display for FittsError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
FittsError::InvalidParameters(msg) => write!(f, "Invalid parameters: {msg}"),
FittsError::ComputationError(msg) => write!(f, "Computation error: {msg}"),
}
}
}
impl std::error::Error for FittsError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_creation() {
let model = FittsModel::new(0.5, 0.3);
assert_eq!(model.params.a, 0.5);
assert_eq!(model.params.b, 0.3);
}
#[test]
fn test_memory_distance() {
let model = FittsModel::default();
let dist1 = model.memory_distance(1.0, 2.0);
let dist2 = model.memory_distance(10.0, 2.0);
assert!(dist2 > dist1, "Distance should increase with interval");
let dist_low = model.memory_distance(7.0, 1.5);
let dist_high = model.memory_distance(7.0, 2.5);
assert!(dist_low > dist_high, "Distance should decrease with ease");
}
#[test]
fn test_memory_accessibility() {
let model = FittsModel::default();
let acc1 = model.memory_accessibility(1.0, 2.0);
let acc2 = model.memory_accessibility(10.0, 2.0);
assert!(acc2 > acc1);
let acc_low = model.memory_accessibility(5.0, 1.5);
let acc_high = model.memory_accessibility(5.0, 2.5);
assert!(acc_high > acc_low);
}
#[test]
fn test_response_time_properties() {
let model = FittsModel::new(0.5, 0.3);
let rt = model.response_time(7.0, 2.0, 5.0);
assert!(rt > 0.0);
assert!(rt >= model.params.a);
}
#[test]
fn test_retrievability_bounds() {
let model = FittsModel::default();
for interval in [0.0, 1.0, 7.0, 30.0, 365.0] {
let r = model.retrievability(interval, 2.0, 5.0);
assert!((0.0..=1.0).contains(&r), "Retrievability must be in [0,1]");
}
}
#[test]
fn test_retrievability_decreases_with_rt() {
let model = FittsModel::default();
let r_easy = model.retrievability(1.0, 2.5, 10.0);
let r_hard = model.retrievability(30.0, 1.5, 1.0);
assert!(
r_easy > r_hard,
"Easy cards should have higher retrievability: {} vs {}",
r_easy,
r_hard
);
}
#[test]
fn test_calibration_reduces_error() {
let mut model = FittsModel::with_learning_rate(0.5, 0.3, 0.1);
let interval = 7.0;
let ease = 2.0;
let stability = 5.0;
let actual_rt = 2.0;
let initial_prediction = model.response_time(interval, ease, stability);
let initial_error = (actual_rt - initial_prediction).abs();
for _ in 0..10 {
model.calibrate(interval, ease, stability, actual_rt);
}
let final_prediction = model.response_time(interval, ease, stability);
let final_error = (actual_rt - final_prediction).abs();
assert!(
final_error < initial_error,
"Calibration should reduce error"
);
}
#[test]
fn test_calibration_disabled_when_lr_zero() {
let mut model = FittsModel::with_learning_rate(0.5, 0.3, 0.0);
let original_a = model.params.a;
let original_b = model.params.b;
model.calibrate(7.0, 2.0, 5.0, 10.0);
assert_eq!(model.params.a, original_a);
assert_eq!(model.params.b, original_b);
}
#[test]
fn test_parameter_validation() {
assert!(FittsParameters::new(0.5, 0.3, 1.0, 1.0, 2.0, 0.05).is_ok());
assert!(FittsParameters::new(-0.5, 0.3, 1.0, 1.0, 2.0, 0.05).is_err());
assert!(FittsParameters::new(0.5, 0.3, 1.0, 1.0, 2.0, 1.5).is_err());
}
#[test]
fn test_numerical_stability() {
let model = FittsModel::default();
let _ = model.response_time(0.0, 1.3, 0.01);
let _ = model.response_time(1000000.0, 5.0, 1000.0);
let _ = model.response_time(f64::MAX, 2.5, 1.0);
}
}