Skip to main content

kernels_data/config/
v2.rs

1use std::{
2    collections::{BTreeSet, HashMap},
3    fmt::Display,
4    path::PathBuf,
5};
6
7use eyre::{bail, Result};
8use serde::{Deserialize, Serialize};
9
10use super::{Backend, Dependency, KernelName};
11use crate::version::Version;
12
13#[derive(Debug, Deserialize, Serialize)]
14#[serde(deny_unknown_fields)]
15pub struct Build {
16    pub general: General,
17    pub torch: Option<Torch>,
18
19    #[serde(rename = "kernel", default)]
20    pub kernels: HashMap<String, Kernel>,
21}
22
23#[derive(Debug, Deserialize, Serialize)]
24#[serde(deny_unknown_fields, rename_all = "kebab-case")]
25pub struct General {
26    pub name: KernelName,
27    #[serde(default)]
28    pub universal: bool,
29
30    pub cuda_maxver: Option<Version>,
31
32    pub cuda_minver: Option<Version>,
33
34    pub hub: Option<Hub>,
35
36    pub python_depends: Option<Vec<PythonDependency>>,
37}
38
39#[derive(Debug, Deserialize, Serialize)]
40#[serde(deny_unknown_fields, rename_all = "kebab-case")]
41pub struct Hub {
42    pub repo_id: Option<String>,
43    pub branch: Option<String>,
44}
45
46#[derive(Clone, Debug, Deserialize, Serialize)]
47#[serde(deny_unknown_fields, rename_all = "kebab-case")]
48pub enum PythonDependency {
49    Einops,
50    NvidiaCutlassDsl,
51}
52
53impl Display for PythonDependency {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        match self {
56            PythonDependency::Einops => write!(f, "einops"),
57            PythonDependency::NvidiaCutlassDsl => write!(f, "nvidia-cutlass-dsl"),
58        }
59    }
60}
61
62#[derive(Debug, Deserialize, Clone, Serialize)]
63#[serde(deny_unknown_fields)]
64pub struct Torch {
65    pub include: Option<Vec<String>>,
66    pub minver: Option<Version>,
67    pub maxver: Option<Version>,
68    pub pyext: Option<Vec<String>>,
69
70    #[serde(default)]
71    pub src: Vec<PathBuf>,
72}
73
74#[derive(Debug, Deserialize, Serialize)]
75#[serde(deny_unknown_fields, rename_all = "kebab-case", tag = "backend")]
76pub enum Kernel {
77    #[serde(rename_all = "kebab-case")]
78    Cpu {
79        cxx_flags: Option<Vec<String>>,
80        depends: Vec<Dependency>,
81        include: Option<Vec<String>>,
82        src: Vec<String>,
83    },
84    #[serde(rename_all = "kebab-case")]
85    Cuda {
86        cuda_capabilities: Option<Vec<String>>,
87        cuda_flags: Option<Vec<String>>,
88        cuda_minver: Option<Version>,
89        cxx_flags: Option<Vec<String>>,
90        depends: Vec<Dependency>,
91        include: Option<Vec<String>>,
92        src: Vec<String>,
93    },
94    #[serde(rename_all = "kebab-case")]
95    Metal {
96        cxx_flags: Option<Vec<String>>,
97        depends: Vec<Dependency>,
98        include: Option<Vec<String>>,
99        src: Vec<String>,
100    },
101    #[serde(rename_all = "kebab-case")]
102    Rocm {
103        cxx_flags: Option<Vec<String>>,
104        depends: Vec<Dependency>,
105        rocm_archs: Option<Vec<String>>,
106        hip_flags: Option<Vec<String>>,
107        include: Option<Vec<String>>,
108        src: Vec<String>,
109    },
110    #[serde(rename_all = "kebab-case")]
111    Xpu {
112        cxx_flags: Option<Vec<String>>,
113        depends: Vec<Dependency>,
114        sycl_flags: Option<Vec<String>>,
115        include: Option<Vec<String>>,
116        src: Vec<String>,
117    },
118}
119
120impl TryFrom<Build> for super::Build {
121    type Error = eyre::Error;
122
123    fn try_from(build: Build) -> Result<Self> {
124        let kernels: HashMap<String, super::Kernel> = build
125            .kernels
126            .into_iter()
127            .map(|(k, v)| (k, v.into()))
128            .collect();
129
130        let backends = if build.general.universal {
131            vec![
132                Backend::Cpu,
133                Backend::Cuda,
134                Backend::Metal,
135                Backend::Neuron,
136                Backend::Rocm,
137                Backend::Xpu,
138            ]
139        } else {
140            let backend_set: BTreeSet<Backend> =
141                kernels.values().map(|kernel| kernel.backend()).collect();
142            backend_set.into_iter().collect()
143        };
144
145        let torch = match build.torch {
146            Some(torch) => torch,
147            None => bail!("Torch section is required build.toml v2"),
148        };
149
150        Ok(Self {
151            general: General::from_v2(build.general, backends),
152            framework: super::Framework::Torch(torch.into()),
153            kernels,
154        })
155    }
156}
157
158impl General {
159    fn from_v2(general: General, backends: Vec<Backend>) -> super::General {
160        let cuda = if general.cuda_minver.is_some() || general.cuda_maxver.is_some() {
161            Some(super::CudaGeneral {
162                minver: general.cuda_minver,
163                maxver: general.cuda_maxver,
164                python_depends: None,
165            })
166        } else {
167            None
168        };
169
170        super::General {
171            name: general.name,
172            version: None,
173            license: None,
174            upstream: None,
175            backends,
176            cuda,
177            hub: general.hub.map(Into::into),
178            neuron: None,
179            python_depends: None,
180            xpu: None,
181        }
182    }
183}
184
185impl From<Hub> for super::Hub {
186    fn from(hub: Hub) -> Self {
187        Self {
188            repo_id: hub.repo_id,
189            branch: hub.branch,
190        }
191    }
192}
193
194impl From<Torch> for super::Torch {
195    fn from(torch: Torch) -> Self {
196        Self {
197            include: torch.include,
198            minver: torch.minver,
199            maxver: torch.maxver,
200            pyext: torch.pyext,
201            src: torch.src,
202        }
203    }
204}
205
206impl From<Kernel> for super::Kernel {
207    fn from(kernel: Kernel) -> Self {
208        match kernel {
209            Kernel::Cpu {
210                cxx_flags,
211                depends,
212                include,
213                src,
214            } => super::Kernel::Cpu {
215                cxx_flags,
216                depends,
217                include,
218                src,
219            },
220            Kernel::Cuda {
221                cuda_capabilities,
222                cuda_flags,
223                cuda_minver,
224                cxx_flags,
225                depends,
226                include,
227                src,
228            } => super::Kernel::Cuda {
229                cuda_capabilities,
230                cuda_flags,
231                cuda_minver,
232                cxx_flags,
233                depends,
234                include,
235                src,
236            },
237            Kernel::Metal {
238                cxx_flags,
239                depends,
240                include,
241                src,
242            } => super::Kernel::Metal {
243                cxx_flags,
244                depends,
245                include,
246                src,
247            },
248            Kernel::Rocm {
249                cxx_flags,
250                depends,
251                rocm_archs,
252                hip_flags,
253                include,
254                src,
255            } => super::Kernel::Rocm {
256                cxx_flags,
257                depends,
258                rocm_archs,
259                hip_flags,
260                include,
261                src,
262            },
263            Kernel::Xpu {
264                cxx_flags,
265                depends,
266                sycl_flags,
267                include,
268                src,
269            } => super::Kernel::Xpu {
270                cxx_flags,
271                depends,
272                sycl_flags,
273                include,
274                src,
275            },
276        }
277    }
278}