1use crate::backend::{BackendConfig, BackendType};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
10pub enum DataFormat {
11 #[default]
14 Nhwc,
15
16 Nchw,
19
20 Nc4hw4,
22}
23
24impl DataFormat {
25 pub fn name(&self) -> &'static str {
27 match self {
28 DataFormat::Nhwc => "NHWC",
29 DataFormat::Nchw => "NCHW",
30 DataFormat::Nc4hw4 => "NC4HW4",
31 }
32 }
33
34 pub(crate) fn to_mnn(&self) -> i32 {
36 match self {
37 DataFormat::Nhwc => mnn_rs_sys::MNN_DATA_FORMAT_NHWC,
38 DataFormat::Nchw => mnn_rs_sys::MNN_DATA_FORMAT_NCHW,
39 DataFormat::Nc4hw4 => mnn_rs_sys::MNN_DATA_FORMAT_NC4HW4,
40 }
41 }
42}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
46pub enum MemoryMode {
47 #[default]
49 Normal,
50
51 Low,
53
54 High,
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
60pub enum SessionMode {
61 #[default]
63 Debug,
64 Release,
66 InputInside,
68 InputUser,
70 OutputInside,
72 OutputUser,
74 ResizeDirect,
76 ResizeDefer,
78 BackendFix,
80 BackendAuto,
82}
83
84impl SessionMode {
85 pub(crate) fn to_mnn(self) -> i32 {
87 match self {
88 SessionMode::Debug => mnn_rs_sys::MNN_SESSION_MODE_DEBUG,
89 SessionMode::Release => mnn_rs_sys::MNN_SESSION_MODE_RELEASE,
90 SessionMode::InputInside => mnn_rs_sys::MNN_SESSION_MODE_INPUT_INSIDE,
91 SessionMode::InputUser => mnn_rs_sys::MNN_SESSION_MODE_INPUT_USER,
92 SessionMode::OutputInside => mnn_rs_sys::MNN_SESSION_MODE_OUTPUT_INSIDE,
93 SessionMode::OutputUser => mnn_rs_sys::MNN_SESSION_MODE_OUTPUT_USER,
94 SessionMode::ResizeDirect => mnn_rs_sys::MNN_SESSION_MODE_RESIZE_DIRECT,
95 SessionMode::ResizeDefer => mnn_rs_sys::MNN_SESSION_MODE_RESIZE_DEFER,
96 SessionMode::BackendFix => mnn_rs_sys::MNN_SESSION_MODE_BACKEND_FIX,
97 SessionMode::BackendAuto => mnn_rs_sys::MNN_SESSION_MODE_BACKEND_AUTO,
98 }
99 }
100}
101
102#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
104pub enum PowerMode {
105 #[default]
107 Normal,
108
109 Low,
111
112 High,
114}
115
116#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
118pub enum PrecisionMode {
119 #[default]
121 Normal,
122
123 Low,
125
126 High,
128
129 LowBf16,
131}
132
133#[derive(Debug, Clone)]
138pub struct ScheduleConfig {
139 pub backend_config: BackendConfig,
141
142 pub num_threads: u32,
144}
145
146impl Default for ScheduleConfig {
147 fn default() -> Self {
148 Self {
149 backend_config: BackendConfig::default(),
150 num_threads: 4,
151 }
152 }
153}
154
155impl ScheduleConfig {
156 pub fn new() -> Self {
158 Self::default()
159 }
160
161 pub fn cpu() -> Self {
163 Self::default()
164 }
165
166 pub fn with_backend(backend: BackendType) -> Self {
168 Self {
169 backend_config: BackendConfig::new(backend),
170 ..Default::default()
171 }
172 }
173
174 pub fn backend(mut self, backend: BackendType) -> Self {
176 self.backend_config.backend_type = backend;
177 self
178 }
179
180 pub fn num_threads(mut self, threads: u32) -> Self {
182 self.num_threads = threads;
183 self
184 }
185
186 pub fn memory_mode(mut self, mode: MemoryMode) -> Self {
188 self.backend_config.memory_mode = mode;
189 self
190 }
191
192 pub fn power_mode(mut self, mode: PowerMode) -> Self {
194 self.backend_config.power_mode = mode;
195 self
196 }
197
198 pub fn precision_mode(mut self, mode: PrecisionMode) -> Self {
200 self.backend_config.precision_mode = mode;
201 self
202 }
203
204 pub fn device_id(mut self, id: i32) -> Self {
206 self.backend_config.device_id = Some(id);
207 self
208 }
209}
210
211#[derive(Debug, Default)]
215pub struct ScheduleConfigBuilder {
216 config: ScheduleConfig,
217}
218
219impl ScheduleConfigBuilder {
220 pub fn new() -> Self {
222 Self::default()
223 }
224
225 pub fn backend(mut self, backend: BackendType) -> Self {
227 self.config.backend_config.backend_type = backend;
228 self
229 }
230
231 pub fn num_threads(mut self, threads: u32) -> Self {
233 self.config.num_threads = threads;
234 self
235 }
236
237 pub fn memory_mode(mut self, mode: MemoryMode) -> Self {
239 self.config.backend_config.memory_mode = mode;
240 self
241 }
242
243 pub fn power_mode(mut self, mode: PowerMode) -> Self {
245 self.config.backend_config.power_mode = mode;
246 self
247 }
248
249 pub fn precision_mode(mut self, mode: PrecisionMode) -> Self {
251 self.config.backend_config.precision_mode = mode;
252 self
253 }
254
255 pub fn device_id(mut self, id: i32) -> Self {
257 self.config.backend_config.device_id = Some(id);
258 self
259 }
260
261 pub fn build(self) -> ScheduleConfig {
263 self.config
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[test]
272 fn test_default_config() {
273 let config = ScheduleConfig::default();
274 assert_eq!(config.num_threads, 4);
275 }
276
277 #[test]
278 fn test_builder() {
279 let config = ScheduleConfigBuilder::new()
280 .backend(BackendType::CPU)
281 .num_threads(8)
282 .memory_mode(MemoryMode::Low)
283 .precision_mode(PrecisionMode::High)
284 .build();
285
286 assert_eq!(config.num_threads, 8);
287 assert_eq!(config.backend_config.memory_mode, MemoryMode::Low);
288 assert_eq!(config.backend_config.precision_mode, PrecisionMode::High);
289 }
290}