1use crate::error::FdarError;
15use crate::matrix::FdMatrix;
16use std::f64::consts::PI;
17
18#[derive(Debug, Clone, PartialEq)]
20#[non_exhaustive]
21pub struct AndrewsResult {
22 pub curves: FdMatrix,
24 pub argvals: Vec<f64>,
26 pub n_vars: usize,
28}
29
30#[derive(Debug, Clone, PartialEq)]
36#[non_exhaustive]
37pub struct AndrewsLoadings {
38 pub loadings: FdMatrix,
40 pub argvals: Vec<f64>,
42 pub n_vars: usize,
44}
45
46#[inline]
56fn andrews_basis(t: f64, k: usize) -> f64 {
57 if k == 0 {
58 return std::f64::consts::FRAC_1_SQRT_2;
59 }
60 let j = k.div_ceil(2) as f64;
61 if k % 2 == 1 {
62 (j * t).sin()
63 } else {
64 (j * t).cos()
65 }
66}
67
68#[must_use = "returns the Andrews curves without modifying the input"]
79pub fn andrews_transform(data: &FdMatrix, n_grid: usize) -> Result<AndrewsResult, FdarError> {
80 let (n, p) = data.shape();
81 if n == 0 || p == 0 {
82 return Err(FdarError::InvalidDimension {
83 parameter: "data",
84 expected: "non-zero rows and columns".to_string(),
85 actual: format!("{n} x {p}"),
86 });
87 }
88 if n_grid == 0 {
89 return Err(FdarError::InvalidParameter {
90 parameter: "n_grid",
91 message: "must be at least 1".to_string(),
92 });
93 }
94
95 let argvals = make_grid(n_grid);
96 let mut curves = FdMatrix::zeros(n, n_grid);
97
98 let basis_vals: Vec<Vec<f64>> = argvals
100 .iter()
101 .map(|&t| (0..p).map(|k| andrews_basis(t, k)).collect())
102 .collect();
103
104 for i in 0..n {
105 for (g, bv) in basis_vals.iter().enumerate() {
106 let mut val = 0.0;
107 for k in 0..p {
108 val += data[(i, k)] * bv[k];
109 }
110 curves[(i, g)] = val;
111 }
112 }
113
114 Ok(AndrewsResult {
115 curves,
116 argvals,
117 n_vars: p,
118 })
119}
120
121#[must_use = "returns the Andrews loadings without modifying the input"]
134pub fn andrews_loadings(rotation: &FdMatrix, n_grid: usize) -> Result<AndrewsLoadings, FdarError> {
135 let (m, ncomp) = rotation.shape();
136 if m == 0 || ncomp == 0 {
137 return Err(FdarError::InvalidDimension {
138 parameter: "rotation",
139 expected: "non-zero rows and columns".to_string(),
140 actual: format!("{m} x {ncomp}"),
141 });
142 }
143 if n_grid == 0 {
144 return Err(FdarError::InvalidParameter {
145 parameter: "n_grid",
146 message: "must be at least 1".to_string(),
147 });
148 }
149
150 let mut transposed_data = vec![0.0; ncomp * m];
153 for j in 0..ncomp {
154 let col = rotation.column(j);
155 for i in 0..m {
156 transposed_data[j + i * ncomp] = col[i];
159 }
160 }
161 let transposed = FdMatrix::from_column_major(transposed_data, ncomp, m)?;
162
163 let result = andrews_transform(&transposed, n_grid)?;
164
165 Ok(AndrewsLoadings {
166 loadings: result.curves,
167 argvals: result.argvals,
168 n_vars: m,
169 })
170}
171
172fn make_grid(n: usize) -> Vec<f64> {
174 if n == 1 {
175 return vec![0.0];
176 }
177 let step = 2.0 * PI / (n - 1) as f64;
178 (0..n).map(|i| -PI + i as f64 * step).collect()
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 fn row_major_matrix(data: &[f64], nrows: usize, ncols: usize) -> FdMatrix {
187 let mut col_major = vec![0.0; nrows * ncols];
188 for i in 0..nrows {
189 for j in 0..ncols {
190 col_major[i + j * nrows] = data[i * ncols + j];
191 }
192 }
193 FdMatrix::from_column_major(col_major, nrows, ncols).unwrap()
194 }
195
196 #[test]
197 fn andrews_basis_values() {
198 let t = 1.0;
199 assert!((andrews_basis(t, 0) - std::f64::consts::FRAC_1_SQRT_2).abs() < 1e-12);
200 assert!((andrews_basis(t, 1) - t.sin()).abs() < 1e-12);
201 assert!((andrews_basis(t, 2) - t.cos()).abs() < 1e-12);
202 assert!((andrews_basis(t, 3) - (2.0 * t).sin()).abs() < 1e-12);
203 assert!((andrews_basis(t, 4) - (2.0 * t).cos()).abs() < 1e-12);
204 assert!((andrews_basis(t, 5) - (3.0 * t).sin()).abs() < 1e-12);
205 assert!((andrews_basis(t, 6) - (3.0 * t).cos()).abs() < 1e-12);
206 }
207
208 #[test]
209 fn constant_curve_from_unit_first_var() {
210 let data = row_major_matrix(&[1.0, 0.0, 0.0], 1, 3);
212 let result = andrews_transform(&data, 50).unwrap();
213 assert_eq!(result.curves.nrows(), 1);
214 assert_eq!(result.curves.ncols(), 50);
215 for g in 0..50 {
216 assert!(
217 (result.curves[(0, g)] - std::f64::consts::FRAC_1_SQRT_2).abs() < 1e-12,
218 "grid point {g}: expected 1/sqrt(2), got {}",
219 result.curves[(0, g)]
220 );
221 }
222 }
223
224 #[test]
225 fn sin_curve_from_unit_second_var() {
226 let data = row_major_matrix(&[0.0, 1.0, 0.0], 1, 3);
228 let result = andrews_transform(&data, 100).unwrap();
229 for (g, &t) in result.argvals.iter().enumerate() {
230 assert!(
231 (result.curves[(0, g)] - t.sin()).abs() < 1e-12,
232 "at t={t}: expected sin(t)={}, got {}",
233 t.sin(),
234 result.curves[(0, g)]
235 );
236 }
237 }
238
239 #[test]
240 fn cos_curve_from_unit_third_var() {
241 let data = row_major_matrix(&[0.0, 0.0, 1.0], 1, 3);
243 let result = andrews_transform(&data, 100).unwrap();
244 for (g, &t) in result.argvals.iter().enumerate() {
245 assert!(
246 (result.curves[(0, g)] - t.cos()).abs() < 1e-12,
247 "at t={t}: expected cos(t)={}, got {}",
248 t.cos(),
249 result.curves[(0, g)]
250 );
251 }
252 }
253
254 #[test]
255 fn correct_output_dimensions() {
256 let data = row_major_matrix(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3);
257 let result = andrews_transform(&data, 75).unwrap();
258 assert_eq!(result.curves.nrows(), 2);
259 assert_eq!(result.curves.ncols(), 75);
260 assert_eq!(result.argvals.len(), 75);
261 assert_eq!(result.n_vars, 3);
262 }
263
264 #[test]
265 fn error_on_empty_data() {
266 let data = FdMatrix::zeros(0, 0);
267 let err = andrews_transform(&data, 50).unwrap_err();
268 assert!(matches!(err, FdarError::InvalidDimension { .. }));
269 }
270
271 #[test]
272 fn error_on_zero_grid() {
273 let data = row_major_matrix(&[1.0, 2.0], 1, 2);
274 let err = andrews_transform(&data, 0).unwrap_err();
275 assert!(matches!(err, FdarError::InvalidParameter { .. }));
276 }
277
278 #[test]
279 fn error_on_zero_rows() {
280 let data = FdMatrix::zeros(0, 3);
281 let err = andrews_transform(&data, 50).unwrap_err();
282 assert!(matches!(err, FdarError::InvalidDimension { .. }));
283 }
284
285 #[test]
286 fn error_on_zero_cols() {
287 let data = FdMatrix::zeros(5, 0);
288 let err = andrews_transform(&data, 50).unwrap_err();
289 assert!(matches!(err, FdarError::InvalidDimension { .. }));
290 }
291
292 #[test]
293 fn andrews_loadings_correct_shape() {
294 let rotation = FdMatrix::zeros(10, 3);
296 let result = andrews_loadings(&rotation, 50).unwrap();
297 assert_eq!(result.loadings.nrows(), 3);
298 assert_eq!(result.loadings.ncols(), 50);
299 assert_eq!(result.argvals.len(), 50);
300 assert_eq!(result.n_vars, 10);
301 }
302
303 #[test]
304 fn andrews_loadings_identity_column() {
305 let rotation = row_major_matrix(&[1.0, 0.0, 0.0], 3, 1);
308 let loadings = andrews_loadings(&rotation, 50).unwrap();
309
310 let data = row_major_matrix(&[1.0, 0.0, 0.0], 1, 3);
311 let direct = andrews_transform(&data, 50).unwrap();
312
313 for g in 0..50 {
314 assert!(
315 (loadings.loadings[(0, g)] - direct.curves[(0, g)]).abs() < 1e-12,
316 "grid point {g}: loadings {} vs direct {}",
317 loadings.loadings[(0, g)],
318 direct.curves[(0, g)]
319 );
320 }
321 }
322
323 #[test]
324 fn deterministic_output() {
325 let data = row_major_matrix(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], 3, 3);
326 let r1 = andrews_transform(&data, 50).unwrap();
327 let r2 = andrews_transform(&data, 50).unwrap();
328 assert_eq!(r1.curves, r2.curves);
329 assert_eq!(r1.argvals, r2.argvals);
330 }
331
332 #[test]
333 fn grid_endpoints() {
334 let data = row_major_matrix(&[1.0], 1, 1);
335 let result = andrews_transform(&data, 101).unwrap();
336 assert!((result.argvals[0] - (-PI)).abs() < 1e-12);
337 assert!((result.argvals[100] - PI).abs() < 1e-12);
338 }
339
340 #[test]
341 fn single_grid_point() {
342 let data = row_major_matrix(&[1.0, 2.0], 1, 2);
343 let result = andrews_transform(&data, 1).unwrap();
344 assert_eq!(result.curves.ncols(), 1);
345 assert_eq!(result.argvals.len(), 1);
346 assert!((result.curves[(0, 0)] - std::f64::consts::FRAC_1_SQRT_2).abs() < 1e-12);
348 }
349
350 #[test]
351 fn linearity() {
352 let x = row_major_matrix(&[1.0, 2.0, 3.0], 1, 3);
354 let y = row_major_matrix(&[4.0, -1.0, 0.5], 1, 3);
355 let combined = row_major_matrix(
356 &[
357 2.0 * 1.0 + 3.0 * 4.0,
358 2.0 * 2.0 - 3.0,
359 2.0 * 3.0 + 3.0 * 0.5,
360 ],
361 1,
362 3,
363 );
364
365 let n_grid = 50;
366 let tx = andrews_transform(&x, n_grid).unwrap();
367 let ty = andrews_transform(&y, n_grid).unwrap();
368 let tc = andrews_transform(&combined, n_grid).unwrap();
369
370 for g in 0..n_grid {
371 let expected = 2.0 * tx.curves[(0, g)] + 3.0 * ty.curves[(0, g)];
372 assert!(
373 (tc.curves[(0, g)] - expected).abs() < 1e-10,
374 "linearity failed at grid point {g}: {} vs {expected}",
375 tc.curves[(0, g)]
376 );
377 }
378 }
379
380 #[test]
381 fn andrews_loadings_error_on_empty() {
382 let rotation = FdMatrix::zeros(0, 0);
383 let err = andrews_loadings(&rotation, 50).unwrap_err();
384 assert!(matches!(err, FdarError::InvalidDimension { .. }));
385 }
386
387 #[test]
388 fn andrews_loadings_error_on_zero_grid() {
389 let rotation = FdMatrix::zeros(5, 2);
390 let err = andrews_loadings(&rotation, 0).unwrap_err();
391 assert!(matches!(err, FdarError::InvalidParameter { .. }));
392 }
393}