1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
use alloc::string::ToString;
use super::{ArenaExtendStrategy, ExecutionProvider, ExecutionProviderOptions, RegisterError};
use crate::{error::Result, session::builder::SessionBuilder};
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum PrecisionMode {
/// Convert to float32 first according to operator implementation
ForceFP32,
/// Convert to float16 when float16 and float32 are both supported
ForceFP16,
/// Convert to float16 when float32 is not supported
AllowFP32ToFP16,
/// Keep dtypes as is
MustKeepOrigin,
/// Allow mixed precision
AllowMixedPrecision
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum ImplementationMode {
/// Prefer high precision, potentially at the cost of some performance.
HighPrecision,
/// Prefer high performance, potentially with lower accuracy.
HighPerformance
}
/// [CANN execution provider](https://onnxruntime.ai/docs/execution-providers/community-maintained/CANN-ExecutionProvider.html)
/// for hardware acceleration using Huawei Ascend AI devices.
#[derive(Default, Debug, Clone)]
pub struct CANN {
options: ExecutionProviderOptions
}
super::impl_ep!(arbitrary; CANN);
impl CANN {
/// Configures which device the EP should use.
///
/// ```
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = ep::CANN::default().with_device_id(0).build();
/// # Ok(())
/// # }
/// ```
#[must_use]
pub fn with_device_id(mut self, device_id: i32) -> Self {
self.options.set("device_id", device_id.to_string());
self
}
/// Configure the size limit of the device memory arena in bytes. This size limit is only for the execution
/// provider’s arena; the total device memory usage may be higher.
///
/// ```
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = ep::CANN::default().with_memory_limit(2 * 1024 * 1024 * 1024).build();
/// # Ok(())
/// # }
/// ```
#[must_use]
pub fn with_memory_limit(mut self, limit: usize) -> Self {
self.options.set("npu_mem_limit", limit.to_string());
self
}
/// Configure the strategy for extending the device's memory arena.
///
/// ```
/// # use ort::{ep::{self, ArenaExtendStrategy}, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = ep::CANN::default()
/// .with_arena_extend_strategy(ArenaExtendStrategy::SameAsRequested)
/// .build();
/// # Ok(())
/// # }
/// ```
#[must_use]
pub fn with_arena_extend_strategy(mut self, strategy: ArenaExtendStrategy) -> Self {
self.options.set(
"arena_extend_strategy",
match strategy {
ArenaExtendStrategy::NextPowerOfTwo => "kNextPowerOfTwo",
ArenaExtendStrategy::SameAsRequested => "kSameAsRequested"
}
);
self
}
/// Configure whether to use the graph inference engine to speed up performance. The recommended and default setting
/// is `true`. If `false`, it will fall back to the single-operator inference engine.
///
/// ```
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = ep::CANN::default().with_cann_graph(true).build();
/// # Ok(())
/// # }
/// ```
#[must_use]
pub fn with_cann_graph(mut self, enable: bool) -> Self {
self.options.set("enable_cann_graph", if enable { "1" } else { "0" });
self
}
/// Configure whether to dump the subgraph into ONNX format for analysis of subgraph segmentation.
///
/// ```
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = ep::CANN::default().with_dump_graphs(true).build();
/// # Ok(())
/// # }
/// ```
#[must_use]
pub fn with_dump_graphs(mut self, enable: bool) -> Self {
self.options.set("dump_graphs", if enable { "1" } else { "0" });
self
}
/// Configure whether to dump the offline model to an `.om` file.
///
/// ```
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = ep::CANN::default().with_dump_om_model(true).build();
/// # Ok(())
/// # }
/// ```
#[must_use]
pub fn with_dump_om_model(mut self, enable: bool) -> Self {
self.options.set("dump_om_model", if enable { "1" } else { "0" });
self
}
/// Configure the precision mode; see [`PrecisionMode`].
///
/// ```
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = ep::CANN::default().with_precision_mode(ep::cann::PrecisionMode::ForceFP16).build();
/// # Ok(())
/// # }
/// ```
#[must_use]
pub fn with_precision_mode(mut self, mode: PrecisionMode) -> Self {
self.options.set(
"precision_mode",
match mode {
PrecisionMode::ForceFP32 => "force_fp32",
PrecisionMode::ForceFP16 => "force_fp16",
PrecisionMode::AllowFP32ToFP16 => "allow_fp32_to_fp16",
PrecisionMode::MustKeepOrigin => "must_keep_origin_dtype",
PrecisionMode::AllowMixedPrecision => "allow_mix_precision"
}
);
self
}
/// Configure the implementation mode for operators. Some CANN operators can have both high-precision and
/// high-performance implementations.
///
/// ```
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = ep::CANN::default()
/// .with_implementation_mode(ep::cann::ImplementationMode::HighPerformance)
/// .build();
/// # Ok(())
/// # }
/// ```
#[must_use]
pub fn with_implementation_mode(mut self, mode: ImplementationMode) -> Self {
self.options.set(
"op_select_impl_mode",
match mode {
ImplementationMode::HighPrecision => "high_precision",
ImplementationMode::HighPerformance => "high_performance"
}
);
self
}
/// Configure the list of operators which use the mode specified by
/// [`CANN::with_implementation_mode`].
///
/// ```
/// # use ort::{ep, session::Session};
/// # fn main() -> ort::Result<()> {
/// let ep = ep::CANN::default().with_implementation_mode_oplist("LayerNorm,Gelu").build();
/// # Ok(())
/// # }
/// ```
#[must_use]
pub fn with_implementation_mode_oplist(mut self, list: impl ToString) -> Self {
self.options.set("optypelist_for_implmode", list.to_string());
self
}
}
impl ExecutionProvider for CANN {
fn name(&self) -> &'static str {
"CANNExecutionProvider"
}
fn supported_by_platform(&self) -> bool {
cfg!(all(target_os = "linux", any(target_arch = "aarch64", target_arch = "x86_64")))
}
#[allow(unused, unreachable_code)]
fn register(&self, session_builder: &mut SessionBuilder) -> Result<(), RegisterError> {
#[cfg(any(feature = "load-dynamic", feature = "cann"))]
{
use core::ptr;
use crate::{AsPointer, ortsys, util};
let mut cann_options: *mut ort_sys::OrtCANNProviderOptions = ptr::null_mut();
ortsys![unsafe CreateCANNProviderOptions(&mut cann_options)?];
let _guard = util::run_on_drop(|| {
ortsys![unsafe ReleaseCANNProviderOptions(cann_options)];
});
let ffi_options = self.options.to_ffi();
ortsys![unsafe UpdateCANNProviderOptions(
cann_options,
ffi_options.key_ptrs(),
ffi_options.value_ptrs(),
ffi_options.len()
)?];
ortsys![unsafe SessionOptionsAppendExecutionProvider_CANN(session_builder.ptr_mut(), cann_options)?];
return Ok(());
}
Err(RegisterError::MissingFeature)
}
}