Skip to main content

kernels_data/config/
deps.rs

1use std::{collections::HashMap, sync::LazyLock};
2
3use eyre::Result;
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6
7use super::Backend;
8
9pub static PYTHON_DEPENDENCIES: LazyLock<PythonDependencies> =
10    LazyLock::new(|| serde_json::from_str(include_str!("../python_dependencies.json")).unwrap());
11
12#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
13#[non_exhaustive]
14#[serde(rename_all = "lowercase")]
15pub enum Dependency {
16    #[serde(rename = "cutlass_2_10")]
17    Cutlass2_10,
18    #[serde(rename = "cutlass_3_5")]
19    Cutlass3_5,
20    #[serde(rename = "cutlass_3_6")]
21    Cutlass3_6,
22    #[serde(rename = "cutlass_3_8")]
23    Cutlass3_8,
24    #[serde(rename = "cutlass_3_9")]
25    Cutlass3_9,
26    #[serde(rename = "cutlass_4_0")]
27    Cutlass4_0,
28    #[serde(rename = "sycl_tla")]
29    SyclTla,
30    #[serde(rename = "metal-cpp")]
31    MetalCpp,
32    Torch,
33}
34
35#[derive(Debug, Deserialize, Serialize)]
36#[serde(deny_unknown_fields)]
37pub struct PythonDependencies {
38    general: HashMap<String, PythonDependency>,
39    backends: HashMap<Backend, HashMap<String, PythonDependency>>,
40}
41
42impl PythonDependencies {
43    pub fn get_dependency(&self, dependency: &str) -> Result<&PythonDependency, DependencyError> {
44        match self.general.get(dependency) {
45            None => Err(DependencyError::GeneralDependency {
46                dependency: dependency.to_string(),
47            }),
48            Some(dep) => Ok(dep),
49        }
50    }
51
52    pub fn get_backend_dependency(
53        &self,
54        backend: Backend,
55        dependency: &str,
56    ) -> Result<&PythonDependency, DependencyError> {
57        let backend_deps = match self.backends.get(&backend) {
58            None => {
59                return Err(DependencyError::Backend {
60                    backend: backend.to_string(),
61                })
62            }
63            Some(backend_deps) => backend_deps,
64        };
65        match backend_deps.get(dependency) {
66            None => Err(DependencyError::Dependency {
67                backend: backend.to_string(),
68                dependency: dependency.to_string(),
69            }),
70            Some(dep) => Ok(dep),
71        }
72    }
73}
74
75/// Entry for a builder Python dependency.
76#[derive(Debug, Deserialize, Serialize)]
77pub struct PythonDependency {
78    /// Nix dependencies in `python3.pkgs`.
79    pub nix: Vec<String>,
80
81    /// Python dependency.
82    pub python: Vec<PythonPkgImport>,
83}
84
85/// Python package and (module to import).
86#[derive(Debug, Deserialize, Serialize)]
87pub struct PythonPkgImport {
88    pub pkg: String,
89    pub import: Option<String>,
90}
91
92#[derive(Debug, Error)]
93pub enum DependencyError {
94    #[error("No dependencies are defined for backend: {backend:?}")]
95    Backend { backend: String },
96    #[error("Unknown dependency `{dependency:?}` for backend `{backend:?}`")]
97    Dependency { backend: String, dependency: String },
98    #[error("Unknown dependency: `{dependency:?}`")]
99    GeneralDependency { dependency: String },
100}