Skip to main content

kernels_data/config/
mod.rs

1use std::{collections::HashMap, fmt::Display, path::PathBuf, str::FromStr};
2
3use eyre::Result;
4use serde::{Deserialize, Serialize};
5
6mod deps;
7pub use deps::{Dependency, PythonDependency};
8
9mod compat;
10pub use compat::BuildCompat;
11
12mod name;
13pub use name::KernelName;
14
15pub mod v1;
16pub mod v2;
17pub mod v3;
18
19use itertools::Itertools;
20
21use crate::version::Version;
22
23pub struct Build {
24    pub general: General,
25    pub kernels: HashMap<String, Kernel>,
26    pub framework: Framework,
27}
28
29pub enum Framework {
30    Torch(Torch),
31    TorchNoarch,
32    TvmFfi(TvmFfi),
33}
34
35impl Framework {
36    pub fn torch(&self) -> Option<&Torch> {
37        match self {
38            Framework::Torch(torch) => Some(torch),
39            _ => None,
40        }
41    }
42
43    pub fn tvm_ffi(&self) -> Option<&TvmFfi> {
44        match self {
45            Framework::TvmFfi(tvm_ffi) => Some(tvm_ffi),
46            _ => None,
47        }
48    }
49}
50
51impl Build {
52    pub fn is_noarch(&self) -> bool {
53        self.kernels.is_empty()
54    }
55}
56
57pub struct General {
58    pub name: KernelName,
59    pub version: Option<usize>,
60
61    /// Hugging Face Hub license identifier.
62    pub license: Option<String>,
63
64    /// Source repository or reference for the kernel code.
65    pub upstream: Option<url::Url>,
66
67    pub backends: Vec<Backend>,
68    pub hub: Option<Hub>,
69    pub python_depends: Option<Vec<String>>,
70
71    pub cuda: Option<CudaGeneral>,
72    pub neuron: Option<NeuronGeneral>,
73    pub xpu: Option<XpuGeneral>,
74}
75
76impl General {
77    pub fn python_depends(
78        &self,
79    ) -> Box<dyn Iterator<Item = Result<(&str, &PythonDependency)>> + '_> {
80        let general_python_deps = match self.python_depends.as_ref() {
81            Some(deps) => deps,
82            None => {
83                return Box::new(std::iter::empty());
84            }
85        };
86
87        Box::new(general_python_deps.iter().map(move |dep| {
88            match deps::PYTHON_DEPENDENCIES.get_dependency(dep) {
89                Ok(resolved_deps) => Ok((dep.as_str(), resolved_deps)),
90                Err(e) => Err(e.into()),
91            }
92        }))
93    }
94
95    pub fn backend_python_depends(
96        &self,
97        backend: Backend,
98    ) -> Box<dyn Iterator<Item = Result<(&str, &PythonDependency)>> + '_> {
99        let backend_python_deps = match backend {
100            Backend::Cuda => self
101                .cuda
102                .as_ref()
103                .and_then(|cuda| cuda.python_depends.as_ref()),
104            Backend::Xpu => self
105                .xpu
106                .as_ref()
107                .and_then(|xpu| xpu.python_depends.as_ref()),
108            _ => None,
109        };
110
111        let backend_python_deps = match backend_python_deps {
112            Some(deps) => deps,
113            None => {
114                return Box::new(std::iter::empty());
115            }
116        };
117
118        Box::new(backend_python_deps.iter().map(move |dep| {
119            match deps::PYTHON_DEPENDENCIES.get_backend_dependency(backend, dep) {
120                Ok(resolved_deps) => Ok((dep.as_str(), resolved_deps)),
121                Err(e) => Err(e.into()),
122            }
123        }))
124    }
125}
126
127pub struct CudaGeneral {
128    pub minver: Option<Version>,
129    pub maxver: Option<Version>,
130    pub python_depends: Option<Vec<String>>,
131}
132
133pub struct XpuGeneral {
134    pub python_depends: Option<Vec<String>>,
135}
136
137pub struct NeuronGeneral {
138    pub python_depends: Option<Vec<String>>,
139}
140
141pub struct Hub {
142    pub repo_id: Option<String>,
143    pub branch: Option<String>,
144}
145
146pub struct Torch {
147    pub include: Option<Vec<String>>,
148    pub minver: Option<Version>,
149    pub maxver: Option<Version>,
150    pub pyext: Option<Vec<String>>,
151    pub src: Vec<PathBuf>,
152}
153
154fn data_extensions(py_ext: Option<&[String]>) -> Option<Vec<String>> {
155    match py_ext {
156        Some(exts) => {
157            let extensions = exts
158                .iter()
159                .filter(|&ext| ext != "py" && ext != "pyi")
160                .cloned()
161                .collect_vec();
162            if extensions.is_empty() {
163                None
164            } else {
165                Some(extensions)
166            }
167        }
168
169        None => None,
170    }
171}
172
173impl Torch {
174    pub fn data_extensions(&self) -> Option<Vec<String>> {
175        data_extensions(self.pyext.as_deref())
176    }
177}
178
179pub struct TvmFfi {
180    pub include: Option<Vec<String>>,
181    pub pyext: Option<Vec<String>>,
182    pub src: Vec<PathBuf>,
183}
184
185impl TvmFfi {
186    pub fn data_extensions(&self) -> Option<Vec<String>> {
187        data_extensions(self.pyext.as_deref())
188    }
189}
190
191pub enum Kernel {
192    Cpu {
193        cxx_flags: Option<Vec<String>>,
194        depends: Vec<Dependency>,
195        include: Option<Vec<String>>,
196        src: Vec<String>,
197    },
198    Cuda {
199        cuda_capabilities: Option<Vec<String>>,
200        cuda_flags: Option<Vec<String>>,
201        cuda_minver: Option<Version>,
202        cxx_flags: Option<Vec<String>>,
203        depends: Vec<Dependency>,
204        include: Option<Vec<String>>,
205        src: Vec<String>,
206    },
207    Metal {
208        cxx_flags: Option<Vec<String>>,
209        depends: Vec<Dependency>,
210        include: Option<Vec<String>>,
211        src: Vec<String>,
212    },
213    Rocm {
214        cxx_flags: Option<Vec<String>>,
215        depends: Vec<Dependency>,
216        rocm_archs: Option<Vec<String>>,
217        hip_flags: Option<Vec<String>>,
218        include: Option<Vec<String>>,
219        src: Vec<String>,
220    },
221    Xpu {
222        cxx_flags: Option<Vec<String>>,
223        depends: Vec<Dependency>,
224        sycl_flags: Option<Vec<String>>,
225        include: Option<Vec<String>>,
226        src: Vec<String>,
227    },
228}
229
230impl Kernel {
231    pub fn cxx_flags(&self) -> Option<&[String]> {
232        match self {
233            Kernel::Cpu { cxx_flags, .. }
234            | Kernel::Cuda { cxx_flags, .. }
235            | Kernel::Metal { cxx_flags, .. }
236            | Kernel::Rocm { cxx_flags, .. }
237            | Kernel::Xpu { cxx_flags, .. } => cxx_flags.as_deref(),
238        }
239    }
240
241    pub fn include(&self) -> Option<&[String]> {
242        match self {
243            Kernel::Cpu { include, .. }
244            | Kernel::Cuda { include, .. }
245            | Kernel::Metal { include, .. }
246            | Kernel::Rocm { include, .. }
247            | Kernel::Xpu { include, .. } => include.as_deref(),
248        }
249    }
250
251    pub fn backend(&self) -> Backend {
252        match self {
253            Kernel::Cpu { .. } => Backend::Cpu,
254            Kernel::Cuda { .. } => Backend::Cuda,
255            Kernel::Metal { .. } => Backend::Metal,
256            Kernel::Rocm { .. } => Backend::Rocm,
257            Kernel::Xpu { .. } => Backend::Xpu,
258        }
259    }
260
261    pub fn depends(&self) -> &[Dependency] {
262        match self {
263            Kernel::Cpu { depends, .. }
264            | Kernel::Cuda { depends, .. }
265            | Kernel::Metal { depends, .. }
266            | Kernel::Rocm { depends, .. }
267            | Kernel::Xpu { depends, .. } => depends,
268        }
269    }
270
271    pub fn src(&self) -> &[String] {
272        match self {
273            Kernel::Cpu { src, .. }
274            | Kernel::Cuda { src, .. }
275            | Kernel::Metal { src, .. }
276            | Kernel::Rocm { src, .. }
277            | Kernel::Xpu { src, .. } => src,
278        }
279    }
280}
281
282#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)]
283#[serde(deny_unknown_fields, rename_all = "kebab-case")]
284pub enum Backend {
285    Cpu,
286    Cuda,
287    Metal,
288    Neuron,
289    Rocm,
290    Xpu,
291}
292
293impl Backend {
294    pub const fn all() -> [Backend; 6] {
295        [
296            Backend::Cpu,
297            Backend::Cuda,
298            Backend::Metal,
299            Backend::Neuron,
300            Backend::Rocm,
301            Backend::Xpu,
302        ]
303    }
304}
305
306impl Display for Backend {
307    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308        match self {
309            Backend::Cpu => write!(f, "cpu"),
310            Backend::Cuda => write!(f, "cuda"),
311            Backend::Metal => write!(f, "metal"),
312            Backend::Neuron => write!(f, "neuron"),
313            Backend::Rocm => write!(f, "rocm"),
314            Backend::Xpu => write!(f, "xpu"),
315        }
316    }
317}
318
319impl FromStr for Backend {
320    type Err = String;
321
322    fn from_str(s: &str) -> Result<Self, Self::Err> {
323        match s.to_lowercase().as_str() {
324            "cpu" => Ok(Backend::Cpu),
325            "cuda" => Ok(Backend::Cuda),
326            "metal" => Ok(Backend::Metal),
327            "neuron" => Ok(Backend::Neuron),
328            "rocm" => Ok(Backend::Rocm),
329            "xpu" => Ok(Backend::Xpu),
330            _ => Err(format!("Unknown backend: {s}")),
331        }
332    }
333}