Skip to main content

ggplot_rs/stat/
qq.rs

1use crate::aes::Aesthetic;
2use crate::data::{DataFrame, Value};
3use crate::scale::ScaleSet;
4
5use super::Stat;
6
7/// Rational approximation of the inverse normal CDF (probit function).
8/// Uses Abramowitz and Stegun approximation (26.2.23).
9fn qnorm(p: f64) -> f64 {
10    if p <= 0.0 {
11        return f64::NEG_INFINITY;
12    }
13    if p >= 1.0 {
14        return f64::INFINITY;
15    }
16
17    // Use symmetry: if p > 0.5, negate result of 1-p
18    if p < 0.5 {
19        -rational_approx((-2.0 * p.ln()).sqrt())
20    } else if p > 0.5 {
21        rational_approx((-2.0 * (1.0 - p).ln()).sqrt())
22    } else {
23        0.0
24    }
25}
26
27fn rational_approx(t: f64) -> f64 {
28    // Coefficients for rational approximation
29    let c0 = 2.515_517;
30    let c1 = 0.802_853;
31    let c2 = 0.010_328;
32    let d1 = 1.432_788;
33    let d2 = 0.189_269;
34    let d3 = 0.001_308;
35
36    t - (c0 + c1 * t + c2 * t * t) / (1.0 + d1 * t + d2 * t * t + d3 * t * t * t)
37}
38
39/// R-compatible type-7 quantile interpolation (R's default `quantile()` method).
40fn quantile_type7(sorted: &[f64], p: f64) -> f64 {
41    let n = sorted.len();
42    if n == 0 {
43        return 0.0;
44    }
45    if n == 1 {
46        return sorted[0];
47    }
48    let h = (n - 1) as f64 * p;
49    let lo = h.floor() as usize;
50    let hi = (lo + 1).min(n - 1);
51    let frac = h - lo as f64;
52    sorted[lo] + frac * (sorted[hi] - sorted[lo])
53}
54
55/// StatQQ: sort sample, compute theoretical normal quantiles.
56/// Output: x (theoretical quantiles), y (sample sorted).
57pub struct StatQQ;
58
59impl Stat for StatQQ {
60    fn compute_group(&self, data: &DataFrame, _scales: &ScaleSet) -> DataFrame {
61        let y_col = match data.column("y") {
62            Some(c) => c,
63            None => return DataFrame::new(),
64        };
65
66        let mut values: Vec<f64> = y_col.iter().filter_map(|v| v.as_f64()).collect();
67        if values.is_empty() {
68            return DataFrame::new();
69        }
70
71        values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
72        let n = values.len();
73
74        let mut x_vals = Vec::with_capacity(n);
75        let mut y_vals = Vec::with_capacity(n);
76
77        for (i, &val) in values.iter().enumerate() {
78            // R's ppoints(): (i + 1 - a) / (n + 1 - 2*a) where a = 3/8 for n > 10
79            let a = if n > 10 { 3.0 / 8.0 } else { 0.5 };
80            let p = (i as f64 + 1.0 - a) / (n as f64 + 1.0 - 2.0 * a);
81            let theoretical = qnorm(p);
82            x_vals.push(Value::Float(theoretical));
83            y_vals.push(Value::Float(val));
84        }
85
86        let mut result = DataFrame::new();
87        result.add_column("x".to_string(), x_vals);
88        result.add_column("y".to_string(), y_vals);
89
90        // Carry over grouping columns
91        for col_name in &["color", "fill", "group"] {
92            if let Some(col) = data.column(col_name) {
93                if let Some(first) = col.first() {
94                    result.add_column(col_name.to_string(), vec![first.clone(); n]);
95                }
96            }
97        }
98
99        result
100    }
101
102    fn required_aes(&self) -> Vec<Aesthetic> {
103        vec![Aesthetic::Y]
104    }
105
106    fn name(&self) -> &str {
107        "qq"
108    }
109}
110
111/// StatQQLine: fit line through Q1/Q3 of sample vs theoretical.
112/// Output: x, y (two points defining the reference line).
113pub struct StatQQLine;
114
115impl Stat for StatQQLine {
116    fn compute_group(&self, data: &DataFrame, _scales: &ScaleSet) -> DataFrame {
117        let y_col = match data.column("y") {
118            Some(c) => c,
119            None => return DataFrame::new(),
120        };
121
122        let mut values: Vec<f64> = y_col.iter().filter_map(|v| v.as_f64()).collect();
123        if values.len() < 4 {
124            return DataFrame::new();
125        }
126
127        values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
128        let n = values.len();
129
130        // Sample Q1 and Q3 using R-compatible type-7 quantile interpolation
131        let sample_q1 = quantile_type7(&values, 0.25);
132        let sample_q3 = quantile_type7(&values, 0.75);
133
134        // Theoretical Q1 and Q3
135        let theo_q1 = qnorm(0.25);
136        let theo_q3 = qnorm(0.75);
137
138        // Line through (theo_q1, sample_q1) and (theo_q3, sample_q3)
139        let slope = (sample_q3 - sample_q1) / (theo_q3 - theo_q1);
140        let intercept = sample_q1 - slope * theo_q1;
141
142        // Extend line to cover full theoretical range using R's ppoints formula
143        let a = if n > 10 { 3.0 / 8.0 } else { 0.5 };
144        let x_min = qnorm((1.0 - a) / (n as f64 + 1.0 - 2.0 * a));
145        let x_max = qnorm((n as f64 - a) / (n as f64 + 1.0 - 2.0 * a));
146
147        let mut result = DataFrame::new();
148        result.add_column(
149            "x".to_string(),
150            vec![Value::Float(x_min), Value::Float(x_max)],
151        );
152        result.add_column(
153            "y".to_string(),
154            vec![
155                Value::Float(intercept + slope * x_min),
156                Value::Float(intercept + slope * x_max),
157            ],
158        );
159
160        // Carry over grouping columns
161        for col_name in &["color", "fill", "group"] {
162            if let Some(col) = data.column(col_name) {
163                if let Some(first) = col.first() {
164                    result.add_column(col_name.to_string(), vec![first.clone(); 2]);
165                }
166            }
167        }
168
169        result
170    }
171
172    fn required_aes(&self) -> Vec<Aesthetic> {
173        vec![Aesthetic::Y]
174    }
175
176    fn name(&self) -> &str {
177        "qq_line"
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    #[test]
186    fn test_qnorm_symmetry() {
187        let q = qnorm(0.5);
188        assert!((q).abs() < 0.01, "qnorm(0.5) should be ~0, got {q}");
189
190        let q1 = qnorm(0.25);
191        let q3 = qnorm(0.75);
192        assert!((q1 + q3).abs() < 0.01, "qnorm should be symmetric");
193        assert!(q1 < 0.0);
194        assert!(q3 > 0.0);
195    }
196
197    #[test]
198    fn test_stat_qq() {
199        let mut data = DataFrame::new();
200        let y_vals: Vec<Value> = (0..100).map(|i| Value::Float(i as f64)).collect();
201        data.add_column("y".to_string(), y_vals);
202
203        let stat = StatQQ;
204        let scales = ScaleSet::new();
205        let result = stat.compute_group(&data, &scales);
206
207        assert_eq!(result.nrows(), 100);
208        let x = result.column("x").unwrap();
209        let y = result.column("y").unwrap();
210        // y should be sorted
211        for i in 1..y.len() {
212            assert!(y[i].as_f64().unwrap() >= y[i - 1].as_f64().unwrap());
213        }
214        // x should be sorted (theoretical quantiles)
215        for i in 1..x.len() {
216            assert!(x[i].as_f64().unwrap() >= x[i - 1].as_f64().unwrap());
217        }
218    }
219
220    #[test]
221    fn test_stat_qq_line() {
222        let mut data = DataFrame::new();
223        let y_vals: Vec<Value> = (0..100).map(|i| Value::Float(i as f64)).collect();
224        data.add_column("y".to_string(), y_vals);
225
226        let stat = StatQQLine;
227        let scales = ScaleSet::new();
228        let result = stat.compute_group(&data, &scales);
229
230        assert_eq!(result.nrows(), 2);
231    }
232}