1use std::{
2 collections::{BTreeSet, HashMap},
3 fmt::Display,
4 path::PathBuf,
5};
6
7use eyre::{bail, Result};
8use serde::{Deserialize, Serialize};
9
10use super::{Backend, Dependency, KernelName};
11use crate::version::Version;
12
13#[derive(Debug, Deserialize, Serialize)]
14#[serde(deny_unknown_fields)]
15pub struct Build {
16 pub general: General,
17 pub torch: Option<Torch>,
18
19 #[serde(rename = "kernel", default)]
20 pub kernels: HashMap<String, Kernel>,
21}
22
23#[derive(Debug, Deserialize, Serialize)]
24#[serde(deny_unknown_fields, rename_all = "kebab-case")]
25pub struct General {
26 pub name: KernelName,
27 #[serde(default)]
28 pub universal: bool,
29
30 pub cuda_maxver: Option<Version>,
31
32 pub cuda_minver: Option<Version>,
33
34 pub hub: Option<Hub>,
35
36 pub python_depends: Option<Vec<PythonDependency>>,
37}
38
39#[derive(Debug, Deserialize, Serialize)]
40#[serde(deny_unknown_fields, rename_all = "kebab-case")]
41pub struct Hub {
42 pub repo_id: Option<String>,
43 pub branch: Option<String>,
44}
45
46#[derive(Clone, Debug, Deserialize, Serialize)]
47#[serde(deny_unknown_fields, rename_all = "kebab-case")]
48pub enum PythonDependency {
49 Einops,
50 NvidiaCutlassDsl,
51}
52
53impl Display for PythonDependency {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 match self {
56 PythonDependency::Einops => write!(f, "einops"),
57 PythonDependency::NvidiaCutlassDsl => write!(f, "nvidia-cutlass-dsl"),
58 }
59 }
60}
61
62#[derive(Debug, Deserialize, Clone, Serialize)]
63#[serde(deny_unknown_fields)]
64pub struct Torch {
65 pub include: Option<Vec<String>>,
66 pub minver: Option<Version>,
67 pub maxver: Option<Version>,
68 pub pyext: Option<Vec<String>>,
69
70 #[serde(default)]
71 pub src: Vec<PathBuf>,
72}
73
74#[derive(Debug, Deserialize, Serialize)]
75#[serde(deny_unknown_fields, rename_all = "kebab-case", tag = "backend")]
76pub enum Kernel {
77 #[serde(rename_all = "kebab-case")]
78 Cpu {
79 cxx_flags: Option<Vec<String>>,
80 depends: Vec<Dependency>,
81 include: Option<Vec<String>>,
82 src: Vec<String>,
83 },
84 #[serde(rename_all = "kebab-case")]
85 Cuda {
86 cuda_capabilities: Option<Vec<String>>,
87 cuda_flags: Option<Vec<String>>,
88 cuda_minver: Option<Version>,
89 cxx_flags: Option<Vec<String>>,
90 depends: Vec<Dependency>,
91 include: Option<Vec<String>>,
92 src: Vec<String>,
93 },
94 #[serde(rename_all = "kebab-case")]
95 Metal {
96 cxx_flags: Option<Vec<String>>,
97 depends: Vec<Dependency>,
98 include: Option<Vec<String>>,
99 src: Vec<String>,
100 },
101 #[serde(rename_all = "kebab-case")]
102 Rocm {
103 cxx_flags: Option<Vec<String>>,
104 depends: Vec<Dependency>,
105 rocm_archs: Option<Vec<String>>,
106 hip_flags: Option<Vec<String>>,
107 include: Option<Vec<String>>,
108 src: Vec<String>,
109 },
110 #[serde(rename_all = "kebab-case")]
111 Xpu {
112 cxx_flags: Option<Vec<String>>,
113 depends: Vec<Dependency>,
114 sycl_flags: Option<Vec<String>>,
115 include: Option<Vec<String>>,
116 src: Vec<String>,
117 },
118}
119
120impl TryFrom<Build> for super::Build {
121 type Error = eyre::Error;
122
123 fn try_from(build: Build) -> Result<Self> {
124 let kernels: HashMap<String, super::Kernel> = build
125 .kernels
126 .into_iter()
127 .map(|(k, v)| (k, v.into()))
128 .collect();
129
130 let backends = if build.general.universal {
131 vec![
132 Backend::Cpu,
133 Backend::Cuda,
134 Backend::Metal,
135 Backend::Neuron,
136 Backend::Rocm,
137 Backend::Xpu,
138 ]
139 } else {
140 let backend_set: BTreeSet<Backend> =
141 kernels.values().map(|kernel| kernel.backend()).collect();
142 backend_set.into_iter().collect()
143 };
144
145 let torch = match build.torch {
146 Some(torch) => torch,
147 None => bail!("Torch section is required build.toml v2"),
148 };
149
150 Ok(Self {
151 general: General::from_v2(build.general, backends),
152 framework: super::Framework::Torch(torch.into()),
153 kernels,
154 })
155 }
156}
157
158impl General {
159 fn from_v2(general: General, backends: Vec<Backend>) -> super::General {
160 let cuda = if general.cuda_minver.is_some() || general.cuda_maxver.is_some() {
161 Some(super::CudaGeneral {
162 minver: general.cuda_minver,
163 maxver: general.cuda_maxver,
164 python_depends: None,
165 })
166 } else {
167 None
168 };
169
170 super::General {
171 name: general.name,
172 version: None,
173 license: None,
174 upstream: None,
175 backends,
176 cuda,
177 hub: general.hub.map(Into::into),
178 neuron: None,
179 python_depends: None,
180 xpu: None,
181 }
182 }
183}
184
185impl From<Hub> for super::Hub {
186 fn from(hub: Hub) -> Self {
187 Self {
188 repo_id: hub.repo_id,
189 branch: hub.branch,
190 }
191 }
192}
193
194impl From<Torch> for super::Torch {
195 fn from(torch: Torch) -> Self {
196 Self {
197 include: torch.include,
198 minver: torch.minver,
199 maxver: torch.maxver,
200 pyext: torch.pyext,
201 src: torch.src,
202 }
203 }
204}
205
206impl From<Kernel> for super::Kernel {
207 fn from(kernel: Kernel) -> Self {
208 match kernel {
209 Kernel::Cpu {
210 cxx_flags,
211 depends,
212 include,
213 src,
214 } => super::Kernel::Cpu {
215 cxx_flags,
216 depends,
217 include,
218 src,
219 },
220 Kernel::Cuda {
221 cuda_capabilities,
222 cuda_flags,
223 cuda_minver,
224 cxx_flags,
225 depends,
226 include,
227 src,
228 } => super::Kernel::Cuda {
229 cuda_capabilities,
230 cuda_flags,
231 cuda_minver,
232 cxx_flags,
233 depends,
234 include,
235 src,
236 },
237 Kernel::Metal {
238 cxx_flags,
239 depends,
240 include,
241 src,
242 } => super::Kernel::Metal {
243 cxx_flags,
244 depends,
245 include,
246 src,
247 },
248 Kernel::Rocm {
249 cxx_flags,
250 depends,
251 rocm_archs,
252 hip_flags,
253 include,
254 src,
255 } => super::Kernel::Rocm {
256 cxx_flags,
257 depends,
258 rocm_archs,
259 hip_flags,
260 include,
261 src,
262 },
263 Kernel::Xpu {
264 cxx_flags,
265 depends,
266 sycl_flags,
267 include,
268 src,
269 } => super::Kernel::Xpu {
270 cxx_flags,
271 depends,
272 sycl_flags,
273 include,
274 src,
275 },
276 }
277 }
278}