numrs/backend/
capabilities.rs

1//! Backend capabilities - Define qué dtypes soporta cada backend
2//!
3//! Cada backend declara explícitamente qué tipos de datos puede manejar.
4//! Esto permite al sistema de dispatch seleccionar el backend apropiado
5//! basado tanto en disponibilidad como en el dtype del array.
6
7use crate::array::DType;
8
9/// Capabilities de un backend - qué tipos soporta
10#[derive(Debug, Clone)]
11pub struct BackendCapabilities {
12    /// Nombre del backend
13    pub name: &'static str,
14    /// Tipos de datos soportados
15    pub supported_dtypes: &'static [DType],
16}
17
18impl BackendCapabilities {
19    /// Verifica si este backend soporta el dtype dado
20    pub fn supports(&self, dtype: DType) -> bool {
21        self.supported_dtypes.contains(&dtype)
22    }
23
24    /// Lista de dtypes soportados como strings (para logging)
25    pub fn supported_types_str(&self) -> String {
26        self.supported_dtypes
27            .iter()
28            .map(|dt| dt.name())
29            .collect::<Vec<_>>()
30            .join(", ")
31    }
32}
33
34// ============================================================================
35// Backend Capabilities Constants
36// ============================================================================
37
38/// BLAS capabilities: solo F32 y F64
39/// BLAS tiene funciones específicas por tipo:
40/// - sgemm (single precision, f32)
41/// - dgemm (double precision, f64)
42/// F16/BF16 no están soportados en BLAS estándar
43pub const BLAS_CAPABILITIES: BackendCapabilities = BackendCapabilities {
44    name: "blas",
45    supported_dtypes: &[DType::F32, DType::F64],
46};
47
48/// CPU SIMD capabilities: F16, F32, F64, I32
49/// AVX2/SSE soportan operaciones vectorizadas para estos tipos
50/// F16 via conversión F16->F32->operación->F16
51/// BF16 tiene soporte limitado (solo en AVX512_BF16)
52pub const SIMD_CAPABILITIES: BackendCapabilities = BackendCapabilities {
53    name: "cpu-simd",
54    supported_dtypes: &[DType::F16, DType::F32, DType::F64, DType::I32],
55};
56
57/// CPU Scalar capabilities: todos los tipos
58/// El backend escalar puede manejar cualquier tipo mediante loops simples
59pub const SCALAR_CAPABILITIES: BackendCapabilities = BackendCapabilities {
60    name: "cpu-scalar",
61    supported_dtypes: &[
62        DType::F16,
63        DType::BF16,
64        DType::F32,
65        DType::F64,
66        DType::U8,
67        DType::I8,
68        DType::I32,
69        DType::Bool,
70    ],
71};
72
73/// WebGPU capabilities: F16, F32
74/// WGSL (WebGPU Shading Language) soporta f16 (con extensión) y f32 nativamente
75pub const WEBGPU_CAPABILITIES: BackendCapabilities = BackendCapabilities {
76    name: "webgpu",
77    supported_dtypes: &[DType::F16, DType::F32],
78};
79
80/// Metal capabilities: F16, F32, F64, I32, U8
81/// Metal Shading Language tiene buen soporte para estos tipos
82pub const METAL_CAPABILITIES: BackendCapabilities = BackendCapabilities {
83    name: "metal",
84    supported_dtypes: &[DType::F16, DType::F32, DType::F64, DType::I32, DType::U8],
85};
86
87/// CUDA capabilities: F16, BF16, F32, F64, I32, U8, I8
88/// CUDA tiene excelente soporte para tipos ML (incluyendo BF16)
89pub const CUDA_CAPABILITIES: BackendCapabilities = BackendCapabilities {
90    name: "cuda",
91    supported_dtypes: &[
92        DType::F16,
93        DType::BF16,
94        DType::F32,
95        DType::F64,
96        DType::U8,
97        DType::I8,
98        DType::I32,
99    ],
100};
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105
106    #[test]
107    fn test_blas_capabilities() {
108        assert!(BLAS_CAPABILITIES.supports(DType::F32));
109        assert!(BLAS_CAPABILITIES.supports(DType::F64));
110        assert!(!BLAS_CAPABILITIES.supports(DType::I32));
111        assert!(!BLAS_CAPABILITIES.supports(DType::Bool));
112    }
113
114    #[test]
115    fn test_simd_capabilities() {
116        assert!(SIMD_CAPABILITIES.supports(DType::F16));
117        assert!(SIMD_CAPABILITIES.supports(DType::F32));
118        assert!(SIMD_CAPABILITIES.supports(DType::I32));
119        assert!(!SIMD_CAPABILITIES.supports(DType::BF16));  // BF16 necesita AVX512_BF16
120        assert!(!SIMD_CAPABILITIES.supports(DType::U8));
121        assert!(!SIMD_CAPABILITIES.supports(DType::Bool));
122    }
123
124    #[test]
125    fn test_scalar_capabilities() {
126        // Scalar debería soportar TODOS los tipos
127        assert!(SCALAR_CAPABILITIES.supports(DType::F16));
128        assert!(SCALAR_CAPABILITIES.supports(DType::BF16));
129        assert!(SCALAR_CAPABILITIES.supports(DType::F32));
130        assert!(SCALAR_CAPABILITIES.supports(DType::F64));
131        assert!(SCALAR_CAPABILITIES.supports(DType::U8));
132        assert!(SCALAR_CAPABILITIES.supports(DType::I8));
133        assert!(SCALAR_CAPABILITIES.supports(DType::I32));
134        assert!(SCALAR_CAPABILITIES.supports(DType::Bool));
135    }
136
137    #[test]
138    fn test_webgpu_capabilities() {
139        assert!(WEBGPU_CAPABILITIES.supports(DType::F16));
140        assert!(WEBGPU_CAPABILITIES.supports(DType::F32));
141        assert!(!WEBGPU_CAPABILITIES.supports(DType::F64));
142        assert!(!WEBGPU_CAPABILITIES.supports(DType::I32));
143    }
144
145    #[test]
146    fn test_supported_types_str() {
147        let s = BLAS_CAPABILITIES.supported_types_str();
148        assert!(s.contains("f32"));
149        assert!(s.contains("f64"));
150    }
151}