1#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
8pub enum BackendType {
9 #[default]
11 CPU,
12
13 #[cfg(feature = "cuda")]
15 Cuda,
16
17 #[cfg(feature = "opencl")]
19 OpenCL,
20
21 #[cfg(feature = "vulkan")]
23 Vulkan,
24
25 #[cfg(feature = "metal")]
27 Metal,
28
29 Auto,
31}
32
33impl BackendType {
34 pub(crate) fn to_mnn_type(&self) -> i32 {
36 match self {
37 BackendType::CPU => mnn_rs_sys::MNN_FORWARD_CPU,
38 #[cfg(feature = "cuda")]
39 BackendType::Cuda => mnn_rs_sys::MNN_FORWARD_CUDA,
40 #[cfg(feature = "opencl")]
41 BackendType::OpenCL => mnn_rs_sys::MNN_FORWARD_OPENCL,
42 #[cfg(feature = "vulkan")]
43 BackendType::Vulkan => mnn_rs_sys::MNN_FORWARD_VULKAN,
44 #[cfg(feature = "metal")]
45 BackendType::Metal => mnn_rs_sys::MNN_FORWARD_METAL,
46 BackendType::Auto => mnn_rs_sys::MNN_FORWARD_AUTO,
47 }
48 }
49
50 pub(crate) fn from_mnn_type(code: i32) -> Self {
52 match code {
53 #[cfg(feature = "cuda")]
54 x if x == mnn_rs_sys::MNN_FORWARD_CUDA => BackendType::Cuda,
55 #[cfg(feature = "opencl")]
56 x if x == mnn_rs_sys::MNN_FORWARD_OPENCL => BackendType::OpenCL,
57 #[cfg(feature = "vulkan")]
58 x if x == mnn_rs_sys::MNN_FORWARD_VULKAN => BackendType::Vulkan,
59 #[cfg(feature = "metal")]
60 x if x == mnn_rs_sys::MNN_FORWARD_METAL => BackendType::Metal,
61 _ => BackendType::CPU,
62 }
63 }
64
65 pub fn name(&self) -> &'static str {
67 match self {
68 BackendType::CPU => "CPU",
69 #[cfg(feature = "cuda")]
70 BackendType::Cuda => "CUDA",
71 #[cfg(feature = "opencl")]
72 BackendType::OpenCL => "OpenCL",
73 #[cfg(feature = "vulkan")]
74 BackendType::Vulkan => "Vulkan",
75 #[cfg(feature = "metal")]
76 BackendType::Metal => "Metal",
77 BackendType::Auto => "Auto",
78 }
79 }
80
81 pub fn is_gpu(&self) -> bool {
83 match self {
84 #[cfg(feature = "cuda")]
85 BackendType::Cuda => true,
86 #[cfg(feature = "opencl")]
87 BackendType::OpenCL => true,
88 #[cfg(feature = "vulkan")]
89 BackendType::Vulkan => true,
90 #[cfg(feature = "metal")]
91 BackendType::Metal => true,
92 _ => false,
93 }
94 }
95
96 pub fn is_available(&self) -> bool {
98 unsafe { mnn_rs_sys::mnn_is_backend_available(self.to_mnn_type()) != 0 }
99 }
100
101 pub fn available_backends() -> Vec<BackendType> {
103 let mut backends = Vec::new();
104
105 backends.push(BackendType::CPU);
107
108 #[cfg(feature = "cuda")]
109 {
110 if BackendType::Cuda.is_available() {
111 backends.push(BackendType::Cuda);
112 }
113 }
114
115 #[cfg(feature = "opencl")]
116 {
117 if BackendType::OpenCL.is_available() {
118 backends.push(BackendType::OpenCL);
119 }
120 }
121
122 #[cfg(feature = "vulkan")]
123 {
124 if BackendType::Vulkan.is_available() {
125 backends.push(BackendType::Vulkan);
126 }
127 }
128
129 #[cfg(feature = "metal")]
130 {
131 if BackendType::Metal.is_available() {
132 backends.push(BackendType::Metal);
133 }
134 }
135
136 backends
137 }
138}
139
140impl std::fmt::Display for BackendType {
141 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142 write!(f, "{}", self.name())
143 }
144}
145
146pub fn available_backends() -> Vec<BackendType> {
148 BackendType::available_backends()
149}
150
151pub fn is_backend_available(backend: BackendType) -> bool {
153 backend.is_available()
154}
155
156#[derive(Debug, Clone, Copy)]
158pub struct BackendCapabilities {
159 pub max_tensor_dimensions: i32,
161
162 pub supports_fp16: bool,
164
165 pub supports_int8: bool,
167
168 pub supports_bf16: bool,
170}
171
172impl BackendCapabilities {
173 pub fn query(_backend: BackendType) -> Self {
175 Self {
176 max_tensor_dimensions: 8,
177 supports_fp16: cfg!(feature = "fp16"),
178 supports_int8: cfg!(feature = "int8"),
179 supports_bf16: cfg!(feature = "bf16"),
180 }
181 }
182}
183
184#[derive(Debug, Clone)]
186pub struct BackendConfig {
187 pub backend_type: BackendType,
189
190 pub device_id: Option<i32>,
192
193 pub memory_mode: crate::config::MemoryMode,
195
196 pub power_mode: crate::config::PowerMode,
198
199 pub precision_mode: crate::config::PrecisionMode,
201}
202
203impl Default for BackendConfig {
204 fn default() -> Self {
205 Self {
206 backend_type: BackendType::CPU,
207 device_id: None,
208 memory_mode: crate::config::MemoryMode::Normal,
209 power_mode: crate::config::PowerMode::Normal,
210 precision_mode: crate::config::PrecisionMode::Normal,
211 }
212 }
213}
214
215impl BackendConfig {
216 pub fn new(backend_type: BackendType) -> Self {
218 Self {
219 backend_type,
220 ..Default::default()
221 }
222 }
223
224 pub fn cpu() -> Self {
226 Self::new(BackendType::CPU)
227 }
228
229 pub fn gpu() -> Self {
231 Self::new(BackendType::Auto)
232 }
233
234 pub fn with_device_id(mut self, id: i32) -> Self {
236 self.device_id = Some(id);
237 self
238 }
239
240 pub fn with_memory_mode(mut self, mode: crate::config::MemoryMode) -> Self {
242 self.memory_mode = mode;
243 self
244 }
245
246 pub fn with_power_mode(mut self, mode: crate::config::PowerMode) -> Self {
248 self.power_mode = mode;
249 self
250 }
251
252 pub fn with_precision_mode(mut self, mode: crate::config::PrecisionMode) -> Self {
254 self.precision_mode = mode;
255 self
256 }
257}
258
259#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
261pub enum DataType {
262 #[default]
264 Float32,
265
266 #[cfg(feature = "fp16")]
268 Float16,
269
270 #[cfg(feature = "bf16")]
272 BFloat16,
273
274 Int32,
276
277 #[cfg(feature = "int8")]
279 Int8,
280
281 UInt8,
283
284 Int16,
286
287 Float64,
289}
290
291impl DataType {
292 pub fn size(&self) -> usize {
294 match self {
295 DataType::Float32 => 4,
296 #[cfg(feature = "fp16")]
297 DataType::Float16 => 2,
298 #[cfg(feature = "bf16")]
299 DataType::BFloat16 => 2,
300 DataType::Int32 => 4,
301 #[cfg(feature = "int8")]
302 DataType::Int8 => 1,
303 DataType::UInt8 => 1,
304 DataType::Int16 => 2,
305 DataType::Float64 => 8,
306 }
307 }
308
309 pub fn name(&self) -> &'static str {
311 match self {
312 DataType::Float32 => "float32",
313 #[cfg(feature = "fp16")]
314 DataType::Float16 => "float16",
315 #[cfg(feature = "bf16")]
316 DataType::BFloat16 => "bfloat16",
317 DataType::Int32 => "int32",
318 #[cfg(feature = "int8")]
319 DataType::Int8 => "int8",
320 DataType::UInt8 => "uint8",
321 DataType::Int16 => "int16",
322 DataType::Float64 => "float64",
323 }
324 }
325
326 pub fn is_float(&self) -> bool {
328 match self {
329 DataType::Float32 | DataType::Float64 => true,
330 #[cfg(feature = "fp16")]
331 DataType::Float16 => true,
332 #[cfg(feature = "bf16")]
333 DataType::BFloat16 => true,
334 _ => false,
335 }
336 }
337
338 pub fn is_integer(&self) -> bool {
340 match self {
341 DataType::Int32 | DataType::Int16 | DataType::UInt8 => true,
342 #[cfg(feature = "int8")]
343 DataType::Int8 => true,
344 _ => false,
345 }
346 }
347
348 pub fn is_signed(&self) -> bool {
350 !matches!(self, DataType::UInt8)
351 }
352
353 pub(crate) fn to_type_code(&self) -> i32 {
355 match self {
357 DataType::Float32 => (0 << 8) | 32, DataType::Float64 => (0 << 8) | 64,
359 DataType::Int32 => (1 << 8) | 32, DataType::Int16 => (1 << 8) | 16,
361 #[cfg(feature = "int8")]
362 DataType::Int8 => (1 << 8) | 8,
363 DataType::UInt8 => (2 << 8) | 8, #[cfg(feature = "fp16")]
365 DataType::Float16 => (0 << 8) | 16,
366 #[cfg(feature = "bf16")]
367 DataType::BFloat16 => (0 << 8) | 16,
368 }
369 }
370
371 pub(crate) fn from_type_code(code: i32) -> Self {
373 let type_code = (code >> 8) & 0xFF;
375 let bits = code & 0xFF;
376
377 match (type_code, bits) {
378 (0, 32) => DataType::Float32,
379 (0, 64) => DataType::Float64,
380 (1, 32) => DataType::Int32,
381 (1, 16) => DataType::Int16,
382 #[cfg(feature = "int8")]
383 (1, 8) => DataType::Int8,
384 (2, 8) => DataType::UInt8,
385 #[cfg(feature = "fp16")]
386 (0, 16) => DataType::Float16,
387 #[cfg(feature = "bf16")]
388 (0, 16) => DataType::BFloat16,
389 _ => DataType::Float32, }
391 }
392}
393
394impl std::fmt::Display for DataType {
395 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
396 write!(f, "{}", self.name())
397 }
398}
399
400pub fn version() -> String {
402 unsafe {
403 let ptr = mnn_rs_sys::mnn_get_version();
404 if ptr.is_null() {
405 return String::from("unknown");
406 }
407 std::ffi::CStr::from_ptr(ptr)
408 .to_string_lossy()
409 .into_owned()
410 }
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416
417 #[test]
418 fn test_backend_type_name() {
419 assert_eq!(BackendType::CPU.name(), "CPU");
420 assert_eq!(BackendType::Auto.name(), "Auto");
421 }
422
423 #[test]
424 fn test_data_type_size() {
425 assert_eq!(DataType::Float32.size(), 4);
426 assert_eq!(DataType::Int32.size(), 4);
427 assert_eq!(DataType::Float64.size(), 8);
428 }
429
430 #[test]
431 fn test_backend_config_default() {
432 let config = BackendConfig::default();
433 assert_eq!(config.backend_type, BackendType::CPU);
434 assert!(config.device_id.is_none());
435 }
436}