Skip to main content

fdars_core/
wire.rs

1//! Unified FDA data container for pipeline interchange.
2//!
3//! [`FdaData`] is a single layered container that flows between pipeline nodes.
4//! Nodes read from existing layers and add new ones — data is additive, never
5//! destructive. This replaces per-type wire enums with a composable structure.
6//!
7//! # Design
8//!
9//! - **Core**: curves (FdMatrix) + argvals + metadata (grouping, scalars)
10//! - **Layers**: optional analysis results keyed by [`LayerKey`]
11//! - Nodes declare what they *require* via `require_*` helpers
12//! - Nodes add results via `set_layer`
13//! - Layers compose: FPCA + Depth + Outliers can all coexist on one `FdaData`
14//!
15//! # Example
16//!
17//! ```
18//! use fdars_core::wire::*;
19//! use fdars_core::matrix::FdMatrix;
20//!
21//! let mut fd = FdaData::from_curves(
22//!     FdMatrix::zeros(10, 50),
23//!     (0..50).map(|i| i as f64 / 49.0).collect(),
24//! );
25//!
26//! // A depth node reads curves, adds a Depth layer
27//! let scores = vec![0.5; 10];
28//! fd.set_layer(LayerKey::Depth, Layer::Depth(DepthLayer {
29//!     scores,
30//!     method: "fraiman_muniz".into(),
31//! }));
32//!
33//! // Downstream node checks what's available
34//! assert!(fd.has_layer(&LayerKey::Depth));
35//! assert!(!fd.has_layer(&LayerKey::Fpca));
36//! ```
37
38use crate::matrix::FdMatrix;
39use std::collections::HashMap;
40
41// ─── Core Container ─────────────────────────────────────────────────────────
42
43/// Unified FDA data object for pipeline interchange.
44///
45/// Carries functional data (curves + domain) plus optional analysis layers.
46/// Nodes read what they need and add their results as new layers.
47#[derive(Debug, Clone)]
48#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
49pub struct FdaData {
50    // ── Core functional data ──
51    /// Functional observations (n × m). `None` for tabular-only data.
52    pub curves: Option<FdMatrix>,
53    /// Evaluation grid (length m).
54    pub argvals: Option<Vec<f64>>,
55
56    // ── Metadata ──
57    /// Named grouping variables (multiple allowed).
58    pub grouping: Vec<GroupVar>,
59    /// Named scalar variables (each length n).
60    pub scalar_vars: Vec<NamedVec>,
61    /// Tabular data for non-functional variables (n × p).
62    pub tabular: Option<FdMatrix>,
63    /// Column names for tabular data.
64    pub column_names: Option<Vec<String>>,
65
66    // ── Analysis layers ──
67    /// Analysis results keyed by layer type.
68    pub layers: HashMap<LayerKey, Layer>,
69}
70
71/// A named vector of f64 values (e.g., a scalar covariate or response).
72#[derive(Debug, Clone)]
73#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
74pub struct NamedVec {
75    pub name: String,
76    pub values: Vec<f64>,
77}
78
79/// Named grouping variable with string labels.
80#[derive(Debug, Clone)]
81#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
82pub struct GroupVar {
83    /// Variable name (e.g., "treatment", "sex").
84    pub name: String,
85    /// Per-observation labels (length n).
86    pub labels: Vec<String>,
87    /// Unique labels in order of first appearance.
88    pub unique: Vec<String>,
89}
90
91// ─── Layer Keys & Types ─────────────────────────────────────────────────────
92
93/// Key identifying a layer type.
94#[derive(Debug, Clone, PartialEq, Eq, Hash)]
95#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
96#[non_exhaustive]
97pub enum LayerKey {
98    /// Functional PCA decomposition.
99    Fpca,
100    /// PLS decomposition.
101    Pls,
102    /// Elastic alignment (Karcher mean + warps).
103    Alignment,
104    /// Precomputed n×n distance matrix.
105    Distances,
106    /// Functional depth scores.
107    Depth,
108    /// Outlier detection flags.
109    Outliers,
110    /// Cluster assignments.
111    Clusters,
112    /// Scalar-on-function regression fit.
113    Regression,
114    /// Function-on-scalar regression fit.
115    FunctionOnScalar,
116    /// Tolerance / confidence bands.
117    Tolerance,
118    /// Mean curve.
119    Mean,
120    /// SPM Phase I chart.
121    SpmChart,
122    /// SPM Phase II monitoring result.
123    SpmMonitor,
124    /// Explainability result (SHAP, PDP, etc.).
125    Explain,
126    /// User-defined extension.
127    Custom(String),
128}
129
130/// Analysis result attached to an [`FdaData`].
131#[derive(Debug, Clone)]
132#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
133#[non_exhaustive]
134pub enum Layer {
135    Fpca(FpcaLayer),
136    Pls(PlsLayer),
137    Alignment(AlignmentLayer),
138    Distances(DistancesLayer),
139    Depth(DepthLayer),
140    Outliers(OutlierLayer),
141    Clusters(ClusterLayer),
142    Regression(RegressionLayer),
143    FunctionOnScalar(FosrLayer),
144    Tolerance(ToleranceLayer),
145    Mean(MeanLayer),
146    SpmChart(SpmChartLayer),
147    SpmMonitor(SpmMonitorLayer),
148    Explain(ExplainLayer),
149    Custom(CustomLayer),
150}
151
152// ─── Layer Structs ──────────────────────────────────────────────────────────
153
154/// FPCA decomposition.
155#[derive(Debug, Clone)]
156#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
157pub struct FpcaLayer {
158    pub eigenvalues: Vec<f64>,
159    pub variance_explained: Vec<f64>,
160    /// Eigenfunctions (m × ncomp), each column is one eigenfunction.
161    pub eigenfunctions: FdMatrix,
162    /// Scores (n × ncomp).
163    pub scores: FdMatrix,
164    /// Mean function (length m).
165    pub mean: Vec<f64>,
166    /// Integration weights (length m).
167    pub weights: Vec<f64>,
168    pub ncomp: usize,
169}
170
171/// PLS decomposition.
172#[derive(Debug, Clone)]
173#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
174pub struct PlsLayer {
175    /// Weight vectors (m × ncomp).
176    pub weights: FdMatrix,
177    /// Scores (n × ncomp).
178    pub scores: FdMatrix,
179    /// Loadings (m × ncomp).
180    pub loadings: FdMatrix,
181    pub ncomp: usize,
182}
183
184/// Elastic alignment result.
185#[derive(Debug, Clone)]
186#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
187pub struct AlignmentLayer {
188    /// Aligned curves (n × m).
189    pub aligned: FdMatrix,
190    /// Warping functions (n × m).
191    pub warps: FdMatrix,
192    /// Karcher mean (length m).
193    pub mean: Vec<f64>,
194    /// Mean SRSF (length m).
195    pub mean_srsf: Vec<f64>,
196    /// Optional: number of alignment iterations performed.
197    pub n_iter: Option<usize>,
198    /// Optional: whether the alignment converged.
199    pub converged: Option<bool>,
200}
201
202/// Precomputed n×n distance matrix with method metadata.
203#[derive(Debug, Clone)]
204#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
205pub struct DistancesLayer {
206    /// Symmetric n×n distance matrix.
207    pub dist_mat: FdMatrix,
208    /// Distance method used (e.g., "elastic", "l2", "dtw", "amplitude", "phase").
209    pub method: String,
210}
211
212/// Functional depth scores.
213#[derive(Debug, Clone)]
214#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
215pub struct DepthLayer {
216    /// Depth score per observation (length n).
217    pub scores: Vec<f64>,
218    /// Method name.
219    pub method: String,
220}
221
222/// Outlier detection result.
223#[derive(Debug, Clone)]
224#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
225pub struct OutlierLayer {
226    /// Outlier flag per observation (length n).
227    pub flags: Vec<bool>,
228    /// Detection threshold.
229    pub threshold: f64,
230    /// Method name.
231    pub method: String,
232    /// Optional: MEI scores (for outliergram).
233    pub mei: Option<Vec<f64>>,
234    /// Optional: MBD scores (for outliergram).
235    pub mbd: Option<Vec<f64>>,
236    /// Optional: magnitude outlyingness.
237    pub magnitude: Option<Vec<f64>>,
238    /// Optional: shape outlyingness.
239    pub shape: Option<Vec<f64>>,
240    /// Optional: outliergram parabola intercept coefficient.
241    pub outliergram_a0: Option<f64>,
242    /// Optional: outliergram parabola linear coefficient.
243    pub outliergram_a1: Option<f64>,
244    /// Optional: outliergram parabola quadratic coefficient.
245    pub outliergram_a2: Option<f64>,
246}
247
248/// Cluster assignments.
249#[derive(Debug, Clone)]
250#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
251pub struct ClusterLayer {
252    /// Cluster label per observation (0-indexed, length n).
253    pub labels: Vec<usize>,
254    /// Number of clusters.
255    pub k: usize,
256    /// Method name.
257    pub method: String,
258    /// Optional: cluster centers (k rows × m cols).
259    pub centers: Option<FdMatrix>,
260    /// Optional: medoid indices (length k).
261    pub medoid_indices: Option<Vec<usize>>,
262    /// Optional: silhouette scores (length n).
263    pub silhouette: Option<Vec<f64>>,
264}
265
266/// Scalar-on-function regression fit.
267#[derive(Debug, Clone)]
268#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
269pub struct RegressionLayer {
270    /// Method name (e.g., "fregre_lm", "fregre_pls", "fregre_np", "elastic").
271    pub method: String,
272    /// Functional coefficient β(t) (length m). `None` for nonparametric.
273    pub beta_t: Option<Vec<f64>>,
274    /// Fitted values (length n).
275    pub fitted_values: Vec<f64>,
276    /// Residuals (length n).
277    pub residuals: Vec<f64>,
278    /// Observed response (length n).
279    pub observed_y: Vec<f64>,
280    /// R².
281    pub r_squared: f64,
282    /// Adjusted R².
283    pub adj_r_squared: Option<f64>,
284    /// Intercept.
285    pub intercept: f64,
286    /// Number of components used (0 for nonparametric).
287    pub ncomp: usize,
288    /// Evaluation grid for β(t).
289    pub argvals: Option<Vec<f64>>,
290    /// Pointwise standard errors of β(t).
291    pub beta_se: Option<Vec<f64>>,
292    /// Optional: human-readable model name.
293    pub model_name: Option<String>,
294    /// Optional: number of training observations.
295    pub n_obs: Option<usize>,
296}
297
298/// Function-on-scalar regression fit.
299#[derive(Debug, Clone)]
300#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
301pub struct FosrLayer {
302    /// Coefficient functions (p × m), one per predictor.
303    pub coefficients: FdMatrix,
304    /// Fitted curves (n × m).
305    pub fitted: FdMatrix,
306    /// R² per grid point (length m).
307    pub r_squared_t: Vec<f64>,
308}
309
310/// Tolerance / confidence band.
311#[derive(Debug, Clone)]
312#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
313pub struct ToleranceLayer {
314    /// Lower bound (length m).
315    pub lower: Vec<f64>,
316    /// Upper bound (length m).
317    pub upper: Vec<f64>,
318    /// Center (length m).
319    pub center: Vec<f64>,
320    /// Method name.
321    pub method: String,
322}
323
324/// Mean curve.
325#[derive(Debug, Clone)]
326#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
327pub struct MeanLayer {
328    /// Mean function (length m).
329    pub mean: Vec<f64>,
330}
331
332/// SPM Phase I chart.
333#[derive(Debug, Clone)]
334#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
335pub struct SpmChartLayer {
336    /// T² control limit.
337    pub t2_limit: f64,
338    /// SPE control limit.
339    pub spe_limit: f64,
340    /// Phase I T² statistics.
341    pub t2_stats: Vec<f64>,
342    /// Phase I SPE statistics.
343    pub spe_stats: Vec<f64>,
344    /// Number of FPC components.
345    pub ncomp: usize,
346    /// Significance level.
347    pub alpha: f64,
348    /// Optional: eigenvalues from FPCA (length ncomp).
349    pub eigenvalues: Option<Vec<f64>>,
350    /// Optional: FPCA mean function (length m).
351    pub fpca_mean: Option<Vec<f64>>,
352    /// Optional: FPCA integration weights (length m).
353    pub fpca_weights: Option<Vec<f64>>,
354}
355
356/// SPM Phase II monitoring result.
357#[derive(Debug, Clone)]
358#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
359pub struct SpmMonitorLayer {
360    /// T² statistics for new observations.
361    pub t2_stats: Vec<f64>,
362    /// SPE statistics for new observations.
363    pub spe_stats: Vec<f64>,
364    /// T² control limit.
365    pub t2_limit: f64,
366    /// SPE control limit.
367    pub spe_limit: f64,
368    /// T² alarm flags.
369    pub t2_alarms: Vec<bool>,
370    /// SPE alarm flags.
371    pub spe_alarms: Vec<bool>,
372}
373
374/// Explainability result.
375#[derive(Debug, Clone)]
376#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
377pub struct ExplainLayer {
378    /// Method name (e.g., "shap", "pdp", "ale", "permutation_importance").
379    pub method: String,
380    /// Values (interpretation depends on method).
381    pub values: Vec<f64>,
382    /// Labels for the values.
383    pub labels: Vec<String>,
384    /// Additional method-specific data.
385    pub extra: Option<HashMap<String, Vec<f64>>>,
386}
387
388/// User-defined layer for extensions.
389#[derive(Debug, Clone)]
390#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
391pub struct CustomLayer {
392    pub name: String,
393    pub data: HashMap<String, Vec<f64>>,
394}
395
396// ─── FdaData Constructors ───────────────────────────────────────────────────
397
398impl FdaData {
399    /// Create from functional curves + grid.
400    pub fn from_curves(curves: FdMatrix, argvals: Vec<f64>) -> Self {
401        Self {
402            curves: Some(curves),
403            argvals: Some(argvals),
404            grouping: Vec::new(),
405            scalar_vars: Vec::new(),
406            tabular: None,
407            column_names: None,
408            layers: HashMap::new(),
409        }
410    }
411
412    /// Create from tabular (non-functional) data.
413    pub fn from_tabular(tabular: FdMatrix, column_names: Vec<String>) -> Self {
414        Self {
415            curves: None,
416            argvals: None,
417            grouping: Vec::new(),
418            scalar_vars: Vec::new(),
419            tabular: Some(tabular),
420            column_names: Some(column_names),
421            layers: HashMap::new(),
422        }
423    }
424
425    /// Create empty container.
426    pub fn empty() -> Self {
427        Self {
428            curves: None,
429            argvals: None,
430            grouping: Vec::new(),
431            scalar_vars: Vec::new(),
432            tabular: None,
433            column_names: None,
434            layers: HashMap::new(),
435        }
436    }
437
438    // ── Requirement checks ──
439
440    /// Require functional curves to be present.
441    pub fn require_curves(&self) -> Result<(&FdMatrix, &[f64]), String> {
442        match (&self.curves, &self.argvals) {
443            (Some(c), Some(a)) => Ok((c, a)),
444            _ => Err("FdaData requires functional curves + argvals".into()),
445        }
446    }
447
448    /// Require a specific layer to be present.
449    pub fn require_layer(&self, key: &LayerKey) -> Result<&Layer, String> {
450        self.layers
451            .get(key)
452            .ok_or_else(|| format!("FdaData missing required layer: {key:?}"))
453    }
454
455    // ── Layer access ──
456
457    /// Check if a layer is present.
458    pub fn has_layer(&self, key: &LayerKey) -> bool {
459        self.layers.contains_key(key)
460    }
461
462    /// Get a layer by key.
463    pub fn get_layer(&self, key: &LayerKey) -> Option<&Layer> {
464        self.layers.get(key)
465    }
466
467    /// Set (add or replace) a layer.
468    pub fn set_layer(&mut self, key: LayerKey, layer: Layer) {
469        self.layers.insert(key, layer);
470    }
471
472    /// Remove a layer.
473    pub fn remove_layer(&mut self, key: &LayerKey) -> Option<Layer> {
474        self.layers.remove(key)
475    }
476
477    /// List all layer keys present.
478    pub fn layer_keys(&self) -> Vec<&LayerKey> {
479        self.layers.keys().collect()
480    }
481
482    // ── Typed layer accessors ──
483
484    /// Get FPCA layer if present.
485    pub fn fpca(&self) -> Option<&FpcaLayer> {
486        match self.layers.get(&LayerKey::Fpca)? {
487            Layer::Fpca(l) => Some(l),
488            _ => None,
489        }
490    }
491
492    /// Get distances layer if present.
493    pub fn distances(&self) -> Option<&DistancesLayer> {
494        match self.layers.get(&LayerKey::Distances)? {
495            Layer::Distances(l) => Some(l),
496            _ => None,
497        }
498    }
499
500    /// Get alignment layer if present.
501    pub fn alignment(&self) -> Option<&AlignmentLayer> {
502        match self.layers.get(&LayerKey::Alignment)? {
503            Layer::Alignment(l) => Some(l),
504            _ => None,
505        }
506    }
507
508    /// Get regression layer if present.
509    pub fn regression(&self) -> Option<&RegressionLayer> {
510        match self.layers.get(&LayerKey::Regression)? {
511            Layer::Regression(l) => Some(l),
512            _ => None,
513        }
514    }
515
516    /// Get cluster layer if present.
517    pub fn clusters(&self) -> Option<&ClusterLayer> {
518        match self.layers.get(&LayerKey::Clusters)? {
519            Layer::Clusters(l) => Some(l),
520            _ => None,
521        }
522    }
523
524    /// Get depth layer if present.
525    pub fn depth(&self) -> Option<&DepthLayer> {
526        match self.layers.get(&LayerKey::Depth)? {
527            Layer::Depth(l) => Some(l),
528            _ => None,
529        }
530    }
531
532    /// Get outlier layer if present.
533    pub fn outliers(&self) -> Option<&OutlierLayer> {
534        match self.layers.get(&LayerKey::Outliers)? {
535            Layer::Outliers(l) => Some(l),
536            _ => None,
537        }
538    }
539
540    // ── Metadata helpers ──
541
542    /// Number of observations (from curves, tabular, or first scalar var).
543    pub fn n_obs(&self) -> usize {
544        if let Some(c) = &self.curves {
545            return c.nrows();
546        }
547        if let Some(t) = &self.tabular {
548            return t.nrows();
549        }
550        self.scalar_vars.first().map_or(0, |v| v.values.len())
551    }
552
553    /// Number of grid points (0 if no functional data).
554    pub fn n_points(&self) -> usize {
555        self.argvals.as_ref().map_or(0, |a| a.len())
556    }
557
558    /// Add a scalar variable.
559    pub fn add_scalar(&mut self, name: impl Into<String>, values: Vec<f64>) {
560        self.scalar_vars.push(NamedVec {
561            name: name.into(),
562            values,
563        });
564    }
565
566    /// Get a scalar variable by name.
567    pub fn get_scalar(&self, name: &str) -> Option<&[f64]> {
568        self.scalar_vars
569            .iter()
570            .find(|v| v.name == name)
571            .map(|v| v.values.as_slice())
572    }
573
574    /// Add a grouping variable with per-observation string labels.
575    ///
576    /// Unique labels are computed automatically in order of first appearance.
577    pub fn add_grouping(&mut self, name: impl Into<String>, labels: Vec<String>) {
578        let mut unique = Vec::new();
579        for lab in &labels {
580            if !unique.contains(lab) {
581                unique.push(lab.clone());
582            }
583        }
584        self.grouping.push(GroupVar {
585            name: name.into(),
586            labels,
587            unique,
588        });
589    }
590
591    /// Look up a grouping variable by name.
592    pub fn get_grouping(&self, name: &str) -> Option<&GroupVar> {
593        self.grouping.iter().find(|g| g.name == name)
594    }
595}
596
597// ─── Tests ──────────────────────────────────────────────────────────────────
598
599#[cfg(test)]
600mod tests {
601    use super::*;
602
603    #[test]
604    fn from_curves_basic() {
605        let fd = FdaData::from_curves(
606            FdMatrix::zeros(10, 50),
607            (0..50).map(|i| i as f64 / 49.0).collect(),
608        );
609        assert_eq!(fd.n_obs(), 10);
610        assert_eq!(fd.n_points(), 50);
611        assert!(fd.require_curves().is_ok());
612        assert!(!fd.has_layer(&LayerKey::Fpca));
613    }
614
615    #[test]
616    fn add_and_retrieve_layers() {
617        let mut fd = FdaData::from_curves(
618            FdMatrix::zeros(5, 20),
619            (0..20).map(|i| i as f64 / 19.0).collect(),
620        );
621
622        fd.set_layer(
623            LayerKey::Depth,
624            Layer::Depth(DepthLayer {
625                scores: vec![0.5; 5],
626                method: "fraiman_muniz".into(),
627            }),
628        );
629
630        assert!(fd.has_layer(&LayerKey::Depth));
631        assert!(!fd.has_layer(&LayerKey::Fpca));
632        assert!(fd.depth().is_some());
633        assert_eq!(fd.depth().unwrap().scores.len(), 5);
634        assert_eq!(fd.layer_keys().len(), 1);
635    }
636
637    #[test]
638    fn require_missing_layer_errors() {
639        let fd = FdaData::from_curves(FdMatrix::zeros(3, 10), vec![0.0; 10]);
640        assert!(fd.require_layer(&LayerKey::Fpca).is_err());
641    }
642
643    #[test]
644    fn scalar_vars() {
645        let mut fd = FdaData::empty();
646        fd.add_scalar("height", vec![170.0, 180.0, 165.0]);
647        assert_eq!(fd.get_scalar("height").unwrap(), &[170.0, 180.0, 165.0]);
648        assert!(fd.get_scalar("weight").is_none());
649        assert_eq!(fd.n_obs(), 3);
650    }
651
652    #[test]
653    fn multiple_layers_compose() {
654        let mut fd = FdaData::from_curves(FdMatrix::zeros(10, 30), vec![0.0; 30]);
655
656        fd.set_layer(
657            LayerKey::Depth,
658            Layer::Depth(DepthLayer {
659                scores: vec![0.5; 10],
660                method: "fm".into(),
661            }),
662        );
663        fd.set_layer(
664            LayerKey::Outliers,
665            Layer::Outliers(OutlierLayer {
666                flags: vec![false; 10],
667                threshold: 0.1,
668                method: "lrt".into(),
669                mei: None,
670                mbd: None,
671                magnitude: None,
672                shape: None,
673                outliergram_a0: None,
674                outliergram_a1: None,
675                outliergram_a2: None,
676            }),
677        );
678        fd.set_layer(
679            LayerKey::Distances,
680            Layer::Distances(DistancesLayer {
681                dist_mat: FdMatrix::zeros(10, 10),
682                method: "elastic".into(),
683            }),
684        );
685
686        assert_eq!(fd.layer_keys().len(), 3);
687        assert!(fd.depth().is_some());
688        assert!(fd.outliers().is_some());
689        assert!(fd.distances().is_some());
690    }
691}