use crate::aes::Aesthetic;
use crate::data::{DataFrame, Value};
use crate::scale::ScaleSet;
use super::Stat;
fn qnorm(p: f64) -> f64 {
if p <= 0.0 {
return f64::NEG_INFINITY;
}
if p >= 1.0 {
return f64::INFINITY;
}
if p < 0.5 {
-rational_approx((-2.0 * p.ln()).sqrt())
} else if p > 0.5 {
rational_approx((-2.0 * (1.0 - p).ln()).sqrt())
} else {
0.0
}
}
fn rational_approx(t: f64) -> f64 {
let c0 = 2.515_517;
let c1 = 0.802_853;
let c2 = 0.010_328;
let d1 = 1.432_788;
let d2 = 0.189_269;
let d3 = 0.001_308;
t - (c0 + c1 * t + c2 * t * t) / (1.0 + d1 * t + d2 * t * t + d3 * t * t * t)
}
fn quantile_type7(sorted: &[f64], p: f64) -> f64 {
let n = sorted.len();
if n == 0 {
return 0.0;
}
if n == 1 {
return sorted[0];
}
let h = (n - 1) as f64 * p;
let lo = h.floor() as usize;
let hi = (lo + 1).min(n - 1);
let frac = h - lo as f64;
sorted[lo] + frac * (sorted[hi] - sorted[lo])
}
pub struct StatQQ;
impl Stat for StatQQ {
fn compute_group(&self, data: &DataFrame, _scales: &ScaleSet) -> DataFrame {
let y_col = match data.column("y") {
Some(c) => c,
None => return DataFrame::new(),
};
let mut values: Vec<f64> = y_col.iter().filter_map(|v| v.as_f64()).collect();
if values.is_empty() {
return DataFrame::new();
}
values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let n = values.len();
let mut x_vals = Vec::with_capacity(n);
let mut y_vals = Vec::with_capacity(n);
for (i, &val) in values.iter().enumerate() {
let a = if n > 10 { 3.0 / 8.0 } else { 0.5 };
let p = (i as f64 + 1.0 - a) / (n as f64 + 1.0 - 2.0 * a);
let theoretical = qnorm(p);
x_vals.push(Value::Float(theoretical));
y_vals.push(Value::Float(val));
}
let mut result = DataFrame::new();
result.add_column("x".to_string(), x_vals);
result.add_column("y".to_string(), y_vals);
for col_name in &["color", "fill", "group"] {
if let Some(col) = data.column(col_name) {
if let Some(first) = col.first() {
result.add_column(col_name.to_string(), vec![first.clone(); n]);
}
}
}
result
}
fn required_aes(&self) -> Vec<Aesthetic> {
vec![Aesthetic::Y]
}
fn name(&self) -> &str {
"qq"
}
}
pub struct StatQQLine;
impl Stat for StatQQLine {
fn compute_group(&self, data: &DataFrame, _scales: &ScaleSet) -> DataFrame {
let y_col = match data.column("y") {
Some(c) => c,
None => return DataFrame::new(),
};
let mut values: Vec<f64> = y_col.iter().filter_map(|v| v.as_f64()).collect();
if values.len() < 4 {
return DataFrame::new();
}
values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let n = values.len();
let sample_q1 = quantile_type7(&values, 0.25);
let sample_q3 = quantile_type7(&values, 0.75);
let theo_q1 = qnorm(0.25);
let theo_q3 = qnorm(0.75);
let slope = (sample_q3 - sample_q1) / (theo_q3 - theo_q1);
let intercept = sample_q1 - slope * theo_q1;
let a = if n > 10 { 3.0 / 8.0 } else { 0.5 };
let x_min = qnorm((1.0 - a) / (n as f64 + 1.0 - 2.0 * a));
let x_max = qnorm((n as f64 - a) / (n as f64 + 1.0 - 2.0 * a));
let mut result = DataFrame::new();
result.add_column(
"x".to_string(),
vec![Value::Float(x_min), Value::Float(x_max)],
);
result.add_column(
"y".to_string(),
vec![
Value::Float(intercept + slope * x_min),
Value::Float(intercept + slope * x_max),
],
);
for col_name in &["color", "fill", "group"] {
if let Some(col) = data.column(col_name) {
if let Some(first) = col.first() {
result.add_column(col_name.to_string(), vec![first.clone(); 2]);
}
}
}
result
}
fn required_aes(&self) -> Vec<Aesthetic> {
vec![Aesthetic::Y]
}
fn name(&self) -> &str {
"qq_line"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_qnorm_symmetry() {
let q = qnorm(0.5);
assert!((q).abs() < 0.01, "qnorm(0.5) should be ~0, got {q}");
let q1 = qnorm(0.25);
let q3 = qnorm(0.75);
assert!((q1 + q3).abs() < 0.01, "qnorm should be symmetric");
assert!(q1 < 0.0);
assert!(q3 > 0.0);
}
#[test]
fn test_stat_qq() {
let mut data = DataFrame::new();
let y_vals: Vec<Value> = (0..100).map(|i| Value::Float(i as f64)).collect();
data.add_column("y".to_string(), y_vals);
let stat = StatQQ;
let scales = ScaleSet::new();
let result = stat.compute_group(&data, &scales);
assert_eq!(result.nrows(), 100);
let x = result.column("x").unwrap();
let y = result.column("y").unwrap();
for i in 1..y.len() {
assert!(y[i].as_f64().unwrap() >= y[i - 1].as_f64().unwrap());
}
for i in 1..x.len() {
assert!(x[i].as_f64().unwrap() >= x[i - 1].as_f64().unwrap());
}
}
#[test]
fn test_stat_qq_line() {
let mut data = DataFrame::new();
let y_vals: Vec<Value> = (0..100).map(|i| Value::Float(i as f64)).collect();
data.add_column("y".to_string(), y_vals);
let stat = StatQQLine;
let scales = ScaleSet::new();
let result = stat.compute_group(&data, &scales);
assert_eq!(result.nrows(), 2);
}
}