Skip to main content

parakeet_rs/
execution.rs

1use std::{fmt, rc::Rc};
2
3use crate::error::Result;
4use ort::session::builder::SessionBuilder;
5
6// Hardware acceleration options. CPU is default and most reliable.
7// GPU providers (CUDA, TensorRT, MIGraphX) offer 5-10x speedup but require specific hardware.
8// All GPU providers automatically fall back to CPU if they fail.
9//
10// Note: CoreML currently fails with this model due to unsupported operations.
11// WebGPU is experimental and may produce incorrect results.
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
13pub enum ExecutionProvider {
14    #[default]
15    Cpu,
16    #[cfg(feature = "cuda")]
17    Cuda,
18    #[cfg(feature = "tensorrt")]
19    TensorRT,
20    #[cfg(feature = "coreml")]
21    CoreML,
22    #[cfg(feature = "directml")]
23    DirectML,
24    #[cfg(feature = "migraphx")]
25    MIGraphX,
26    #[cfg(feature = "openvino")]
27    OpenVINO,
28    #[cfg(feature = "webgpu")]
29    WebGPU,
30    #[cfg(feature = "nnapi")]
31    NNAPI,
32}
33
34#[derive(Clone)]
35pub struct ModelConfig {
36    pub execution_provider: ExecutionProvider,
37    pub intra_threads: usize,
38    pub inter_threads: usize,
39    pub configure: Option<Rc<dyn Fn(SessionBuilder) -> ort::Result<SessionBuilder>>>,
40}
41
42impl fmt::Debug for ModelConfig {
43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44        f.debug_struct("ModelConfig")
45            .field("execution_provider", &self.execution_provider)
46            .field("intra_threads", &self.intra_threads)
47            .field("inter_threads", &self.inter_threads)
48            .field(
49                "configure",
50                &if self.configure.is_some() {
51                    "<fn>"
52                } else {
53                    "None"
54                },
55            )
56            .finish()
57    }
58}
59
60impl Default for ModelConfig {
61    fn default() -> Self {
62        Self {
63            execution_provider: ExecutionProvider::default(),
64            intra_threads: 4,
65            inter_threads: 1,
66            configure: None,
67        }
68    }
69}
70
71impl ModelConfig {
72    pub fn new() -> Self {
73        Self::default()
74    }
75
76    pub fn with_execution_provider(mut self, provider: ExecutionProvider) -> Self {
77        self.execution_provider = provider;
78        self
79    }
80
81    pub fn with_intra_threads(mut self, threads: usize) -> Self {
82        self.intra_threads = threads;
83        self
84    }
85
86    pub fn with_inter_threads(mut self, threads: usize) -> Self {
87        self.inter_threads = threads;
88        self
89    }
90
91    pub fn with_custom_configure(
92        mut self,
93        configure: impl Fn(SessionBuilder) -> ort::Result<SessionBuilder> + 'static,
94    ) -> Self {
95        self.configure = Some(Rc::new(configure));
96        self
97    }
98
99    pub(crate) fn apply_to_session_builder(
100        &self,
101        builder: SessionBuilder,
102    ) -> Result<SessionBuilder> {
103        #[cfg(any(
104            feature = "cuda",
105            feature = "tensorrt",
106            feature = "coreml",
107            feature = "directml",
108            feature = "migraphx",
109            feature = "openvino",
110            feature = "webgpu",
111            feature = "nnapi"
112        ))]
113        use ort::ep::CPU as CPUExecutionProvider;
114        use ort::session::builder::GraphOptimizationLevel;
115
116        let mut builder = builder
117            .with_optimization_level(GraphOptimizationLevel::Level3)?
118            .with_intra_threads(self.intra_threads)?
119            .with_inter_threads(self.inter_threads)?;
120
121        builder = match self.execution_provider {
122            ExecutionProvider::Cpu => builder,
123
124            #[cfg(feature = "cuda")]
125            ExecutionProvider::Cuda => builder.with_execution_providers([
126                ort::ep::CUDA::default().build(),
127                CPUExecutionProvider::default().build().error_on_failure(),
128            ])?,
129
130            #[cfg(feature = "tensorrt")]
131            ExecutionProvider::TensorRT => builder.with_execution_providers([
132                ort::ep::TensorRT::default().build(),
133                CPUExecutionProvider::default().build().error_on_failure(),
134            ])?,
135
136            #[cfg(feature = "coreml")]
137            ExecutionProvider::CoreML => {
138                use ort::ep::coreml::{ComputeUnits, CoreML};
139                builder.with_execution_providers([
140                    CoreML::default()
141                        .with_compute_units(ComputeUnits::CPUAndGPU)
142                        .build(),
143                    CPUExecutionProvider::default().build().error_on_failure(),
144                ])?
145            }
146
147            #[cfg(feature = "directml")]
148            ExecutionProvider::DirectML => builder.with_execution_providers([
149                ort::ep::DirectML::default().build(),
150                CPUExecutionProvider::default().build().error_on_failure(),
151            ])?,
152
153            #[cfg(feature = "migraphx")]
154            ExecutionProvider::MIGraphX => builder.with_execution_providers([
155                ort::ep::MIGraphX::default().build(),
156                CPUExecutionProvider::default().build().error_on_failure(),
157            ])?,
158
159            #[cfg(feature = "openvino")]
160            ExecutionProvider::OpenVINO => builder.with_execution_providers([
161                ort::ep::OpenVINO::default().build(),
162                CPUExecutionProvider::default().build().error_on_failure(),
163            ])?,
164
165            #[cfg(feature = "webgpu")]
166            ExecutionProvider::WebGPU => builder.with_execution_providers([
167                ort::ep::WebGPU::default().build(),
168                CPUExecutionProvider::default().build().error_on_failure(),
169            ])?,
170
171            #[cfg(feature = "nnapi")]
172            ExecutionProvider::NNAPI => builder.with_execution_providers([
173                ort::ep::NNAPI::default().build(),
174                CPUExecutionProvider::default().build().error_on_failure(),
175            ])?,
176        };
177
178        if let Some(configure) = self.configure.as_ref() {
179            builder = configure(builder)?;
180        }
181
182        Ok(builder)
183    }
184}