1use crate::aes::Aesthetic;
2use crate::data::{DataFrame, Value};
3use crate::scale::ScaleSet;
4
5use super::Stat;
6
7fn qnorm(p: f64) -> f64 {
10 if p <= 0.0 {
11 return f64::NEG_INFINITY;
12 }
13 if p >= 1.0 {
14 return f64::INFINITY;
15 }
16
17 if p < 0.5 {
19 -rational_approx((-2.0 * p.ln()).sqrt())
20 } else if p > 0.5 {
21 rational_approx((-2.0 * (1.0 - p).ln()).sqrt())
22 } else {
23 0.0
24 }
25}
26
27fn rational_approx(t: f64) -> f64 {
28 let c0 = 2.515_517;
30 let c1 = 0.802_853;
31 let c2 = 0.010_328;
32 let d1 = 1.432_788;
33 let d2 = 0.189_269;
34 let d3 = 0.001_308;
35
36 t - (c0 + c1 * t + c2 * t * t) / (1.0 + d1 * t + d2 * t * t + d3 * t * t * t)
37}
38
39fn quantile_type7(sorted: &[f64], p: f64) -> f64 {
41 let n = sorted.len();
42 if n == 0 {
43 return 0.0;
44 }
45 if n == 1 {
46 return sorted[0];
47 }
48 let h = (n - 1) as f64 * p;
49 let lo = h.floor() as usize;
50 let hi = (lo + 1).min(n - 1);
51 let frac = h - lo as f64;
52 sorted[lo] + frac * (sorted[hi] - sorted[lo])
53}
54
55pub struct StatQQ;
58
59impl Stat for StatQQ {
60 fn compute_group(&self, data: &DataFrame, _scales: &ScaleSet) -> DataFrame {
61 let y_col = match data.column("y") {
62 Some(c) => c,
63 None => return DataFrame::new(),
64 };
65
66 let mut values: Vec<f64> = y_col.iter().filter_map(|v| v.as_f64()).collect();
67 if values.is_empty() {
68 return DataFrame::new();
69 }
70
71 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
72 let n = values.len();
73
74 let mut x_vals = Vec::with_capacity(n);
75 let mut y_vals = Vec::with_capacity(n);
76
77 for (i, &val) in values.iter().enumerate() {
78 let a = if n > 10 { 3.0 / 8.0 } else { 0.5 };
80 let p = (i as f64 + 1.0 - a) / (n as f64 + 1.0 - 2.0 * a);
81 let theoretical = qnorm(p);
82 x_vals.push(Value::Float(theoretical));
83 y_vals.push(Value::Float(val));
84 }
85
86 let mut result = DataFrame::new();
87 result.add_column("x".to_string(), x_vals);
88 result.add_column("y".to_string(), y_vals);
89
90 for col_name in &["color", "fill", "group"] {
92 if let Some(col) = data.column(col_name) {
93 if let Some(first) = col.first() {
94 result.add_column(col_name.to_string(), vec![first.clone(); n]);
95 }
96 }
97 }
98
99 result
100 }
101
102 fn required_aes(&self) -> Vec<Aesthetic> {
103 vec![Aesthetic::Y]
104 }
105
106 fn name(&self) -> &str {
107 "qq"
108 }
109}
110
111pub struct StatQQLine;
114
115impl Stat for StatQQLine {
116 fn compute_group(&self, data: &DataFrame, _scales: &ScaleSet) -> DataFrame {
117 let y_col = match data.column("y") {
118 Some(c) => c,
119 None => return DataFrame::new(),
120 };
121
122 let mut values: Vec<f64> = y_col.iter().filter_map(|v| v.as_f64()).collect();
123 if values.len() < 4 {
124 return DataFrame::new();
125 }
126
127 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
128 let n = values.len();
129
130 let sample_q1 = quantile_type7(&values, 0.25);
132 let sample_q3 = quantile_type7(&values, 0.75);
133
134 let theo_q1 = qnorm(0.25);
136 let theo_q3 = qnorm(0.75);
137
138 let slope = (sample_q3 - sample_q1) / (theo_q3 - theo_q1);
140 let intercept = sample_q1 - slope * theo_q1;
141
142 let a = if n > 10 { 3.0 / 8.0 } else { 0.5 };
144 let x_min = qnorm((1.0 - a) / (n as f64 + 1.0 - 2.0 * a));
145 let x_max = qnorm((n as f64 - a) / (n as f64 + 1.0 - 2.0 * a));
146
147 let mut result = DataFrame::new();
148 result.add_column(
149 "x".to_string(),
150 vec![Value::Float(x_min), Value::Float(x_max)],
151 );
152 result.add_column(
153 "y".to_string(),
154 vec![
155 Value::Float(intercept + slope * x_min),
156 Value::Float(intercept + slope * x_max),
157 ],
158 );
159
160 for col_name in &["color", "fill", "group"] {
162 if let Some(col) = data.column(col_name) {
163 if let Some(first) = col.first() {
164 result.add_column(col_name.to_string(), vec![first.clone(); 2]);
165 }
166 }
167 }
168
169 result
170 }
171
172 fn required_aes(&self) -> Vec<Aesthetic> {
173 vec![Aesthetic::Y]
174 }
175
176 fn name(&self) -> &str {
177 "qq_line"
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 #[test]
186 fn test_qnorm_symmetry() {
187 let q = qnorm(0.5);
188 assert!((q).abs() < 0.01, "qnorm(0.5) should be ~0, got {q}");
189
190 let q1 = qnorm(0.25);
191 let q3 = qnorm(0.75);
192 assert!((q1 + q3).abs() < 0.01, "qnorm should be symmetric");
193 assert!(q1 < 0.0);
194 assert!(q3 > 0.0);
195 }
196
197 #[test]
198 fn test_stat_qq() {
199 let mut data = DataFrame::new();
200 let y_vals: Vec<Value> = (0..100).map(|i| Value::Float(i as f64)).collect();
201 data.add_column("y".to_string(), y_vals);
202
203 let stat = StatQQ;
204 let scales = ScaleSet::new();
205 let result = stat.compute_group(&data, &scales);
206
207 assert_eq!(result.nrows(), 100);
208 let x = result.column("x").unwrap();
209 let y = result.column("y").unwrap();
210 for i in 1..y.len() {
212 assert!(y[i].as_f64().unwrap() >= y[i - 1].as_f64().unwrap());
213 }
214 for i in 1..x.len() {
216 assert!(x[i].as_f64().unwrap() >= x[i - 1].as_f64().unwrap());
217 }
218 }
219
220 #[test]
221 fn test_stat_qq_line() {
222 let mut data = DataFrame::new();
223 let y_vals: Vec<Value> = (0..100).map(|i| Value::Float(i as f64)).collect();
224 data.add_column("y".to_string(), y_vals);
225
226 let stat = StatQQLine;
227 let scales = ScaleSet::new();
228 let result = stat.compute_group(&data, &scales);
229
230 assert_eq!(result.nrows(), 2);
231 }
232}