Skip to main content

mnn_rs/
backend.rs

1//! Backend configuration and management for MNN.
2//!
3//! This module provides types for configuring compute backends (CPU, GPU, etc.)
4//! and querying backend capabilities.
5
6/// Compute backend type.
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
8pub enum BackendType {
9    /// CPU backend (always available)
10    #[default]
11    CPU,
12
13    /// CUDA backend (NVIDIA GPUs)
14    #[cfg(feature = "cuda")]
15    Cuda,
16
17    /// OpenCL backend (cross-platform GPU)
18    #[cfg(feature = "opencl")]
19    OpenCL,
20
21    /// Vulkan backend (cross-platform GPU)
22    #[cfg(feature = "vulkan")]
23    Vulkan,
24
25    /// Metal backend (Apple platforms)
26    #[cfg(feature = "metal")]
27    Metal,
28
29    /// Auto-detect best available backend
30    Auto,
31}
32
33impl BackendType {
34    /// Convert to MNN forward type constant.
35    pub(crate) fn to_mnn_type(&self) -> i32 {
36        match self {
37            BackendType::CPU => mnn_rs_sys::MNN_FORWARD_CPU,
38            #[cfg(feature = "cuda")]
39            BackendType::Cuda => mnn_rs_sys::MNN_FORWARD_CUDA,
40            #[cfg(feature = "opencl")]
41            BackendType::OpenCL => mnn_rs_sys::MNN_FORWARD_OPENCL,
42            #[cfg(feature = "vulkan")]
43            BackendType::Vulkan => mnn_rs_sys::MNN_FORWARD_VULKAN,
44            #[cfg(feature = "metal")]
45            BackendType::Metal => mnn_rs_sys::MNN_FORWARD_METAL,
46            BackendType::Auto => mnn_rs_sys::MNN_FORWARD_AUTO,
47        }
48    }
49
50    /// Convert from MNN forward type constant.
51    pub(crate) fn from_mnn_type(code: i32) -> Self {
52        match code {
53            #[cfg(feature = "cuda")]
54            x if x == mnn_rs_sys::MNN_FORWARD_CUDA => BackendType::Cuda,
55            #[cfg(feature = "opencl")]
56            x if x == mnn_rs_sys::MNN_FORWARD_OPENCL => BackendType::OpenCL,
57            #[cfg(feature = "vulkan")]
58            x if x == mnn_rs_sys::MNN_FORWARD_VULKAN => BackendType::Vulkan,
59            #[cfg(feature = "metal")]
60            x if x == mnn_rs_sys::MNN_FORWARD_METAL => BackendType::Metal,
61            _ => BackendType::CPU,
62        }
63    }
64
65    /// Get the name of this backend.
66    pub fn name(&self) -> &'static str {
67        match self {
68            BackendType::CPU => "CPU",
69            #[cfg(feature = "cuda")]
70            BackendType::Cuda => "CUDA",
71            #[cfg(feature = "opencl")]
72            BackendType::OpenCL => "OpenCL",
73            #[cfg(feature = "vulkan")]
74            BackendType::Vulkan => "Vulkan",
75            #[cfg(feature = "metal")]
76            BackendType::Metal => "Metal",
77            BackendType::Auto => "Auto",
78        }
79    }
80
81    /// Check if this backend is a GPU backend.
82    pub fn is_gpu(&self) -> bool {
83        match self {
84            #[cfg(feature = "cuda")]
85            BackendType::Cuda => true,
86            #[cfg(feature = "opencl")]
87            BackendType::OpenCL => true,
88            #[cfg(feature = "vulkan")]
89            BackendType::Vulkan => true,
90            #[cfg(feature = "metal")]
91            BackendType::Metal => true,
92            _ => false,
93        }
94    }
95
96    /// Check if this backend is available on the current system.
97    pub fn is_available(&self) -> bool {
98        unsafe { mnn_rs_sys::mnn_is_backend_available(self.to_mnn_type()) != 0 }
99    }
100
101    /// Get all available backends on this system.
102    pub fn available_backends() -> Vec<BackendType> {
103        let mut backends = Vec::new();
104
105        // CPU is always available
106        backends.push(BackendType::CPU);
107
108        #[cfg(feature = "cuda")]
109        {
110            if BackendType::Cuda.is_available() {
111                backends.push(BackendType::Cuda);
112            }
113        }
114
115        #[cfg(feature = "opencl")]
116        {
117            if BackendType::OpenCL.is_available() {
118                backends.push(BackendType::OpenCL);
119            }
120        }
121
122        #[cfg(feature = "vulkan")]
123        {
124            if BackendType::Vulkan.is_available() {
125                backends.push(BackendType::Vulkan);
126            }
127        }
128
129        #[cfg(feature = "metal")]
130        {
131            if BackendType::Metal.is_available() {
132                backends.push(BackendType::Metal);
133            }
134        }
135
136        backends
137    }
138}
139
140impl std::fmt::Display for BackendType {
141    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142        write!(f, "{}", self.name())
143    }
144}
145
146/// Get list of available backends on this system.
147pub fn available_backends() -> Vec<BackendType> {
148    BackendType::available_backends()
149}
150
151/// Check if a specific backend is available.
152pub fn is_backend_available(backend: BackendType) -> bool {
153    backend.is_available()
154}
155
156/// Backend capabilities.
157#[derive(Debug, Clone, Copy)]
158pub struct BackendCapabilities {
159    /// Maximum supported tensor dimensions
160    pub max_tensor_dimensions: i32,
161
162    /// Supports FP16 operations
163    pub supports_fp16: bool,
164
165    /// Supports INT8 operations
166    pub supports_int8: bool,
167
168    /// Supports BF16 operations
169    pub supports_bf16: bool,
170}
171
172impl BackendCapabilities {
173    /// Query capabilities for a specific backend.
174    pub fn query(_backend: BackendType) -> Self {
175        Self {
176            max_tensor_dimensions: 8,
177            supports_fp16: cfg!(feature = "fp16"),
178            supports_int8: cfg!(feature = "int8"),
179            supports_bf16: cfg!(feature = "bf16"),
180        }
181    }
182}
183
184/// Configuration for a compute backend.
185#[derive(Debug, Clone)]
186pub struct BackendConfig {
187    /// The backend type to use
188    pub backend_type: BackendType,
189
190    /// Device ID for GPU backends (default: 0)
191    pub device_id: Option<i32>,
192
193    /// Memory usage mode
194    pub memory_mode: crate::config::MemoryMode,
195
196    /// Power usage mode
197    pub power_mode: crate::config::PowerMode,
198
199    /// Precision mode
200    pub precision_mode: crate::config::PrecisionMode,
201}
202
203impl Default for BackendConfig {
204    fn default() -> Self {
205        Self {
206            backend_type: BackendType::CPU,
207            device_id: None,
208            memory_mode: crate::config::MemoryMode::Normal,
209            power_mode: crate::config::PowerMode::Normal,
210            precision_mode: crate::config::PrecisionMode::Normal,
211        }
212    }
213}
214
215impl BackendConfig {
216    /// Create a new backend config with default settings.
217    pub fn new(backend_type: BackendType) -> Self {
218        Self {
219            backend_type,
220            ..Default::default()
221        }
222    }
223
224    /// Create a CPU backend config.
225    pub fn cpu() -> Self {
226        Self::new(BackendType::CPU)
227    }
228
229    /// Create a GPU backend config with auto-detection.
230    pub fn gpu() -> Self {
231        Self::new(BackendType::Auto)
232    }
233
234    /// Set the device ID.
235    pub fn with_device_id(mut self, id: i32) -> Self {
236        self.device_id = Some(id);
237        self
238    }
239
240    /// Set the memory mode.
241    pub fn with_memory_mode(mut self, mode: crate::config::MemoryMode) -> Self {
242        self.memory_mode = mode;
243        self
244    }
245
246    /// Set the power mode.
247    pub fn with_power_mode(mut self, mode: crate::config::PowerMode) -> Self {
248        self.power_mode = mode;
249        self
250    }
251
252    /// Set the precision mode.
253    pub fn with_precision_mode(mut self, mode: crate::config::PrecisionMode) -> Self {
254        self.precision_mode = mode;
255        self
256    }
257}
258
259/// Data type for tensor elements.
260#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
261pub enum DataType {
262    /// 32-bit floating point
263    #[default]
264    Float32,
265
266    /// 16-bit floating point (half precision)
267    #[cfg(feature = "fp16")]
268    Float16,
269
270    /// Brain float 16
271    #[cfg(feature = "bf16")]
272    BFloat16,
273
274    /// 32-bit signed integer
275    Int32,
276
277    /// 8-bit signed integer
278    #[cfg(feature = "int8")]
279    Int8,
280
281    /// 8-bit unsigned integer
282    UInt8,
283
284    /// 16-bit signed integer
285    Int16,
286
287    /// 64-bit floating point
288    Float64,
289}
290
291impl DataType {
292    /// Get the size in bytes of this data type.
293    pub fn size(&self) -> usize {
294        match self {
295            DataType::Float32 => 4,
296            #[cfg(feature = "fp16")]
297            DataType::Float16 => 2,
298            #[cfg(feature = "bf16")]
299            DataType::BFloat16 => 2,
300            DataType::Int32 => 4,
301            #[cfg(feature = "int8")]
302            DataType::Int8 => 1,
303            DataType::UInt8 => 1,
304            DataType::Int16 => 2,
305            DataType::Float64 => 8,
306        }
307    }
308
309    /// Get the name of this data type.
310    pub fn name(&self) -> &'static str {
311        match self {
312            DataType::Float32 => "float32",
313            #[cfg(feature = "fp16")]
314            DataType::Float16 => "float16",
315            #[cfg(feature = "bf16")]
316            DataType::BFloat16 => "bfloat16",
317            DataType::Int32 => "int32",
318            #[cfg(feature = "int8")]
319            DataType::Int8 => "int8",
320            DataType::UInt8 => "uint8",
321            DataType::Int16 => "int16",
322            DataType::Float64 => "float64",
323        }
324    }
325
326    /// Check if this is a floating point type.
327    pub fn is_float(&self) -> bool {
328        match self {
329            DataType::Float32 | DataType::Float64 => true,
330            #[cfg(feature = "fp16")]
331            DataType::Float16 => true,
332            #[cfg(feature = "bf16")]
333            DataType::BFloat16 => true,
334            _ => false,
335        }
336    }
337
338    /// Check if this is an integer type.
339    pub fn is_integer(&self) -> bool {
340        match self {
341            DataType::Int32 | DataType::Int16 | DataType::UInt8 => true,
342            #[cfg(feature = "int8")]
343            DataType::Int8 => true,
344            _ => false,
345        }
346    }
347
348    /// Check if this is a signed type.
349    pub fn is_signed(&self) -> bool {
350        !matches!(self, DataType::UInt8)
351    }
352
353    /// Convert to MNN type code.
354    pub(crate) fn to_type_code(&self) -> i32 {
355        // MNN uses halide_type_t codes: (code << 8) | bits
356        match self {
357            DataType::Float32 => (0 << 8) | 32,  // halide_type_float = 0
358            DataType::Float64 => (0 << 8) | 64,
359            DataType::Int32 => (1 << 8) | 32,    // halide_type_int = 1
360            DataType::Int16 => (1 << 8) | 16,
361            #[cfg(feature = "int8")]
362            DataType::Int8 => (1 << 8) | 8,
363            DataType::UInt8 => (2 << 8) | 8,     // halide_type_uint = 2
364            #[cfg(feature = "fp16")]
365            DataType::Float16 => (0 << 8) | 16,
366            #[cfg(feature = "bf16")]
367            DataType::BFloat16 => (0 << 8) | 16,
368        }
369    }
370
371    /// Create from MNN type code.
372    pub(crate) fn from_type_code(code: i32) -> Self {
373        // MNN halide_type_t codes: (type << 8) | bits
374        let type_code = (code >> 8) & 0xFF;
375        let bits = code & 0xFF;
376
377        match (type_code, bits) {
378            (0, 32) => DataType::Float32,
379            (0, 64) => DataType::Float64,
380            (1, 32) => DataType::Int32,
381            (1, 16) => DataType::Int16,
382            #[cfg(feature = "int8")]
383            (1, 8) => DataType::Int8,
384            (2, 8) => DataType::UInt8,
385            #[cfg(feature = "fp16")]
386            (0, 16) => DataType::Float16,
387            #[cfg(feature = "bf16")]
388            (0, 16) => DataType::BFloat16,
389            _ => DataType::Float32, // Default fallback
390        }
391    }
392}
393
394impl std::fmt::Display for DataType {
395    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
396        write!(f, "{}", self.name())
397    }
398}
399
400/// Get the MNN version string.
401pub fn version() -> String {
402    unsafe {
403        let ptr = mnn_rs_sys::mnn_get_version();
404        if ptr.is_null() {
405            return String::from("unknown");
406        }
407        std::ffi::CStr::from_ptr(ptr)
408            .to_string_lossy()
409            .into_owned()
410    }
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416
417    #[test]
418    fn test_backend_type_name() {
419        assert_eq!(BackendType::CPU.name(), "CPU");
420        assert_eq!(BackendType::Auto.name(), "Auto");
421    }
422
423    #[test]
424    fn test_data_type_size() {
425        assert_eq!(DataType::Float32.size(), 4);
426        assert_eq!(DataType::Int32.size(), 4);
427        assert_eq!(DataType::Float64.size(), 8);
428    }
429
430    #[test]
431    fn test_backend_config_default() {
432        let config = BackendConfig::default();
433        assert_eq!(config.backend_type, BackendType::CPU);
434        assert!(config.device_id.is_none());
435    }
436}