Skip to main content

gam_terms/smooth/
torch_dispatch.rs

1//! Dispatch table for the torch fit entry — single source of truth for which
2//! Python `Smooth` subclasses the torch.fit autograd glue recognises.
3//!
4//! The Python side calls `torch_smooth_dispatch_key(type(smooth).__name__)`
5//! to translate the spec class name into a small enumeration. The tensor
6//! construction itself stays in Python because the torch autograd VJP must
7//! flow back through `points`, `centers`, and `by`.
8//!
9//! Every Python `Smooth` subclass that is re-exported from `gamfit.torch`
10//! must have a matching variant here, so that dispatch never fails for a
11//! class the user can legitimately import. `TensorBSpline` (te tensor
12//! product), `Matern` (kernel-Gram penalty), and `Categorical` (sum-to-zero
13//! contrast with an identity ridge penalty — an i.i.d. Gaussian random
14//! effect, matching the Rust `RandomEffectTermSpec`) are now all fully wired
15//! on the torch path. Every exported variant resolves to a `fit.py` branch
16//! that builds a concrete `(design, penalty)` tensor pair.
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub enum TorchSmoothEntry {
20    Duchon,
21    BSpline,
22    TensorBSpline,
23    Matern,
24    Sphere,
25    PeriodicSplineCurve,
26    Pca,
27    Categorical,
28}
29
30impl TorchSmoothEntry {
31    pub const fn as_str(self) -> &'static str {
32        match self {
33            TorchSmoothEntry::Duchon => "duchon",
34            TorchSmoothEntry::BSpline => "bspline",
35            TorchSmoothEntry::TensorBSpline => "tensor_bspline",
36            TorchSmoothEntry::Matern => "matern",
37            TorchSmoothEntry::Sphere => "sphere",
38            TorchSmoothEntry::PeriodicSplineCurve => "periodic_spline_curve",
39            TorchSmoothEntry::Pca => "pca",
40            TorchSmoothEntry::Categorical => "categorical",
41        }
42    }
43}
44
45/// Map a Python `Smooth` subclass name to the matching torch entry kind.
46///
47/// Returns `Ok(entry)` for every `Smooth` subclass that `gamfit.torch`
48/// re-exports. Each recognised entry has a matching `fit.py` branch that
49/// builds a concrete design/penalty tensor pair; the `NotImplementedError`
50/// fallback there is now only a defensive guard for a future Rust variant
51/// added without a torch branch. Truly unknown class names produce a
52/// `TypeError`-shaped message preserving the previous Python cascade's
53/// surface error.
54pub fn dispatch_key(spec_kind: &str) -> Result<TorchSmoothEntry, String> {
55    match spec_kind {
56        "Duchon" => Ok(TorchSmoothEntry::Duchon),
57        "BSpline" => Ok(TorchSmoothEntry::BSpline),
58        "TensorBSpline" => Ok(TorchSmoothEntry::TensorBSpline),
59        "Matern" => Ok(TorchSmoothEntry::Matern),
60        "Sphere" => Ok(TorchSmoothEntry::Sphere),
61        "PeriodicSplineCurve" => Ok(TorchSmoothEntry::PeriodicSplineCurve),
62        "Pca" => Ok(TorchSmoothEntry::Pca),
63        "Categorical" => Ok(TorchSmoothEntry::Categorical),
64        other => Err(format!("unknown Smooth subclass: {other}")),
65    }
66}
67
68#[cfg(test)]
69mod tests {
70    use super::*;
71
72    #[test]
73    fn known_specs_dispatch() {
74        assert_eq!(dispatch_key("Duchon").unwrap(), TorchSmoothEntry::Duchon);
75        assert_eq!(dispatch_key("BSpline").unwrap(), TorchSmoothEntry::BSpline);
76        assert_eq!(
77            dispatch_key("TensorBSpline").unwrap(),
78            TorchSmoothEntry::TensorBSpline
79        );
80        assert_eq!(dispatch_key("Matern").unwrap(), TorchSmoothEntry::Matern);
81        assert_eq!(dispatch_key("Sphere").unwrap(), TorchSmoothEntry::Sphere);
82        assert_eq!(
83            dispatch_key("PeriodicSplineCurve").unwrap(),
84            TorchSmoothEntry::PeriodicSplineCurve
85        );
86        assert_eq!(dispatch_key("Pca").unwrap(), TorchSmoothEntry::Pca);
87        assert_eq!(
88            dispatch_key("Categorical").unwrap(),
89            TorchSmoothEntry::Categorical
90        );
91    }
92
93    #[test]
94    fn unknown_spec_kind_is_distinguishable() {
95        let err = dispatch_key("Banana").unwrap_err();
96        assert!(err.contains("unknown Smooth subclass"));
97        assert!(err.contains("Banana"));
98    }
99
100    #[test]
101    fn as_str_round_trips() {
102        for kind in [
103            TorchSmoothEntry::Duchon,
104            TorchSmoothEntry::BSpline,
105            TorchSmoothEntry::TensorBSpline,
106            TorchSmoothEntry::Matern,
107            TorchSmoothEntry::Sphere,
108            TorchSmoothEntry::PeriodicSplineCurve,
109            TorchSmoothEntry::Pca,
110            TorchSmoothEntry::Categorical,
111        ] {
112            assert!(!kind.as_str().is_empty());
113        }
114    }
115}