1use std::{fmt, rc::Rc};
2
3use crate::error::Result;
4use ort::session::builder::SessionBuilder;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
13pub enum ExecutionProvider {
14 #[default]
15 Cpu,
16 #[cfg(feature = "cuda")]
17 Cuda,
18 #[cfg(feature = "tensorrt")]
19 TensorRT,
20 #[cfg(feature = "coreml")]
21 CoreML,
22 #[cfg(feature = "directml")]
23 DirectML,
24 #[cfg(feature = "migraphx")]
25 MIGraphX,
26 #[cfg(feature = "openvino")]
27 OpenVINO,
28 #[cfg(feature = "webgpu")]
29 WebGPU,
30 #[cfg(feature = "nnapi")]
31 NNAPI,
32}
33
34#[derive(Clone)]
35pub struct ModelConfig {
36 pub execution_provider: ExecutionProvider,
37 pub intra_threads: usize,
38 pub inter_threads: usize,
39 pub configure: Option<Rc<dyn Fn(SessionBuilder) -> ort::Result<SessionBuilder>>>,
40}
41
42impl fmt::Debug for ModelConfig {
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 f.debug_struct("ModelConfig")
45 .field("execution_provider", &self.execution_provider)
46 .field("intra_threads", &self.intra_threads)
47 .field("inter_threads", &self.inter_threads)
48 .field(
49 "configure",
50 &if self.configure.is_some() {
51 "<fn>"
52 } else {
53 "None"
54 },
55 )
56 .finish()
57 }
58}
59
60impl Default for ModelConfig {
61 fn default() -> Self {
62 Self {
63 execution_provider: ExecutionProvider::default(),
64 intra_threads: 4,
65 inter_threads: 1,
66 configure: None,
67 }
68 }
69}
70
71impl ModelConfig {
72 pub fn new() -> Self {
73 Self::default()
74 }
75
76 pub fn with_execution_provider(mut self, provider: ExecutionProvider) -> Self {
77 self.execution_provider = provider;
78 self
79 }
80
81 pub fn with_intra_threads(mut self, threads: usize) -> Self {
82 self.intra_threads = threads;
83 self
84 }
85
86 pub fn with_inter_threads(mut self, threads: usize) -> Self {
87 self.inter_threads = threads;
88 self
89 }
90
91 pub fn with_custom_configure(
92 mut self,
93 configure: impl Fn(SessionBuilder) -> ort::Result<SessionBuilder> + 'static,
94 ) -> Self {
95 self.configure = Some(Rc::new(configure));
96 self
97 }
98
99 pub(crate) fn apply_to_session_builder(
100 &self,
101 builder: SessionBuilder,
102 ) -> Result<SessionBuilder> {
103 #[cfg(any(
104 feature = "cuda",
105 feature = "tensorrt",
106 feature = "coreml",
107 feature = "directml",
108 feature = "migraphx",
109 feature = "openvino",
110 feature = "webgpu",
111 feature = "nnapi"
112 ))]
113 use ort::ep::CPU as CPUExecutionProvider;
114 use ort::session::builder::GraphOptimizationLevel;
115
116 let mut builder = builder
117 .with_optimization_level(GraphOptimizationLevel::Level3)?
118 .with_intra_threads(self.intra_threads)?
119 .with_inter_threads(self.inter_threads)?;
120
121 builder = match self.execution_provider {
122 ExecutionProvider::Cpu => builder,
123
124 #[cfg(feature = "cuda")]
125 ExecutionProvider::Cuda => builder.with_execution_providers([
126 ort::ep::CUDA::default().build(),
127 CPUExecutionProvider::default().build().error_on_failure(),
128 ])?,
129
130 #[cfg(feature = "tensorrt")]
131 ExecutionProvider::TensorRT => builder.with_execution_providers([
132 ort::ep::TensorRT::default().build(),
133 CPUExecutionProvider::default().build().error_on_failure(),
134 ])?,
135
136 #[cfg(feature = "coreml")]
137 ExecutionProvider::CoreML => {
138 use ort::ep::coreml::{ComputeUnits, CoreML};
139 builder.with_execution_providers([
140 CoreML::default()
141 .with_compute_units(ComputeUnits::CPUAndGPU)
142 .build(),
143 CPUExecutionProvider::default().build().error_on_failure(),
144 ])?
145 }
146
147 #[cfg(feature = "directml")]
148 ExecutionProvider::DirectML => builder.with_execution_providers([
149 ort::ep::DirectML::default().build(),
150 CPUExecutionProvider::default().build().error_on_failure(),
151 ])?,
152
153 #[cfg(feature = "migraphx")]
154 ExecutionProvider::MIGraphX => builder.with_execution_providers([
155 ort::ep::MIGraphX::default().build(),
156 CPUExecutionProvider::default().build().error_on_failure(),
157 ])?,
158
159 #[cfg(feature = "openvino")]
160 ExecutionProvider::OpenVINO => builder.with_execution_providers([
161 ort::ep::OpenVINO::default().build(),
162 CPUExecutionProvider::default().build().error_on_failure(),
163 ])?,
164
165 #[cfg(feature = "webgpu")]
166 ExecutionProvider::WebGPU => builder.with_execution_providers([
167 ort::ep::WebGPU::default().build(),
168 CPUExecutionProvider::default().build().error_on_failure(),
169 ])?,
170
171 #[cfg(feature = "nnapi")]
172 ExecutionProvider::NNAPI => builder.with_execution_providers([
173 ort::ep::NNAPI::default().build(),
174 CPUExecutionProvider::default().build().error_on_failure(),
175 ])?,
176 };
177
178 if let Some(configure) = self.configure.as_ref() {
179 builder = configure(builder)?;
180 }
181
182 Ok(builder)
183 }
184}