numrs/backend/
capabilities.rs1use crate::array::DType;
8
9#[derive(Debug, Clone)]
11pub struct BackendCapabilities {
12 pub name: &'static str,
14 pub supported_dtypes: &'static [DType],
16}
17
18impl BackendCapabilities {
19 pub fn supports(&self, dtype: DType) -> bool {
21 self.supported_dtypes.contains(&dtype)
22 }
23
24 pub fn supported_types_str(&self) -> String {
26 self.supported_dtypes
27 .iter()
28 .map(|dt| dt.name())
29 .collect::<Vec<_>>()
30 .join(", ")
31 }
32}
33
34pub const BLAS_CAPABILITIES: BackendCapabilities = BackendCapabilities {
44 name: "blas",
45 supported_dtypes: &[DType::F32, DType::F64],
46};
47
48pub const SIMD_CAPABILITIES: BackendCapabilities = BackendCapabilities {
53 name: "cpu-simd",
54 supported_dtypes: &[DType::F16, DType::F32, DType::F64, DType::I32],
55};
56
57pub const SCALAR_CAPABILITIES: BackendCapabilities = BackendCapabilities {
60 name: "cpu-scalar",
61 supported_dtypes: &[
62 DType::F16,
63 DType::BF16,
64 DType::F32,
65 DType::F64,
66 DType::U8,
67 DType::I8,
68 DType::I32,
69 DType::Bool,
70 ],
71};
72
73pub const WEBGPU_CAPABILITIES: BackendCapabilities = BackendCapabilities {
76 name: "webgpu",
77 supported_dtypes: &[DType::F16, DType::F32],
78};
79
80pub const METAL_CAPABILITIES: BackendCapabilities = BackendCapabilities {
83 name: "metal",
84 supported_dtypes: &[DType::F16, DType::F32, DType::F64, DType::I32, DType::U8],
85};
86
87pub const CUDA_CAPABILITIES: BackendCapabilities = BackendCapabilities {
90 name: "cuda",
91 supported_dtypes: &[
92 DType::F16,
93 DType::BF16,
94 DType::F32,
95 DType::F64,
96 DType::U8,
97 DType::I8,
98 DType::I32,
99 ],
100};
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105
106 #[test]
107 fn test_blas_capabilities() {
108 assert!(BLAS_CAPABILITIES.supports(DType::F32));
109 assert!(BLAS_CAPABILITIES.supports(DType::F64));
110 assert!(!BLAS_CAPABILITIES.supports(DType::I32));
111 assert!(!BLAS_CAPABILITIES.supports(DType::Bool));
112 }
113
114 #[test]
115 fn test_simd_capabilities() {
116 assert!(SIMD_CAPABILITIES.supports(DType::F16));
117 assert!(SIMD_CAPABILITIES.supports(DType::F32));
118 assert!(SIMD_CAPABILITIES.supports(DType::I32));
119 assert!(!SIMD_CAPABILITIES.supports(DType::BF16)); assert!(!SIMD_CAPABILITIES.supports(DType::U8));
121 assert!(!SIMD_CAPABILITIES.supports(DType::Bool));
122 }
123
124 #[test]
125 fn test_scalar_capabilities() {
126 assert!(SCALAR_CAPABILITIES.supports(DType::F16));
128 assert!(SCALAR_CAPABILITIES.supports(DType::BF16));
129 assert!(SCALAR_CAPABILITIES.supports(DType::F32));
130 assert!(SCALAR_CAPABILITIES.supports(DType::F64));
131 assert!(SCALAR_CAPABILITIES.supports(DType::U8));
132 assert!(SCALAR_CAPABILITIES.supports(DType::I8));
133 assert!(SCALAR_CAPABILITIES.supports(DType::I32));
134 assert!(SCALAR_CAPABILITIES.supports(DType::Bool));
135 }
136
137 #[test]
138 fn test_webgpu_capabilities() {
139 assert!(WEBGPU_CAPABILITIES.supports(DType::F16));
140 assert!(WEBGPU_CAPABILITIES.supports(DType::F32));
141 assert!(!WEBGPU_CAPABILITIES.supports(DType::F64));
142 assert!(!WEBGPU_CAPABILITIES.supports(DType::I32));
143 }
144
145 #[test]
146 fn test_supported_types_str() {
147 let s = BLAS_CAPABILITIES.supported_types_str();
148 assert!(s.contains("f32"));
149 assert!(s.contains("f64"));
150 }
151}