Skip to main content

kernels_data/config/
v3.rs

1use std::collections::HashMap;
2use std::path::PathBuf;
3
4use serde::{Deserialize, Serialize};
5
6use super::{Dependency, KernelName};
7use crate::version::Version;
8
9#[derive(Debug, Deserialize, Serialize)]
10#[serde(deny_unknown_fields, rename_all = "kebab-case")]
11pub struct Build {
12    pub general: General,
13
14    // NOTE: In v3, the absense of a framework section means torch-noarch.
15    //       However, this won't work if we have support for other noarch
16    //       frameworks in the future, so in v4, we probably have to make
17    //       a torch-noarch framework variant.
18    #[serde(flatten)]
19    pub framework: Option<Framework>,
20
21    #[serde(rename = "kernel", default)]
22    pub kernels: HashMap<String, Kernel>,
23}
24
25#[derive(Debug, Deserialize, Serialize)]
26#[serde(rename_all = "kebab-case")]
27pub enum Framework {
28    Torch(Torch),
29    TvmFfi(TvmFfi),
30}
31
32#[derive(Debug, Deserialize, Serialize)]
33#[serde(deny_unknown_fields, rename_all = "kebab-case")]
34pub struct General {
35    pub name: KernelName,
36
37    pub version: Option<usize>,
38
39    pub license: Option<String>,
40
41    pub upstream: Option<url::Url>,
42
43    pub backends: Vec<Backend>,
44
45    pub cuda: Option<CudaGeneral>,
46
47    pub hub: Option<Hub>,
48
49    pub neuron: Option<NeuronGeneral>,
50
51    pub python_depends: Option<Vec<String>>,
52
53    pub xpu: Option<XpuGeneral>,
54}
55
56#[derive(Debug, Deserialize, Serialize)]
57#[serde(deny_unknown_fields, rename_all = "kebab-case")]
58pub struct CudaGeneral {
59    pub minver: Option<Version>,
60    pub maxver: Option<Version>,
61    pub python_depends: Option<Vec<String>>,
62}
63
64#[derive(Debug, Deserialize, Serialize)]
65#[serde(deny_unknown_fields, rename_all = "kebab-case")]
66pub struct NeuronGeneral {
67    pub python_depends: Option<Vec<String>>,
68}
69
70#[derive(Debug, Deserialize, Serialize)]
71#[serde(deny_unknown_fields, rename_all = "kebab-case")]
72pub struct XpuGeneral {
73    pub python_depends: Option<Vec<String>>,
74}
75
76#[derive(Debug, Deserialize, Serialize)]
77#[serde(deny_unknown_fields, rename_all = "kebab-case")]
78pub struct Hub {
79    pub repo_id: Option<String>,
80    pub branch: Option<String>,
81}
82
83#[derive(Debug, Deserialize, Clone, Serialize)]
84#[serde(deny_unknown_fields)]
85pub struct Torch {
86    pub include: Option<Vec<String>>,
87    pub minver: Option<Version>,
88    pub maxver: Option<Version>,
89    pub pyext: Option<Vec<String>>,
90
91    #[serde(default)]
92    pub src: Vec<PathBuf>,
93}
94
95#[derive(Debug, Deserialize, Clone, Serialize)]
96#[serde(deny_unknown_fields)]
97pub struct TvmFfi {
98    pub include: Option<Vec<String>>,
99    pub pyext: Option<Vec<String>>,
100    pub src: Vec<PathBuf>,
101}
102
103#[derive(Debug, Deserialize, Serialize)]
104#[serde(deny_unknown_fields, rename_all = "kebab-case", tag = "backend")]
105pub enum Kernel {
106    #[serde(rename_all = "kebab-case")]
107    Cpu {
108        cxx_flags: Option<Vec<String>>,
109        depends: Vec<Dependency>,
110        include: Option<Vec<String>>,
111        src: Vec<String>,
112    },
113    #[serde(rename_all = "kebab-case")]
114    Cuda {
115        cuda_capabilities: Option<Vec<String>>,
116        cuda_flags: Option<Vec<String>>,
117        cuda_minver: Option<Version>,
118        cxx_flags: Option<Vec<String>>,
119        depends: Vec<Dependency>,
120        include: Option<Vec<String>>,
121        src: Vec<String>,
122    },
123    #[serde(rename_all = "kebab-case")]
124    Metal {
125        cxx_flags: Option<Vec<String>>,
126        depends: Vec<Dependency>,
127        include: Option<Vec<String>>,
128        src: Vec<String>,
129    },
130    #[serde(rename_all = "kebab-case")]
131    Rocm {
132        cxx_flags: Option<Vec<String>>,
133        depends: Vec<Dependency>,
134        rocm_archs: Option<Vec<String>>,
135        hip_flags: Option<Vec<String>>,
136        include: Option<Vec<String>>,
137        src: Vec<String>,
138    },
139    #[serde(rename_all = "kebab-case")]
140    Xpu {
141        cxx_flags: Option<Vec<String>>,
142        depends: Vec<Dependency>,
143        sycl_flags: Option<Vec<String>>,
144        include: Option<Vec<String>>,
145        src: Vec<String>,
146    },
147}
148
149#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)]
150#[serde(deny_unknown_fields, rename_all = "kebab-case")]
151pub enum Backend {
152    Cpu,
153    Cuda,
154    Metal,
155    Neuron,
156    Rocm,
157    Xpu,
158}
159
160impl From<Build> for super::Build {
161    fn from(build: Build) -> Self {
162        let kernels: HashMap<String, super::Kernel> = build
163            .kernels
164            .into_iter()
165            .map(|(k, v)| (k, v.into()))
166            .collect();
167
168        let framework = match build.framework {
169            Some(Framework::Torch(torch)) => super::Framework::Torch(torch.into()),
170            Some(Framework::TvmFfi(tvm_ffi)) => super::Framework::TvmFfi(tvm_ffi.into()),
171            None => super::Framework::TorchNoarch,
172        };
173
174        Self {
175            general: build.general.into(),
176            framework,
177            kernels,
178        }
179    }
180}
181
182impl From<General> for super::General {
183    fn from(general: General) -> Self {
184        Self {
185            name: general.name,
186            version: general.version,
187            license: general.license,
188            upstream: general.upstream,
189            backends: general.backends.into_iter().map(Into::into).collect(),
190            cuda: general.cuda.map(Into::into),
191            hub: general.hub.map(Into::into),
192            neuron: general.neuron.map(Into::into),
193            python_depends: general.python_depends,
194            xpu: general.xpu.map(Into::into),
195        }
196    }
197}
198
199impl From<CudaGeneral> for super::CudaGeneral {
200    fn from(cuda: CudaGeneral) -> Self {
201        Self {
202            minver: cuda.minver,
203            maxver: cuda.maxver,
204            python_depends: cuda.python_depends,
205        }
206    }
207}
208
209impl From<NeuronGeneral> for super::NeuronGeneral {
210    fn from(neuron: NeuronGeneral) -> Self {
211        Self {
212            python_depends: neuron.python_depends,
213        }
214    }
215}
216
217impl From<XpuGeneral> for super::XpuGeneral {
218    fn from(xpu: XpuGeneral) -> Self {
219        Self {
220            python_depends: xpu.python_depends,
221        }
222    }
223}
224
225impl From<Hub> for super::Hub {
226    fn from(hub: Hub) -> Self {
227        Self {
228            repo_id: hub.repo_id,
229            branch: hub.branch,
230        }
231    }
232}
233
234impl From<Torch> for super::Torch {
235    fn from(torch: Torch) -> Self {
236        Self {
237            include: torch.include,
238            minver: torch.minver,
239            maxver: torch.maxver,
240            pyext: torch.pyext,
241            src: torch.src,
242        }
243    }
244}
245
246impl From<TvmFfi> for super::TvmFfi {
247    fn from(tvm_ffi: TvmFfi) -> Self {
248        Self {
249            include: tvm_ffi.include,
250            pyext: tvm_ffi.pyext,
251            src: tvm_ffi.src,
252        }
253    }
254}
255
256impl From<Backend> for super::Backend {
257    fn from(backend: Backend) -> Self {
258        match backend {
259            Backend::Cpu => super::Backend::Cpu,
260            Backend::Cuda => super::Backend::Cuda,
261            Backend::Metal => super::Backend::Metal,
262            Backend::Neuron => super::Backend::Neuron,
263            Backend::Rocm => super::Backend::Rocm,
264            Backend::Xpu => super::Backend::Xpu,
265        }
266    }
267}
268
269impl From<Kernel> for super::Kernel {
270    fn from(kernel: Kernel) -> Self {
271        match kernel {
272            Kernel::Cpu {
273                cxx_flags,
274                depends,
275                include,
276                src,
277            } => super::Kernel::Cpu {
278                cxx_flags,
279                depends,
280                include,
281                src,
282            },
283            Kernel::Cuda {
284                cuda_capabilities,
285                cuda_flags,
286                cuda_minver,
287                cxx_flags,
288                depends,
289                include,
290                src,
291            } => super::Kernel::Cuda {
292                cuda_capabilities,
293                cuda_flags,
294                cuda_minver,
295                cxx_flags,
296                depends,
297                include,
298                src,
299            },
300            Kernel::Metal {
301                cxx_flags,
302                depends,
303                include,
304                src,
305            } => super::Kernel::Metal {
306                cxx_flags,
307                depends,
308                include,
309                src,
310            },
311            Kernel::Rocm {
312                cxx_flags,
313                depends,
314                rocm_archs,
315                hip_flags,
316                include,
317                src,
318            } => super::Kernel::Rocm {
319                cxx_flags,
320                depends,
321                rocm_archs,
322                hip_flags,
323                include,
324                src,
325            },
326            Kernel::Xpu {
327                cxx_flags,
328                depends,
329                sycl_flags,
330                include,
331                src,
332            } => super::Kernel::Xpu {
333                cxx_flags,
334                depends,
335                sycl_flags,
336                include,
337                src,
338            },
339        }
340    }
341}
342
343impl From<super::Build> for Build {
344    fn from(build: super::Build) -> Self {
345        let framework = match build.framework {
346            super::Framework::Torch(torch) => Some(Framework::Torch(torch.into())),
347            super::Framework::TorchNoarch => None,
348            super::Framework::TvmFfi(tvm_ffi) => Some(Framework::TvmFfi(tvm_ffi.into())),
349        };
350
351        Self {
352            general: build.general.into(),
353            framework,
354            kernels: build
355                .kernels
356                .into_iter()
357                .map(|(k, v)| (k, v.into()))
358                .collect(),
359        }
360    }
361}
362
363impl From<super::General> for General {
364    fn from(general: super::General) -> Self {
365        Self {
366            name: general.name,
367            version: general.version,
368            license: general.license,
369            upstream: general.upstream,
370            backends: general.backends.into_iter().map(Into::into).collect(),
371            cuda: general.cuda.map(Into::into),
372            hub: general.hub.map(Into::into),
373            neuron: general.neuron.map(Into::into),
374            python_depends: general.python_depends,
375            xpu: general.xpu.map(Into::into),
376        }
377    }
378}
379
380impl From<super::CudaGeneral> for CudaGeneral {
381    fn from(cuda: super::CudaGeneral) -> Self {
382        Self {
383            minver: cuda.minver,
384            maxver: cuda.maxver,
385            python_depends: cuda.python_depends,
386        }
387    }
388}
389
390impl From<super::NeuronGeneral> for NeuronGeneral {
391    fn from(neuron: super::NeuronGeneral) -> Self {
392        Self {
393            python_depends: neuron.python_depends,
394        }
395    }
396}
397
398impl From<super::XpuGeneral> for XpuGeneral {
399    fn from(xpu: super::XpuGeneral) -> Self {
400        Self {
401            python_depends: xpu.python_depends,
402        }
403    }
404}
405
406impl From<super::Hub> for Hub {
407    fn from(hub: super::Hub) -> Self {
408        Self {
409            repo_id: hub.repo_id,
410            branch: hub.branch,
411        }
412    }
413}
414
415impl From<super::Torch> for Torch {
416    fn from(torch: super::Torch) -> Self {
417        Self {
418            include: torch.include,
419            minver: torch.minver,
420            maxver: torch.maxver,
421            pyext: torch.pyext,
422            src: torch.src,
423        }
424    }
425}
426
427impl From<super::TvmFfi> for TvmFfi {
428    fn from(tvm_ffi: super::TvmFfi) -> Self {
429        Self {
430            include: tvm_ffi.include,
431            pyext: tvm_ffi.pyext,
432            src: tvm_ffi.src,
433        }
434    }
435}
436
437impl From<super::Backend> for Backend {
438    fn from(backend: super::Backend) -> Self {
439        match backend {
440            super::Backend::Cpu => Backend::Cpu,
441            super::Backend::Cuda => Backend::Cuda,
442            super::Backend::Metal => Backend::Metal,
443            super::Backend::Neuron => Backend::Neuron,
444            super::Backend::Rocm => Backend::Rocm,
445            super::Backend::Xpu => Backend::Xpu,
446        }
447    }
448}
449
450impl From<super::Kernel> for Kernel {
451    fn from(kernel: super::Kernel) -> Self {
452        match kernel {
453            super::Kernel::Cpu {
454                cxx_flags,
455                depends,
456                include,
457                src,
458            } => Kernel::Cpu {
459                cxx_flags,
460                depends,
461                include,
462                src,
463            },
464            super::Kernel::Cuda {
465                cuda_capabilities,
466                cuda_flags,
467                cuda_minver,
468                cxx_flags,
469                depends,
470                include,
471                src,
472            } => Kernel::Cuda {
473                cuda_capabilities,
474                cuda_flags,
475                cuda_minver,
476                cxx_flags,
477                depends,
478                include,
479                src,
480            },
481            super::Kernel::Metal {
482                cxx_flags,
483                depends,
484                include,
485                src,
486            } => Kernel::Metal {
487                cxx_flags,
488                depends,
489                include,
490                src,
491            },
492            super::Kernel::Rocm {
493                cxx_flags,
494                depends,
495                rocm_archs,
496                hip_flags,
497                include,
498                src,
499            } => Kernel::Rocm {
500                cxx_flags,
501                depends,
502                rocm_archs,
503                hip_flags,
504                include,
505                src,
506            },
507            super::Kernel::Xpu {
508                cxx_flags,
509                depends,
510                sycl_flags,
511                include,
512                src,
513            } => Kernel::Xpu {
514                cxx_flags,
515                depends,
516                sycl_flags,
517                include,
518                src,
519            },
520        }
521    }
522}