Skip to main content

kernels_data/config/
v1.rs

1use std::{
2    collections::{BTreeSet, HashMap},
3    fmt::Display,
4    path::PathBuf,
5};
6
7use eyre::{bail, Result};
8use serde::Deserialize;
9
10use super::{Backend, Dependency, KernelName};
11
12#[derive(Debug, Deserialize)]
13#[serde(deny_unknown_fields)]
14pub struct Build {
15    pub general: General,
16    pub torch: Option<Torch>,
17
18    #[serde(rename = "kernel", default)]
19    pub kernels: HashMap<String, Kernel>,
20}
21
22#[derive(Debug, Deserialize)]
23#[serde(deny_unknown_fields)]
24pub struct General {
25    pub name: KernelName,
26}
27
28#[derive(Debug, Deserialize, Clone)]
29#[serde(deny_unknown_fields)]
30pub struct Torch {
31    pub include: Option<Vec<String>>,
32    pub pyext: Option<Vec<String>>,
33
34    #[serde(default)]
35    pub src: Vec<PathBuf>,
36
37    #[serde(default)]
38    pub universal: bool,
39}
40
41#[derive(Debug, Deserialize)]
42#[serde(deny_unknown_fields, rename_all = "kebab-case")]
43pub struct Kernel {
44    pub cuda_capabilities: Option<Vec<String>>,
45    pub rocm_archs: Option<Vec<String>>,
46    #[serde(default)]
47    pub language: Language,
48    pub depends: Vec<Dependency>,
49    pub include: Option<Vec<String>>,
50    pub src: Vec<String>,
51}
52
53#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, PartialEq)]
54#[serde(deny_unknown_fields, rename_all = "kebab-case")]
55pub enum Language {
56    #[default]
57    Cuda,
58    CudaHipify,
59    Metal,
60}
61
62impl Display for Language {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        match self {
65            Language::Cuda => f.write_str("cuda"),
66            Language::CudaHipify => f.write_str("cuda-hipify"),
67            Language::Metal => f.write_str("metal"),
68        }
69    }
70}
71
72impl TryFrom<Build> for super::Build {
73    type Error = eyre::Error;
74
75    fn try_from(build: Build) -> Result<Self> {
76        let universal = build
77            .torch
78            .as_ref()
79            .map(|torch| torch.universal)
80            .unwrap_or(false);
81
82        let kernels = convert_kernels(build.kernels)?;
83
84        let backends = if universal {
85            vec![
86                Backend::Cpu,
87                Backend::Cuda,
88                Backend::Metal,
89                Backend::Neuron,
90                Backend::Rocm,
91                Backend::Xpu,
92            ]
93        } else {
94            let backend_set: BTreeSet<Backend> =
95                kernels.values().map(|kernel| kernel.backend()).collect();
96            backend_set.into_iter().collect()
97        };
98
99        let torch = match build.torch {
100            Some(torch) => torch,
101            None => bail!("Torch section is required build.toml v1"),
102        };
103
104        Ok(Self {
105            general: super::General {
106                name: build.general.name,
107                version: None,
108                license: None,
109                upstream: None,
110                backends,
111                hub: None,
112                neuron: None,
113                python_depends: None,
114                cuda: None,
115                xpu: None,
116            },
117            framework: super::Framework::Torch(torch.into()),
118            kernels,
119        })
120    }
121}
122
123fn convert_kernels(v1_kernels: HashMap<String, Kernel>) -> Result<HashMap<String, super::Kernel>> {
124    let mut kernels = HashMap::new();
125
126    for (name, kernel) in v1_kernels {
127        if kernel.language == Language::CudaHipify {
128            // We need to add an affix to avoid conflict with the CUDA kernel.
129            let rocm_name = format!("{name}_rocm");
130            if kernels.contains_key(&rocm_name) {
131                bail!("Found an existing kernel with name `{rocm_name}` while expanding `{name}`")
132            }
133
134            kernels.insert(
135                format!("{name}_rocm"),
136                super::Kernel::Rocm {
137                    cxx_flags: None,
138                    rocm_archs: kernel.rocm_archs,
139                    hip_flags: None,
140                    depends: kernel.depends.clone(),
141                    include: kernel.include.clone(),
142                    src: kernel.src.clone(),
143                },
144            );
145        }
146
147        kernels.insert(
148            name,
149            super::Kernel::Cuda {
150                cuda_capabilities: kernel.cuda_capabilities,
151                cuda_flags: None,
152                cuda_minver: None,
153                cxx_flags: None,
154                depends: kernel.depends,
155                include: kernel.include,
156                src: kernel.src,
157            },
158        );
159    }
160
161    Ok(kernels)
162}
163
164impl From<Torch> for super::Torch {
165    fn from(torch: Torch) -> Self {
166        Self {
167            include: torch.include,
168            minver: None,
169            maxver: None,
170            pyext: torch.pyext,
171            src: torch.src,
172        }
173    }
174}