1use std::{
2 collections::{BTreeSet, HashMap},
3 fmt::Display,
4 path::PathBuf,
5};
6
7use eyre::{bail, Result};
8use serde::Deserialize;
9
10use super::{Backend, Dependency, KernelName};
11
12#[derive(Debug, Deserialize)]
13#[serde(deny_unknown_fields)]
14pub struct Build {
15 pub general: General,
16 pub torch: Option<Torch>,
17
18 #[serde(rename = "kernel", default)]
19 pub kernels: HashMap<String, Kernel>,
20}
21
22#[derive(Debug, Deserialize)]
23#[serde(deny_unknown_fields)]
24pub struct General {
25 pub name: KernelName,
26}
27
28#[derive(Debug, Deserialize, Clone)]
29#[serde(deny_unknown_fields)]
30pub struct Torch {
31 pub include: Option<Vec<String>>,
32 pub pyext: Option<Vec<String>>,
33
34 #[serde(default)]
35 pub src: Vec<PathBuf>,
36
37 #[serde(default)]
38 pub universal: bool,
39}
40
41#[derive(Debug, Deserialize)]
42#[serde(deny_unknown_fields, rename_all = "kebab-case")]
43pub struct Kernel {
44 pub cuda_capabilities: Option<Vec<String>>,
45 pub rocm_archs: Option<Vec<String>>,
46 #[serde(default)]
47 pub language: Language,
48 pub depends: Vec<Dependency>,
49 pub include: Option<Vec<String>>,
50 pub src: Vec<String>,
51}
52
53#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, PartialEq)]
54#[serde(deny_unknown_fields, rename_all = "kebab-case")]
55pub enum Language {
56 #[default]
57 Cuda,
58 CudaHipify,
59 Metal,
60}
61
62impl Display for Language {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 match self {
65 Language::Cuda => f.write_str("cuda"),
66 Language::CudaHipify => f.write_str("cuda-hipify"),
67 Language::Metal => f.write_str("metal"),
68 }
69 }
70}
71
72impl TryFrom<Build> for super::Build {
73 type Error = eyre::Error;
74
75 fn try_from(build: Build) -> Result<Self> {
76 let universal = build
77 .torch
78 .as_ref()
79 .map(|torch| torch.universal)
80 .unwrap_or(false);
81
82 let kernels = convert_kernels(build.kernels)?;
83
84 let backends = if universal {
85 vec![
86 Backend::Cpu,
87 Backend::Cuda,
88 Backend::Metal,
89 Backend::Neuron,
90 Backend::Rocm,
91 Backend::Xpu,
92 ]
93 } else {
94 let backend_set: BTreeSet<Backend> =
95 kernels.values().map(|kernel| kernel.backend()).collect();
96 backend_set.into_iter().collect()
97 };
98
99 let torch = match build.torch {
100 Some(torch) => torch,
101 None => bail!("Torch section is required build.toml v1"),
102 };
103
104 Ok(Self {
105 general: super::General {
106 name: build.general.name,
107 version: None,
108 license: None,
109 upstream: None,
110 backends,
111 hub: None,
112 neuron: None,
113 python_depends: None,
114 cuda: None,
115 xpu: None,
116 },
117 framework: super::Framework::Torch(torch.into()),
118 kernels,
119 })
120 }
121}
122
123fn convert_kernels(v1_kernels: HashMap<String, Kernel>) -> Result<HashMap<String, super::Kernel>> {
124 let mut kernels = HashMap::new();
125
126 for (name, kernel) in v1_kernels {
127 if kernel.language == Language::CudaHipify {
128 let rocm_name = format!("{name}_rocm");
130 if kernels.contains_key(&rocm_name) {
131 bail!("Found an existing kernel with name `{rocm_name}` while expanding `{name}`")
132 }
133
134 kernels.insert(
135 format!("{name}_rocm"),
136 super::Kernel::Rocm {
137 cxx_flags: None,
138 rocm_archs: kernel.rocm_archs,
139 hip_flags: None,
140 depends: kernel.depends.clone(),
141 include: kernel.include.clone(),
142 src: kernel.src.clone(),
143 },
144 );
145 }
146
147 kernels.insert(
148 name,
149 super::Kernel::Cuda {
150 cuda_capabilities: kernel.cuda_capabilities,
151 cuda_flags: None,
152 cuda_minver: None,
153 cxx_flags: None,
154 depends: kernel.depends,
155 include: kernel.include,
156 src: kernel.src,
157 },
158 );
159 }
160
161 Ok(kernels)
162}
163
164impl From<Torch> for super::Torch {
165 fn from(torch: Torch) -> Self {
166 Self {
167 include: torch.include,
168 minver: None,
169 maxver: None,
170 pyext: torch.pyext,
171 src: torch.src,
172 }
173 }
174}