Skip to main content

mnn_rs/
config.rs

1//! Configuration types for MNN sessions and backends.
2//!
3//! This module provides configuration structures for creating interpreter
4//! sessions with specific backend settings, thread counts, and data formats.
5
6use crate::backend::{BackendConfig, BackendType};
7
8/// Data format for tensors.
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
10pub enum DataFormat {
11    /// NHWC format: (batch, height, width, channels)
12    /// Common in TensorFlow models
13    #[default]
14    Nhwc,
15
16    /// NCHW format: (batch, channels, height, width)
17    /// Common in PyTorch/ONNX models
18    Nchw,
19
20    /// NC4HW4 format: Optimized format for GPU backends
21    Nc4hw4,
22}
23
24impl DataFormat {
25    /// Get the name of this format
26    pub fn name(&self) -> &'static str {
27        match self {
28            DataFormat::Nhwc => "NHWC",
29            DataFormat::Nchw => "NCHW",
30            DataFormat::Nc4hw4 => "NC4HW4",
31        }
32    }
33
34    /// Convert to MNN dimension type constant.
35    pub(crate) fn to_mnn(&self) -> i32 {
36        match self {
37            DataFormat::Nhwc => mnn_rs_sys::MNN_DATA_FORMAT_NHWC,
38            DataFormat::Nchw => mnn_rs_sys::MNN_DATA_FORMAT_NCHW,
39            DataFormat::Nc4hw4 => mnn_rs_sys::MNN_DATA_FORMAT_NC4HW4,
40        }
41    }
42}
43
44/// Power usage mode.
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
46pub enum MemoryMode {
47    /// Normal memory usage (balanced)
48    #[default]
49    Normal,
50
51    /// Low memory usage (may impact performance)
52    Low,
53
54    /// High memory usage for better performance
55    High,
56}
57
58/// Session mode for controlling interpreter behavior.
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
60pub enum SessionMode {
61    /// Debug mode - allows callback and internal op info
62    #[default]
63    Debug,
64    /// Release mode - no callback, optimized
65    Release,
66    /// Input tensor allocated by session
67    InputInside,
68    /// Input tensor allocated by user
69    InputUser,
70    /// Output tensor depends on session
71    OutputInside,
72    /// Output tensor can be used separately
73    OutputUser,
74    /// Resize session directly
75    ResizeDirect,
76    /// Resize session deferred
77    ResizeDefer,
78    /// Backend fixed by user setting
79    BackendFix,
80    /// Backend auto determined by MNN
81    BackendAuto,
82}
83
84impl SessionMode {
85    /// Convert to MNN constant.
86    pub(crate) fn to_mnn(self) -> i32 {
87        match self {
88            SessionMode::Debug => mnn_rs_sys::MNN_SESSION_MODE_DEBUG,
89            SessionMode::Release => mnn_rs_sys::MNN_SESSION_MODE_RELEASE,
90            SessionMode::InputInside => mnn_rs_sys::MNN_SESSION_MODE_INPUT_INSIDE,
91            SessionMode::InputUser => mnn_rs_sys::MNN_SESSION_MODE_INPUT_USER,
92            SessionMode::OutputInside => mnn_rs_sys::MNN_SESSION_MODE_OUTPUT_INSIDE,
93            SessionMode::OutputUser => mnn_rs_sys::MNN_SESSION_MODE_OUTPUT_USER,
94            SessionMode::ResizeDirect => mnn_rs_sys::MNN_SESSION_MODE_RESIZE_DIRECT,
95            SessionMode::ResizeDefer => mnn_rs_sys::MNN_SESSION_MODE_RESIZE_DEFER,
96            SessionMode::BackendFix => mnn_rs_sys::MNN_SESSION_MODE_BACKEND_FIX,
97            SessionMode::BackendAuto => mnn_rs_sys::MNN_SESSION_MODE_BACKEND_AUTO,
98        }
99    }
100}
101
102/// Power usage mode.
103#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
104pub enum PowerMode {
105    /// Normal power usage (balanced)
106    #[default]
107    Normal,
108
109    /// Low power mode (may impact performance)
110    Low,
111
112    /// High power mode for maximum performance
113    High,
114}
115
116/// Precision mode for inference.
117#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
118pub enum PrecisionMode {
119    /// Normal precision (default)
120    #[default]
121    Normal,
122
123    /// Low precision (faster, may reduce accuracy)
124    Low,
125
126    /// High precision
127    High,
128
129    /// Low precision with BF16
130    LowBf16,
131}
132
133/// Schedule configuration for creating sessions.
134///
135/// This configuration determines how MNN will execute the model,
136/// including backend selection, thread count, and optimization settings.
137#[derive(Debug, Clone)]
138pub struct ScheduleConfig {
139    /// Backend configuration
140    pub backend_config: BackendConfig,
141
142    /// Number of threads for CPU backend (default: 4)
143    pub num_threads: u32,
144}
145
146impl Default for ScheduleConfig {
147    fn default() -> Self {
148        Self {
149            backend_config: BackendConfig::default(),
150            num_threads: 4,
151        }
152    }
153}
154
155impl ScheduleConfig {
156    /// Create a new schedule config with default settings.
157    pub fn new() -> Self {
158        Self::default()
159    }
160
161    /// Create a schedule config for CPU backend.
162    pub fn cpu() -> Self {
163        Self::default()
164    }
165
166    /// Create a schedule config for a specific backend type.
167    pub fn with_backend(backend: BackendType) -> Self {
168        Self {
169            backend_config: BackendConfig::new(backend),
170            ..Default::default()
171        }
172    }
173
174    /// Set the backend type.
175    pub fn backend(mut self, backend: BackendType) -> Self {
176        self.backend_config.backend_type = backend;
177        self
178    }
179
180    /// Set the number of threads for CPU backend.
181    pub fn num_threads(mut self, threads: u32) -> Self {
182        self.num_threads = threads;
183        self
184    }
185
186    /// Set the memory mode.
187    pub fn memory_mode(mut self, mode: MemoryMode) -> Self {
188        self.backend_config.memory_mode = mode;
189        self
190    }
191
192    /// Set the power mode.
193    pub fn power_mode(mut self, mode: PowerMode) -> Self {
194        self.backend_config.power_mode = mode;
195        self
196    }
197
198    /// Set the precision mode.
199    pub fn precision_mode(mut self, mode: PrecisionMode) -> Self {
200        self.backend_config.precision_mode = mode;
201        self
202    }
203
204    /// Set the device ID for GPU backends.
205    pub fn device_id(mut self, id: i32) -> Self {
206        self.backend_config.device_id = Some(id);
207        self
208    }
209}
210
211/// Builder for creating schedule configurations.
212///
213/// Provides a fluent interface for constructing [`ScheduleConfig`].
214#[derive(Debug, Default)]
215pub struct ScheduleConfigBuilder {
216    config: ScheduleConfig,
217}
218
219impl ScheduleConfigBuilder {
220    /// Create a new builder.
221    pub fn new() -> Self {
222        Self::default()
223    }
224
225    /// Set the backend type.
226    pub fn backend(mut self, backend: BackendType) -> Self {
227        self.config.backend_config.backend_type = backend;
228        self
229    }
230
231    /// Set the number of threads.
232    pub fn num_threads(mut self, threads: u32) -> Self {
233        self.config.num_threads = threads;
234        self
235    }
236
237    /// Set memory mode.
238    pub fn memory_mode(mut self, mode: MemoryMode) -> Self {
239        self.config.backend_config.memory_mode = mode;
240        self
241    }
242
243    /// Set power mode.
244    pub fn power_mode(mut self, mode: PowerMode) -> Self {
245        self.config.backend_config.power_mode = mode;
246        self
247    }
248
249    /// Set precision mode.
250    pub fn precision_mode(mut self, mode: PrecisionMode) -> Self {
251        self.config.backend_config.precision_mode = mode;
252        self
253    }
254
255    /// Set device ID.
256    pub fn device_id(mut self, id: i32) -> Self {
257        self.config.backend_config.device_id = Some(id);
258        self
259    }
260
261    /// Build the schedule config.
262    pub fn build(self) -> ScheduleConfig {
263        self.config
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    #[test]
272    fn test_default_config() {
273        let config = ScheduleConfig::default();
274        assert_eq!(config.num_threads, 4);
275    }
276
277    #[test]
278    fn test_builder() {
279        let config = ScheduleConfigBuilder::new()
280            .backend(BackendType::CPU)
281            .num_threads(8)
282            .memory_mode(MemoryMode::Low)
283            .precision_mode(PrecisionMode::High)
284            .build();
285
286        assert_eq!(config.num_threads, 8);
287        assert_eq!(config.backend_config.memory_mode, MemoryMode::Low);
288        assert_eq!(config.backend_config.precision_mode, PrecisionMode::High);
289    }
290}