use serde::{Deserialize, Serialize};
use crate::error::{BodhError, Result, validate_finite, validate_positive};
#[inline]
#[must_use = "returns the probability of correct response without side effects"]
pub fn rasch_probability(ability: f64, difficulty: f64) -> Result<f64> {
validate_finite(ability, "ability")?;
validate_finite(difficulty, "difficulty")?;
Ok(logistic(ability - difficulty))
}
#[inline]
#[must_use = "returns the probability of correct response without side effects"]
pub fn two_pl_probability(ability: f64, difficulty: f64, discrimination: f64) -> Result<f64> {
validate_finite(ability, "ability")?;
validate_finite(difficulty, "difficulty")?;
validate_positive(discrimination, "discrimination")?;
Ok(logistic(discrimination * (ability - difficulty)))
}
#[must_use = "returns the probability of correct response without side effects"]
pub fn three_pl_probability(
ability: f64,
difficulty: f64,
discrimination: f64,
guessing: f64,
) -> Result<f64> {
validate_finite(ability, "ability")?;
validate_finite(difficulty, "difficulty")?;
validate_positive(discrimination, "discrimination")?;
validate_finite(guessing, "guessing")?;
if !(0.0..1.0).contains(&guessing) {
return Err(BodhError::InvalidParameter(
"guessing must be in [0, 1)".into(),
));
}
let p_star = logistic(discrimination * (ability - difficulty));
Ok(guessing + (1.0 - guessing) * p_star)
}
#[inline]
#[must_use = "returns the item information without side effects"]
pub fn item_information_2pl(ability: f64, difficulty: f64, discrimination: f64) -> Result<f64> {
let p = two_pl_probability(ability, difficulty, discrimination)?;
Ok(discrimination * discrimination * p * (1.0 - p))
}
#[must_use = "returns the item information without side effects"]
pub fn item_information_3pl(
ability: f64,
difficulty: f64,
discrimination: f64,
guessing: f64,
) -> Result<f64> {
let p = three_pl_probability(ability, difficulty, discrimination, guessing)?;
if p < 1e-15 {
return Ok(0.0);
}
let ratio = (p - guessing) / ((1.0 - guessing) * p);
Ok(discrimination * discrimination * ratio * ratio * p * (1.0 - p))
}
#[must_use = "returns the test information without side effects"]
pub fn test_information_2pl(ability: f64, items: &[(f64, f64)]) -> Result<f64> {
validate_finite(ability, "ability")?;
let mut total = 0.0;
for &(difficulty, discrimination) in items {
total += item_information_2pl(ability, difficulty, discrimination)?;
}
Ok(total)
}
#[inline]
#[must_use = "returns the standard error without side effects"]
pub fn ability_standard_error(test_info: f64) -> Result<f64> {
validate_finite(test_info, "test_info")?;
if test_info <= 0.0 {
return Err(BodhError::ComputationError(
"test information must be positive for SE computation".into(),
));
}
Ok(1.0 / test_info.sqrt())
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct ItemParameters {
pub difficulty: f64,
pub discrimination: f64,
pub guessing: f64,
}
impl ItemParameters {
#[must_use = "returns the probability without side effects"]
pub fn probability(&self, ability: f64) -> Result<f64> {
three_pl_probability(ability, self.difficulty, self.discrimination, self.guessing)
}
#[must_use = "returns the information without side effects"]
pub fn information(&self, ability: f64) -> Result<f64> {
item_information_3pl(ability, self.difficulty, self.discrimination, self.guessing)
}
}
#[inline]
fn logistic(x: f64) -> f64 {
1.0 / (1.0 + (-x).exp())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rasch_at_difficulty() {
let p = rasch_probability(1.0, 1.0).unwrap();
assert!((p - 0.5).abs() < 1e-10);
}
#[test]
fn test_rasch_high_ability() {
let p = rasch_probability(3.0, 0.0).unwrap();
assert!(p > 0.95);
}
#[test]
fn test_rasch_low_ability() {
let p = rasch_probability(-3.0, 0.0).unwrap();
assert!(p < 0.05);
}
#[test]
fn test_rasch_monotonic() {
let p1 = rasch_probability(0.0, 1.0).unwrap();
let p2 = rasch_probability(1.0, 1.0).unwrap();
let p3 = rasch_probability(2.0, 1.0).unwrap();
assert!(p1 < p2);
assert!(p2 < p3);
}
#[test]
fn test_2pl_at_difficulty() {
let p = two_pl_probability(1.0, 1.0, 1.5).unwrap();
assert!((p - 0.5).abs() < 1e-10);
}
#[test]
fn test_2pl_high_discrimination_steeper() {
let p_low_a = two_pl_probability(1.5, 1.0, 0.5).unwrap();
let p_high_a = two_pl_probability(1.5, 1.0, 2.0).unwrap();
assert!(p_high_a > p_low_a);
}
#[test]
fn test_2pl_matches_rasch_at_a1() {
let rasch = rasch_probability(1.5, 0.5).unwrap();
let twopl = two_pl_probability(1.5, 0.5, 1.0).unwrap();
assert!((rasch - twopl).abs() < 1e-10);
}
#[test]
fn test_3pl_guessing_floor() {
let p = three_pl_probability(-10.0, 0.0, 1.0, 0.25).unwrap();
assert!((p - 0.25).abs() < 0.01);
}
#[test]
fn test_3pl_high_ability() {
let p = three_pl_probability(10.0, 0.0, 1.0, 0.25).unwrap();
assert!((p - 1.0).abs() < 0.01);
}
#[test]
fn test_3pl_no_guessing_matches_2pl() {
let twopl = two_pl_probability(1.0, 0.5, 1.5).unwrap();
let threepl = three_pl_probability(1.0, 0.5, 1.5, 0.0).unwrap();
assert!((twopl - threepl).abs() < 1e-10);
}
#[test]
fn test_3pl_invalid_guessing() {
assert!(three_pl_probability(1.0, 0.0, 1.0, -0.1).is_err());
assert!(three_pl_probability(1.0, 0.0, 1.0, 1.0).is_err());
}
#[test]
fn test_info_2pl_peaks_at_difficulty() {
let info_at_b = item_information_2pl(1.0, 1.0, 1.5).unwrap();
let info_away = item_information_2pl(3.0, 1.0, 1.5).unwrap();
assert!(info_at_b > info_away);
}
#[test]
fn test_info_2pl_known_value() {
let info = item_information_2pl(0.0, 0.0, 2.0).unwrap();
assert!((info - 1.0).abs() < 1e-10); }
#[test]
fn test_info_increases_with_discrimination() {
let low = item_information_2pl(0.0, 0.0, 0.5).unwrap();
let high = item_information_2pl(0.0, 0.0, 2.0).unwrap();
assert!(high > low);
}
#[test]
fn test_info_3pl_less_than_2pl() {
let info_2pl = item_information_2pl(0.0, 0.0, 1.5).unwrap();
let info_3pl = item_information_3pl(0.0, 0.0, 1.5, 0.25).unwrap();
assert!(info_2pl > info_3pl);
}
#[test]
fn test_test_information_additive() {
let items = vec![(0.0, 1.0), (1.0, 1.0), (-1.0, 1.0)];
let ti = test_information_2pl(0.0, &items).unwrap();
let sum: f64 = items
.iter()
.map(|&(b, a)| item_information_2pl(0.0, b, a).unwrap())
.sum();
assert!((ti - sum).abs() < 1e-10);
}
#[test]
fn test_ability_se() {
let se = ability_standard_error(4.0).unwrap();
assert!((se - 0.5).abs() < 1e-10);
}
#[test]
fn test_ability_se_zero_info() {
assert!(ability_standard_error(0.0).is_err());
}
#[test]
fn test_item_params_probability() {
let item = ItemParameters {
difficulty: 0.0,
discrimination: 1.0,
guessing: 0.25,
};
let p = item.probability(-10.0).unwrap();
assert!((p - 0.25).abs() < 0.01);
}
#[test]
fn test_item_params_information() {
let item = ItemParameters {
difficulty: 0.0,
discrimination: 1.5,
guessing: 0.0,
};
let info = item.information(0.0).unwrap();
let expected = item_information_3pl(0.0, 0.0, 1.5, 0.0).unwrap();
assert!((info - expected).abs() < 1e-10);
}
#[test]
fn test_item_parameters_serde_roundtrip() {
let item = ItemParameters {
difficulty: 1.2,
discrimination: 0.8,
guessing: 0.2,
};
let json = serde_json::to_string(&item).unwrap();
let back: ItemParameters = serde_json::from_str(&json).unwrap();
assert!((item.difficulty - back.difficulty).abs() < 1e-10);
assert!((item.guessing - back.guessing).abs() < 1e-10);
}
}