1use std::{collections::HashMap, fmt::Display, path::PathBuf, str::FromStr};
2
3use eyre::Result;
4use serde::{Deserialize, Serialize};
5
6mod deps;
7pub use deps::{Dependency, PythonDependency};
8
9mod compat;
10pub use compat::BuildCompat;
11
12mod name;
13pub use name::KernelName;
14
15pub mod v1;
16pub mod v2;
17pub mod v3;
18
19use itertools::Itertools;
20
21use crate::version::Version;
22
23pub struct Build {
24 pub general: General,
25 pub kernels: HashMap<String, Kernel>,
26 pub framework: Framework,
27}
28
29pub enum Framework {
30 Torch(Torch),
31 TorchNoarch,
32 TvmFfi(TvmFfi),
33}
34
35impl Framework {
36 pub fn torch(&self) -> Option<&Torch> {
37 match self {
38 Framework::Torch(torch) => Some(torch),
39 _ => None,
40 }
41 }
42
43 pub fn tvm_ffi(&self) -> Option<&TvmFfi> {
44 match self {
45 Framework::TvmFfi(tvm_ffi) => Some(tvm_ffi),
46 _ => None,
47 }
48 }
49}
50
51impl Build {
52 pub fn is_noarch(&self) -> bool {
53 self.kernels.is_empty()
54 }
55}
56
57pub struct General {
58 pub name: KernelName,
59 pub version: Option<usize>,
60
61 pub license: Option<String>,
63
64 pub upstream: Option<url::Url>,
66
67 pub backends: Vec<Backend>,
68 pub hub: Option<Hub>,
69 pub python_depends: Option<Vec<String>>,
70
71 pub cuda: Option<CudaGeneral>,
72 pub neuron: Option<NeuronGeneral>,
73 pub xpu: Option<XpuGeneral>,
74}
75
76impl General {
77 pub fn python_depends(
78 &self,
79 ) -> Box<dyn Iterator<Item = Result<(&str, &PythonDependency)>> + '_> {
80 let general_python_deps = match self.python_depends.as_ref() {
81 Some(deps) => deps,
82 None => {
83 return Box::new(std::iter::empty());
84 }
85 };
86
87 Box::new(general_python_deps.iter().map(move |dep| {
88 match deps::PYTHON_DEPENDENCIES.get_dependency(dep) {
89 Ok(resolved_deps) => Ok((dep.as_str(), resolved_deps)),
90 Err(e) => Err(e.into()),
91 }
92 }))
93 }
94
95 pub fn backend_python_depends(
96 &self,
97 backend: Backend,
98 ) -> Box<dyn Iterator<Item = Result<(&str, &PythonDependency)>> + '_> {
99 let backend_python_deps = match backend {
100 Backend::Cuda => self
101 .cuda
102 .as_ref()
103 .and_then(|cuda| cuda.python_depends.as_ref()),
104 Backend::Xpu => self
105 .xpu
106 .as_ref()
107 .and_then(|xpu| xpu.python_depends.as_ref()),
108 _ => None,
109 };
110
111 let backend_python_deps = match backend_python_deps {
112 Some(deps) => deps,
113 None => {
114 return Box::new(std::iter::empty());
115 }
116 };
117
118 Box::new(backend_python_deps.iter().map(move |dep| {
119 match deps::PYTHON_DEPENDENCIES.get_backend_dependency(backend, dep) {
120 Ok(resolved_deps) => Ok((dep.as_str(), resolved_deps)),
121 Err(e) => Err(e.into()),
122 }
123 }))
124 }
125}
126
127pub struct CudaGeneral {
128 pub minver: Option<Version>,
129 pub maxver: Option<Version>,
130 pub python_depends: Option<Vec<String>>,
131}
132
133pub struct XpuGeneral {
134 pub python_depends: Option<Vec<String>>,
135}
136
137pub struct NeuronGeneral {
138 pub python_depends: Option<Vec<String>>,
139}
140
141pub struct Hub {
142 pub repo_id: Option<String>,
143 pub branch: Option<String>,
144}
145
146pub struct Torch {
147 pub include: Option<Vec<String>>,
148 pub minver: Option<Version>,
149 pub maxver: Option<Version>,
150 pub pyext: Option<Vec<String>>,
151 pub src: Vec<PathBuf>,
152}
153
154fn data_extensions(py_ext: Option<&[String]>) -> Option<Vec<String>> {
155 match py_ext {
156 Some(exts) => {
157 let extensions = exts
158 .iter()
159 .filter(|&ext| ext != "py" && ext != "pyi")
160 .cloned()
161 .collect_vec();
162 if extensions.is_empty() {
163 None
164 } else {
165 Some(extensions)
166 }
167 }
168
169 None => None,
170 }
171}
172
173impl Torch {
174 pub fn data_extensions(&self) -> Option<Vec<String>> {
175 data_extensions(self.pyext.as_deref())
176 }
177}
178
179pub struct TvmFfi {
180 pub include: Option<Vec<String>>,
181 pub pyext: Option<Vec<String>>,
182 pub src: Vec<PathBuf>,
183}
184
185impl TvmFfi {
186 pub fn data_extensions(&self) -> Option<Vec<String>> {
187 data_extensions(self.pyext.as_deref())
188 }
189}
190
191pub enum Kernel {
192 Cpu {
193 cxx_flags: Option<Vec<String>>,
194 depends: Vec<Dependency>,
195 include: Option<Vec<String>>,
196 src: Vec<String>,
197 },
198 Cuda {
199 cuda_capabilities: Option<Vec<String>>,
200 cuda_flags: Option<Vec<String>>,
201 cuda_minver: Option<Version>,
202 cxx_flags: Option<Vec<String>>,
203 depends: Vec<Dependency>,
204 include: Option<Vec<String>>,
205 src: Vec<String>,
206 },
207 Metal {
208 cxx_flags: Option<Vec<String>>,
209 depends: Vec<Dependency>,
210 include: Option<Vec<String>>,
211 src: Vec<String>,
212 },
213 Rocm {
214 cxx_flags: Option<Vec<String>>,
215 depends: Vec<Dependency>,
216 rocm_archs: Option<Vec<String>>,
217 hip_flags: Option<Vec<String>>,
218 include: Option<Vec<String>>,
219 src: Vec<String>,
220 },
221 Xpu {
222 cxx_flags: Option<Vec<String>>,
223 depends: Vec<Dependency>,
224 sycl_flags: Option<Vec<String>>,
225 include: Option<Vec<String>>,
226 src: Vec<String>,
227 },
228}
229
230impl Kernel {
231 pub fn cxx_flags(&self) -> Option<&[String]> {
232 match self {
233 Kernel::Cpu { cxx_flags, .. }
234 | Kernel::Cuda { cxx_flags, .. }
235 | Kernel::Metal { cxx_flags, .. }
236 | Kernel::Rocm { cxx_flags, .. }
237 | Kernel::Xpu { cxx_flags, .. } => cxx_flags.as_deref(),
238 }
239 }
240
241 pub fn include(&self) -> Option<&[String]> {
242 match self {
243 Kernel::Cpu { include, .. }
244 | Kernel::Cuda { include, .. }
245 | Kernel::Metal { include, .. }
246 | Kernel::Rocm { include, .. }
247 | Kernel::Xpu { include, .. } => include.as_deref(),
248 }
249 }
250
251 pub fn backend(&self) -> Backend {
252 match self {
253 Kernel::Cpu { .. } => Backend::Cpu,
254 Kernel::Cuda { .. } => Backend::Cuda,
255 Kernel::Metal { .. } => Backend::Metal,
256 Kernel::Rocm { .. } => Backend::Rocm,
257 Kernel::Xpu { .. } => Backend::Xpu,
258 }
259 }
260
261 pub fn depends(&self) -> &[Dependency] {
262 match self {
263 Kernel::Cpu { depends, .. }
264 | Kernel::Cuda { depends, .. }
265 | Kernel::Metal { depends, .. }
266 | Kernel::Rocm { depends, .. }
267 | Kernel::Xpu { depends, .. } => depends,
268 }
269 }
270
271 pub fn src(&self) -> &[String] {
272 match self {
273 Kernel::Cpu { src, .. }
274 | Kernel::Cuda { src, .. }
275 | Kernel::Metal { src, .. }
276 | Kernel::Rocm { src, .. }
277 | Kernel::Xpu { src, .. } => src,
278 }
279 }
280}
281
282#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)]
283#[serde(deny_unknown_fields, rename_all = "kebab-case")]
284pub enum Backend {
285 Cpu,
286 Cuda,
287 Metal,
288 Neuron,
289 Rocm,
290 Xpu,
291}
292
293impl Backend {
294 pub const fn all() -> [Backend; 6] {
295 [
296 Backend::Cpu,
297 Backend::Cuda,
298 Backend::Metal,
299 Backend::Neuron,
300 Backend::Rocm,
301 Backend::Xpu,
302 ]
303 }
304}
305
306impl Display for Backend {
307 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308 match self {
309 Backend::Cpu => write!(f, "cpu"),
310 Backend::Cuda => write!(f, "cuda"),
311 Backend::Metal => write!(f, "metal"),
312 Backend::Neuron => write!(f, "neuron"),
313 Backend::Rocm => write!(f, "rocm"),
314 Backend::Xpu => write!(f, "xpu"),
315 }
316 }
317}
318
319impl FromStr for Backend {
320 type Err = String;
321
322 fn from_str(s: &str) -> Result<Self, Self::Err> {
323 match s.to_lowercase().as_str() {
324 "cpu" => Ok(Backend::Cpu),
325 "cuda" => Ok(Backend::Cuda),
326 "metal" => Ok(Backend::Metal),
327 "neuron" => Ok(Backend::Neuron),
328 "rocm" => Ok(Backend::Rocm),
329 "xpu" => Ok(Backend::Xpu),
330 _ => Err(format!("Unknown backend: {s}")),
331 }
332 }
333}