1use std::collections::HashMap;
2use std::path::PathBuf;
3
4use serde::{Deserialize, Serialize};
5
6use super::{Dependency, KernelName};
7use crate::version::Version;
8
9#[derive(Debug, Deserialize, Serialize)]
10#[serde(deny_unknown_fields, rename_all = "kebab-case")]
11pub struct Build {
12 pub general: General,
13
14 #[serde(flatten)]
19 pub framework: Option<Framework>,
20
21 #[serde(rename = "kernel", default)]
22 pub kernels: HashMap<String, Kernel>,
23}
24
25#[derive(Debug, Deserialize, Serialize)]
26#[serde(rename_all = "kebab-case")]
27pub enum Framework {
28 Torch(Torch),
29 TvmFfi(TvmFfi),
30}
31
32#[derive(Debug, Deserialize, Serialize)]
33#[serde(deny_unknown_fields, rename_all = "kebab-case")]
34pub struct General {
35 pub name: KernelName,
36
37 pub version: Option<usize>,
38
39 pub license: Option<String>,
40
41 pub upstream: Option<url::Url>,
42
43 pub backends: Vec<Backend>,
44
45 pub cuda: Option<CudaGeneral>,
46
47 pub hub: Option<Hub>,
48
49 pub neuron: Option<NeuronGeneral>,
50
51 pub python_depends: Option<Vec<String>>,
52
53 pub xpu: Option<XpuGeneral>,
54}
55
56#[derive(Debug, Deserialize, Serialize)]
57#[serde(deny_unknown_fields, rename_all = "kebab-case")]
58pub struct CudaGeneral {
59 pub minver: Option<Version>,
60 pub maxver: Option<Version>,
61 pub python_depends: Option<Vec<String>>,
62}
63
64#[derive(Debug, Deserialize, Serialize)]
65#[serde(deny_unknown_fields, rename_all = "kebab-case")]
66pub struct NeuronGeneral {
67 pub python_depends: Option<Vec<String>>,
68}
69
70#[derive(Debug, Deserialize, Serialize)]
71#[serde(deny_unknown_fields, rename_all = "kebab-case")]
72pub struct XpuGeneral {
73 pub python_depends: Option<Vec<String>>,
74}
75
76#[derive(Debug, Deserialize, Serialize)]
77#[serde(deny_unknown_fields, rename_all = "kebab-case")]
78pub struct Hub {
79 pub repo_id: Option<String>,
80 pub branch: Option<String>,
81}
82
83#[derive(Debug, Deserialize, Clone, Serialize)]
84#[serde(deny_unknown_fields)]
85pub struct Torch {
86 pub include: Option<Vec<String>>,
87 pub minver: Option<Version>,
88 pub maxver: Option<Version>,
89 pub pyext: Option<Vec<String>>,
90
91 #[serde(default)]
92 pub src: Vec<PathBuf>,
93}
94
95#[derive(Debug, Deserialize, Clone, Serialize)]
96#[serde(deny_unknown_fields)]
97pub struct TvmFfi {
98 pub include: Option<Vec<String>>,
99 pub pyext: Option<Vec<String>>,
100 pub src: Vec<PathBuf>,
101}
102
103#[derive(Debug, Deserialize, Serialize)]
104#[serde(deny_unknown_fields, rename_all = "kebab-case", tag = "backend")]
105pub enum Kernel {
106 #[serde(rename_all = "kebab-case")]
107 Cpu {
108 cxx_flags: Option<Vec<String>>,
109 depends: Vec<Dependency>,
110 include: Option<Vec<String>>,
111 src: Vec<String>,
112 },
113 #[serde(rename_all = "kebab-case")]
114 Cuda {
115 cuda_capabilities: Option<Vec<String>>,
116 cuda_flags: Option<Vec<String>>,
117 cuda_minver: Option<Version>,
118 cxx_flags: Option<Vec<String>>,
119 depends: Vec<Dependency>,
120 include: Option<Vec<String>>,
121 src: Vec<String>,
122 },
123 #[serde(rename_all = "kebab-case")]
124 Metal {
125 cxx_flags: Option<Vec<String>>,
126 depends: Vec<Dependency>,
127 include: Option<Vec<String>>,
128 src: Vec<String>,
129 },
130 #[serde(rename_all = "kebab-case")]
131 Rocm {
132 cxx_flags: Option<Vec<String>>,
133 depends: Vec<Dependency>,
134 rocm_archs: Option<Vec<String>>,
135 hip_flags: Option<Vec<String>>,
136 include: Option<Vec<String>>,
137 src: Vec<String>,
138 },
139 #[serde(rename_all = "kebab-case")]
140 Xpu {
141 cxx_flags: Option<Vec<String>>,
142 depends: Vec<Dependency>,
143 sycl_flags: Option<Vec<String>>,
144 include: Option<Vec<String>>,
145 src: Vec<String>,
146 },
147}
148
149#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)]
150#[serde(deny_unknown_fields, rename_all = "kebab-case")]
151pub enum Backend {
152 Cpu,
153 Cuda,
154 Metal,
155 Neuron,
156 Rocm,
157 Xpu,
158}
159
160impl From<Build> for super::Build {
161 fn from(build: Build) -> Self {
162 let kernels: HashMap<String, super::Kernel> = build
163 .kernels
164 .into_iter()
165 .map(|(k, v)| (k, v.into()))
166 .collect();
167
168 let framework = match build.framework {
169 Some(Framework::Torch(torch)) => super::Framework::Torch(torch.into()),
170 Some(Framework::TvmFfi(tvm_ffi)) => super::Framework::TvmFfi(tvm_ffi.into()),
171 None => super::Framework::TorchNoarch,
172 };
173
174 Self {
175 general: build.general.into(),
176 framework,
177 kernels,
178 }
179 }
180}
181
182impl From<General> for super::General {
183 fn from(general: General) -> Self {
184 Self {
185 name: general.name,
186 version: general.version,
187 license: general.license,
188 upstream: general.upstream,
189 backends: general.backends.into_iter().map(Into::into).collect(),
190 cuda: general.cuda.map(Into::into),
191 hub: general.hub.map(Into::into),
192 neuron: general.neuron.map(Into::into),
193 python_depends: general.python_depends,
194 xpu: general.xpu.map(Into::into),
195 }
196 }
197}
198
199impl From<CudaGeneral> for super::CudaGeneral {
200 fn from(cuda: CudaGeneral) -> Self {
201 Self {
202 minver: cuda.minver,
203 maxver: cuda.maxver,
204 python_depends: cuda.python_depends,
205 }
206 }
207}
208
209impl From<NeuronGeneral> for super::NeuronGeneral {
210 fn from(neuron: NeuronGeneral) -> Self {
211 Self {
212 python_depends: neuron.python_depends,
213 }
214 }
215}
216
217impl From<XpuGeneral> for super::XpuGeneral {
218 fn from(xpu: XpuGeneral) -> Self {
219 Self {
220 python_depends: xpu.python_depends,
221 }
222 }
223}
224
225impl From<Hub> for super::Hub {
226 fn from(hub: Hub) -> Self {
227 Self {
228 repo_id: hub.repo_id,
229 branch: hub.branch,
230 }
231 }
232}
233
234impl From<Torch> for super::Torch {
235 fn from(torch: Torch) -> Self {
236 Self {
237 include: torch.include,
238 minver: torch.minver,
239 maxver: torch.maxver,
240 pyext: torch.pyext,
241 src: torch.src,
242 }
243 }
244}
245
246impl From<TvmFfi> for super::TvmFfi {
247 fn from(tvm_ffi: TvmFfi) -> Self {
248 Self {
249 include: tvm_ffi.include,
250 pyext: tvm_ffi.pyext,
251 src: tvm_ffi.src,
252 }
253 }
254}
255
256impl From<Backend> for super::Backend {
257 fn from(backend: Backend) -> Self {
258 match backend {
259 Backend::Cpu => super::Backend::Cpu,
260 Backend::Cuda => super::Backend::Cuda,
261 Backend::Metal => super::Backend::Metal,
262 Backend::Neuron => super::Backend::Neuron,
263 Backend::Rocm => super::Backend::Rocm,
264 Backend::Xpu => super::Backend::Xpu,
265 }
266 }
267}
268
269impl From<Kernel> for super::Kernel {
270 fn from(kernel: Kernel) -> Self {
271 match kernel {
272 Kernel::Cpu {
273 cxx_flags,
274 depends,
275 include,
276 src,
277 } => super::Kernel::Cpu {
278 cxx_flags,
279 depends,
280 include,
281 src,
282 },
283 Kernel::Cuda {
284 cuda_capabilities,
285 cuda_flags,
286 cuda_minver,
287 cxx_flags,
288 depends,
289 include,
290 src,
291 } => super::Kernel::Cuda {
292 cuda_capabilities,
293 cuda_flags,
294 cuda_minver,
295 cxx_flags,
296 depends,
297 include,
298 src,
299 },
300 Kernel::Metal {
301 cxx_flags,
302 depends,
303 include,
304 src,
305 } => super::Kernel::Metal {
306 cxx_flags,
307 depends,
308 include,
309 src,
310 },
311 Kernel::Rocm {
312 cxx_flags,
313 depends,
314 rocm_archs,
315 hip_flags,
316 include,
317 src,
318 } => super::Kernel::Rocm {
319 cxx_flags,
320 depends,
321 rocm_archs,
322 hip_flags,
323 include,
324 src,
325 },
326 Kernel::Xpu {
327 cxx_flags,
328 depends,
329 sycl_flags,
330 include,
331 src,
332 } => super::Kernel::Xpu {
333 cxx_flags,
334 depends,
335 sycl_flags,
336 include,
337 src,
338 },
339 }
340 }
341}
342
343impl From<super::Build> for Build {
344 fn from(build: super::Build) -> Self {
345 let framework = match build.framework {
346 super::Framework::Torch(torch) => Some(Framework::Torch(torch.into())),
347 super::Framework::TorchNoarch => None,
348 super::Framework::TvmFfi(tvm_ffi) => Some(Framework::TvmFfi(tvm_ffi.into())),
349 };
350
351 Self {
352 general: build.general.into(),
353 framework,
354 kernels: build
355 .kernels
356 .into_iter()
357 .map(|(k, v)| (k, v.into()))
358 .collect(),
359 }
360 }
361}
362
363impl From<super::General> for General {
364 fn from(general: super::General) -> Self {
365 Self {
366 name: general.name,
367 version: general.version,
368 license: general.license,
369 upstream: general.upstream,
370 backends: general.backends.into_iter().map(Into::into).collect(),
371 cuda: general.cuda.map(Into::into),
372 hub: general.hub.map(Into::into),
373 neuron: general.neuron.map(Into::into),
374 python_depends: general.python_depends,
375 xpu: general.xpu.map(Into::into),
376 }
377 }
378}
379
380impl From<super::CudaGeneral> for CudaGeneral {
381 fn from(cuda: super::CudaGeneral) -> Self {
382 Self {
383 minver: cuda.minver,
384 maxver: cuda.maxver,
385 python_depends: cuda.python_depends,
386 }
387 }
388}
389
390impl From<super::NeuronGeneral> for NeuronGeneral {
391 fn from(neuron: super::NeuronGeneral) -> Self {
392 Self {
393 python_depends: neuron.python_depends,
394 }
395 }
396}
397
398impl From<super::XpuGeneral> for XpuGeneral {
399 fn from(xpu: super::XpuGeneral) -> Self {
400 Self {
401 python_depends: xpu.python_depends,
402 }
403 }
404}
405
406impl From<super::Hub> for Hub {
407 fn from(hub: super::Hub) -> Self {
408 Self {
409 repo_id: hub.repo_id,
410 branch: hub.branch,
411 }
412 }
413}
414
415impl From<super::Torch> for Torch {
416 fn from(torch: super::Torch) -> Self {
417 Self {
418 include: torch.include,
419 minver: torch.minver,
420 maxver: torch.maxver,
421 pyext: torch.pyext,
422 src: torch.src,
423 }
424 }
425}
426
427impl From<super::TvmFfi> for TvmFfi {
428 fn from(tvm_ffi: super::TvmFfi) -> Self {
429 Self {
430 include: tvm_ffi.include,
431 pyext: tvm_ffi.pyext,
432 src: tvm_ffi.src,
433 }
434 }
435}
436
437impl From<super::Backend> for Backend {
438 fn from(backend: super::Backend) -> Self {
439 match backend {
440 super::Backend::Cpu => Backend::Cpu,
441 super::Backend::Cuda => Backend::Cuda,
442 super::Backend::Metal => Backend::Metal,
443 super::Backend::Neuron => Backend::Neuron,
444 super::Backend::Rocm => Backend::Rocm,
445 super::Backend::Xpu => Backend::Xpu,
446 }
447 }
448}
449
450impl From<super::Kernel> for Kernel {
451 fn from(kernel: super::Kernel) -> Self {
452 match kernel {
453 super::Kernel::Cpu {
454 cxx_flags,
455 depends,
456 include,
457 src,
458 } => Kernel::Cpu {
459 cxx_flags,
460 depends,
461 include,
462 src,
463 },
464 super::Kernel::Cuda {
465 cuda_capabilities,
466 cuda_flags,
467 cuda_minver,
468 cxx_flags,
469 depends,
470 include,
471 src,
472 } => Kernel::Cuda {
473 cuda_capabilities,
474 cuda_flags,
475 cuda_minver,
476 cxx_flags,
477 depends,
478 include,
479 src,
480 },
481 super::Kernel::Metal {
482 cxx_flags,
483 depends,
484 include,
485 src,
486 } => Kernel::Metal {
487 cxx_flags,
488 depends,
489 include,
490 src,
491 },
492 super::Kernel::Rocm {
493 cxx_flags,
494 depends,
495 rocm_archs,
496 hip_flags,
497 include,
498 src,
499 } => Kernel::Rocm {
500 cxx_flags,
501 depends,
502 rocm_archs,
503 hip_flags,
504 include,
505 src,
506 },
507 super::Kernel::Xpu {
508 cxx_flags,
509 depends,
510 sycl_flags,
511 include,
512 src,
513 } => Kernel::Xpu {
514 cxx_flags,
515 depends,
516 sycl_flags,
517 include,
518 src,
519 },
520 }
521 }
522}