#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TorchSmoothEntry {
Duchon,
BSpline,
TensorBSpline,
Matern,
Sphere,
PeriodicSplineCurve,
Pca,
Categorical,
}
impl TorchSmoothEntry {
pub const fn as_str(self) -> &'static str {
match self {
TorchSmoothEntry::Duchon => "duchon",
TorchSmoothEntry::BSpline => "bspline",
TorchSmoothEntry::TensorBSpline => "tensor_bspline",
TorchSmoothEntry::Matern => "matern",
TorchSmoothEntry::Sphere => "sphere",
TorchSmoothEntry::PeriodicSplineCurve => "periodic_spline_curve",
TorchSmoothEntry::Pca => "pca",
TorchSmoothEntry::Categorical => "categorical",
}
}
}
pub fn dispatch_key(spec_kind: &str) -> Result<TorchSmoothEntry, String> {
match spec_kind {
"Duchon" => Ok(TorchSmoothEntry::Duchon),
"BSpline" => Ok(TorchSmoothEntry::BSpline),
"TensorBSpline" => Ok(TorchSmoothEntry::TensorBSpline),
"Matern" => Ok(TorchSmoothEntry::Matern),
"Sphere" => Ok(TorchSmoothEntry::Sphere),
"PeriodicSplineCurve" => Ok(TorchSmoothEntry::PeriodicSplineCurve),
"Pca" => Ok(TorchSmoothEntry::Pca),
"Categorical" => Ok(TorchSmoothEntry::Categorical),
other => Err(format!("unknown Smooth subclass: {other}")),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn known_specs_dispatch() {
assert_eq!(dispatch_key("Duchon").unwrap(), TorchSmoothEntry::Duchon);
assert_eq!(dispatch_key("BSpline").unwrap(), TorchSmoothEntry::BSpline);
assert_eq!(
dispatch_key("TensorBSpline").unwrap(),
TorchSmoothEntry::TensorBSpline
);
assert_eq!(dispatch_key("Matern").unwrap(), TorchSmoothEntry::Matern);
assert_eq!(dispatch_key("Sphere").unwrap(), TorchSmoothEntry::Sphere);
assert_eq!(
dispatch_key("PeriodicSplineCurve").unwrap(),
TorchSmoothEntry::PeriodicSplineCurve
);
assert_eq!(dispatch_key("Pca").unwrap(), TorchSmoothEntry::Pca);
assert_eq!(
dispatch_key("Categorical").unwrap(),
TorchSmoothEntry::Categorical
);
}
#[test]
fn unknown_spec_kind_is_distinguishable() {
let err = dispatch_key("Banana").unwrap_err();
assert!(err.contains("unknown Smooth subclass"));
assert!(err.contains("Banana"));
}
#[test]
fn as_str_round_trips() {
for kind in [
TorchSmoothEntry::Duchon,
TorchSmoothEntry::BSpline,
TorchSmoothEntry::TensorBSpline,
TorchSmoothEntry::Matern,
TorchSmoothEntry::Sphere,
TorchSmoothEntry::PeriodicSplineCurve,
TorchSmoothEntry::Pca,
TorchSmoothEntry::Categorical,
] {
assert!(!kind.as_str().is_empty());
}
}
}