1use std::path::PathBuf;
2use std::{fmt, rc::Rc};
3
4use crate::error::Result;
5use ort::session::builder::SessionBuilder;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
17pub enum ExecutionProvider {
18 #[default]
19 Cpu,
20 #[cfg(feature = "cuda")]
21 Cuda,
22 #[cfg(feature = "tensorrt")]
23 TensorRT,
24 #[cfg(feature = "coreml")]
25 CoreML,
26 #[cfg(feature = "directml")]
27 DirectML,
28 #[cfg(feature = "migraphx")]
29 MIGraphX,
30 #[cfg(feature = "openvino")]
31 OpenVINO,
32 #[cfg(feature = "webgpu")]
33 WebGPU,
34 #[cfg(feature = "nnapi")]
35 NNAPI,
36}
37
38#[derive(Clone)]
39pub struct ModelConfig {
40 pub execution_provider: ExecutionProvider,
41 pub intra_threads: usize,
42 pub inter_threads: usize,
43 pub configure: Option<Rc<dyn Fn(SessionBuilder) -> ort::Result<SessionBuilder>>>,
44 pub coreml_cache_dir: Option<PathBuf>,
48}
49
50impl fmt::Debug for ModelConfig {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 f.debug_struct("ModelConfig")
53 .field("execution_provider", &self.execution_provider)
54 .field("intra_threads", &self.intra_threads)
55 .field("inter_threads", &self.inter_threads)
56 .field(
57 "configure",
58 &if self.configure.is_some() {
59 "<fn>"
60 } else {
61 "None"
62 },
63 )
64 .field("coreml_cache_dir", &self.coreml_cache_dir)
65 .finish()
66 }
67}
68
69impl Default for ModelConfig {
70 fn default() -> Self {
71 Self {
72 execution_provider: ExecutionProvider::default(),
73 intra_threads: 4,
74 inter_threads: 1,
75 configure: None,
76 coreml_cache_dir: None,
77 }
78 }
79}
80
81impl ModelConfig {
82 pub fn new() -> Self {
83 Self::default()
84 }
85
86 pub fn with_execution_provider(mut self, provider: ExecutionProvider) -> Self {
87 self.execution_provider = provider;
88 self
89 }
90
91 pub fn with_intra_threads(mut self, threads: usize) -> Self {
92 self.intra_threads = threads;
93 self
94 }
95
96 pub fn with_inter_threads(mut self, threads: usize) -> Self {
97 self.inter_threads = threads;
98 self
99 }
100
101 pub fn with_custom_configure(
102 mut self,
103 configure: impl Fn(SessionBuilder) -> ort::Result<SessionBuilder> + 'static,
104 ) -> Self {
105 self.configure = Some(Rc::new(configure));
106 self
107 }
108
109 pub fn with_coreml_cache_dir(mut self, path: impl Into<PathBuf>) -> Self {
112 self.coreml_cache_dir = Some(path.into());
113 self
114 }
115
116 pub(crate) fn apply_to_session_builder(
117 &self,
118 builder: SessionBuilder,
119 ) -> Result<SessionBuilder> {
120 #[cfg(any(
121 feature = "cuda",
122 feature = "tensorrt",
123 feature = "coreml",
124 feature = "directml",
125 feature = "migraphx",
126 feature = "openvino",
127 feature = "webgpu",
128 feature = "nnapi"
129 ))]
130 use ort::ep::CPU as CPUExecutionProvider;
131 use ort::session::builder::GraphOptimizationLevel;
132
133 let mut builder = builder
134 .with_optimization_level(GraphOptimizationLevel::Level3)?
135 .with_intra_threads(self.intra_threads)?
136 .with_inter_threads(self.inter_threads)?;
137
138 builder = match self.execution_provider {
139 ExecutionProvider::Cpu => builder,
140
141 #[cfg(feature = "cuda")]
142 ExecutionProvider::Cuda => builder.with_execution_providers([
143 ort::ep::CUDA::default().build(),
144 CPUExecutionProvider::default().build().error_on_failure(),
145 ])?,
146
147 #[cfg(feature = "tensorrt")]
148 ExecutionProvider::TensorRT => builder.with_execution_providers([
149 ort::ep::TensorRT::default().build(),
150 CPUExecutionProvider::default().build().error_on_failure(),
151 ])?,
152
153 #[cfg(feature = "coreml")]
154 ExecutionProvider::CoreML => {
155 use ort::ep::coreml::{ComputeUnits, CoreML};
156 let mut coreml = CoreML::default().with_compute_units(ComputeUnits::CPUAndGPU);
157
158 if let Some(cache_dir) = &self.coreml_cache_dir {
159 coreml = coreml.with_model_cache_dir(cache_dir.to_string_lossy());
160 }
161
162 builder.with_execution_providers([
163 coreml.build(),
164 CPUExecutionProvider::default().build().error_on_failure(),
165 ])?
166 }
167
168 #[cfg(feature = "directml")]
169 ExecutionProvider::DirectML => builder.with_execution_providers([
170 ort::ep::DirectML::default().build(),
171 CPUExecutionProvider::default().build().error_on_failure(),
172 ])?,
173
174 #[cfg(feature = "migraphx")]
175 ExecutionProvider::MIGraphX => builder.with_execution_providers([
176 ort::ep::MIGraphX::default().build(),
177 CPUExecutionProvider::default().build().error_on_failure(),
178 ])?,
179
180 #[cfg(feature = "openvino")]
181 ExecutionProvider::OpenVINO => builder.with_execution_providers([
182 ort::ep::OpenVINO::default().build(),
183 CPUExecutionProvider::default().build().error_on_failure(),
184 ])?,
185
186 #[cfg(feature = "webgpu")]
187 ExecutionProvider::WebGPU => builder.with_execution_providers([
188 ort::ep::WebGPU::default().build(),
189 CPUExecutionProvider::default().build().error_on_failure(),
190 ])?,
191
192 #[cfg(feature = "nnapi")]
193 ExecutionProvider::NNAPI => builder.with_execution_providers([
194 ort::ep::NNAPI::default().build(),
195 CPUExecutionProvider::default().build().error_on_failure(),
196 ])?,
197 };
198
199 if let Some(configure) = self.configure.as_ref() {
200 builder = configure(builder)?;
201 }
202
203 Ok(builder)
204 }
205}