Skip to main content

kernels_data/config/
name.rs

1use std::fmt;
2
3use regex::Regex;
4use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
5
6/// A validated kernel name matching `^[a-z][-a-z0-9]*[a-z0-9]$`.
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub struct KernelName(String);
9
10impl KernelName {
11    pub fn new(name: impl Into<String>) -> Result<Self, KernelNameError> {
12        let name = name.into();
13        let pattern = Regex::new(r"^[a-z][-a-z0-9]*[a-z0-9]$").unwrap();
14
15        if !pattern.is_match(&name) {
16            return Err(KernelNameError(name));
17        }
18
19        Ok(Self(name))
20    }
21
22    pub fn as_str(&self) -> &str {
23        &self.0
24    }
25
26    pub fn python_name(&self) -> String {
27        self.0.replace("-", "_")
28    }
29}
30
31impl AsRef<str> for KernelName {
32    fn as_ref(&self) -> &str {
33        &self.0
34    }
35}
36
37impl fmt::Display for KernelName {
38    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39        write!(f, "{}", self.0)
40    }
41}
42
43#[derive(Debug)]
44pub struct KernelNameError(String);
45
46impl fmt::Display for KernelNameError {
47    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48        write!(
49            f,
50            "Invalid kernel name `{}`. Name must:\n\
51             - Start with a lowercase letter (a-z)\n\
52             - Contain only lowercase letters, digits, and dashes\n\
53             - End with a lowercase letter or digit\n\
54             - Be at least 2 characters long\n\
55             Examples: `my-kernel`, `relu2d`, `flash-attention`",
56            self.0
57        )
58    }
59}
60
61impl std::error::Error for KernelNameError {}
62
63impl<'de> Deserialize<'de> for KernelName {
64    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
65    where
66        D: Deserializer<'de>,
67    {
68        let s = String::deserialize(deserializer)?;
69        KernelName::new(s).map_err(de::Error::custom)
70    }
71}
72
73impl Serialize for KernelName {
74    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
75    where
76        S: Serializer,
77    {
78        serializer.serialize_str(&self.0)
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85
86    #[test]
87    fn test_kernel_name_valid() {
88        assert!(KernelName::new("my-kernel").is_ok());
89        assert!(KernelName::new("relu2d").is_ok());
90        assert!(KernelName::new("flash-attention").is_ok());
91        assert!(KernelName::new("a1").is_ok());
92        assert!(KernelName::new("ab").is_ok());
93        assert!(KernelName::new("my--kernel").is_ok());
94    }
95
96    #[test]
97    fn test_kernel_name_invalid() {
98        assert!(KernelName::new("my_kernel").is_err());
99        assert!(KernelName::new("MyKernel").is_err());
100        assert!(KernelName::new("a").is_err());
101        assert!(KernelName::new("my-kernel-").is_err());
102        assert!(KernelName::new("-my-kernel").is_err());
103        assert!(KernelName::new("1kernel").is_err());
104        assert!(KernelName::new("").is_err());
105    }
106}