gam_terms/smooth/
torch_dispatch.rs1#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub enum TorchSmoothEntry {
20 Duchon,
21 BSpline,
22 TensorBSpline,
23 Matern,
24 Sphere,
25 PeriodicSplineCurve,
26 Pca,
27 Categorical,
28}
29
30impl TorchSmoothEntry {
31 pub const fn as_str(self) -> &'static str {
32 match self {
33 TorchSmoothEntry::Duchon => "duchon",
34 TorchSmoothEntry::BSpline => "bspline",
35 TorchSmoothEntry::TensorBSpline => "tensor_bspline",
36 TorchSmoothEntry::Matern => "matern",
37 TorchSmoothEntry::Sphere => "sphere",
38 TorchSmoothEntry::PeriodicSplineCurve => "periodic_spline_curve",
39 TorchSmoothEntry::Pca => "pca",
40 TorchSmoothEntry::Categorical => "categorical",
41 }
42 }
43}
44
45pub fn dispatch_key(spec_kind: &str) -> Result<TorchSmoothEntry, String> {
55 match spec_kind {
56 "Duchon" => Ok(TorchSmoothEntry::Duchon),
57 "BSpline" => Ok(TorchSmoothEntry::BSpline),
58 "TensorBSpline" => Ok(TorchSmoothEntry::TensorBSpline),
59 "Matern" => Ok(TorchSmoothEntry::Matern),
60 "Sphere" => Ok(TorchSmoothEntry::Sphere),
61 "PeriodicSplineCurve" => Ok(TorchSmoothEntry::PeriodicSplineCurve),
62 "Pca" => Ok(TorchSmoothEntry::Pca),
63 "Categorical" => Ok(TorchSmoothEntry::Categorical),
64 other => Err(format!("unknown Smooth subclass: {other}")),
65 }
66}
67
68#[cfg(test)]
69mod tests {
70 use super::*;
71
72 #[test]
73 fn known_specs_dispatch() {
74 assert_eq!(dispatch_key("Duchon").unwrap(), TorchSmoothEntry::Duchon);
75 assert_eq!(dispatch_key("BSpline").unwrap(), TorchSmoothEntry::BSpline);
76 assert_eq!(
77 dispatch_key("TensorBSpline").unwrap(),
78 TorchSmoothEntry::TensorBSpline
79 );
80 assert_eq!(dispatch_key("Matern").unwrap(), TorchSmoothEntry::Matern);
81 assert_eq!(dispatch_key("Sphere").unwrap(), TorchSmoothEntry::Sphere);
82 assert_eq!(
83 dispatch_key("PeriodicSplineCurve").unwrap(),
84 TorchSmoothEntry::PeriodicSplineCurve
85 );
86 assert_eq!(dispatch_key("Pca").unwrap(), TorchSmoothEntry::Pca);
87 assert_eq!(
88 dispatch_key("Categorical").unwrap(),
89 TorchSmoothEntry::Categorical
90 );
91 }
92
93 #[test]
94 fn unknown_spec_kind_is_distinguishable() {
95 let err = dispatch_key("Banana").unwrap_err();
96 assert!(err.contains("unknown Smooth subclass"));
97 assert!(err.contains("Banana"));
98 }
99
100 #[test]
101 fn as_str_round_trips() {
102 for kind in [
103 TorchSmoothEntry::Duchon,
104 TorchSmoothEntry::BSpline,
105 TorchSmoothEntry::TensorBSpline,
106 TorchSmoothEntry::Matern,
107 TorchSmoothEntry::Sphere,
108 TorchSmoothEntry::PeriodicSplineCurve,
109 TorchSmoothEntry::Pca,
110 TorchSmoothEntry::Categorical,
111 ] {
112 assert!(!kind.as_str().is_empty());
113 }
114 }
115}