Skip to main content

parakeet_rs/
execution.rs

1use std::path::PathBuf;
2use std::{fmt, rc::Rc};
3
4use crate::error::Result;
5use ort::session::builder::SessionBuilder;
6
7// Hardware acceleration options. CPU is default and most reliable.
8// GPU providers (CUDA, TensorRT, MIGraphX) offer 5-10x speedup but require specific hardware.
9// All GPU providers automatically fall back to CPU if they fail.
10//
11// Note: CoreML EP currently runs slower than CPU for Sortformer/Parakeet models because
12// the ONNX graphs have dynamic input shapes, preventing CoreML from building optimised
13// execution plans for ANE/GPU. CoreML claims nodes but runs them on CPU with overhead.
14//
15// WebGPU is experimental and may produce incorrect results.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
17pub enum ExecutionProvider {
18    #[default]
19    Cpu,
20    #[cfg(feature = "cuda")]
21    Cuda,
22    #[cfg(feature = "tensorrt")]
23    TensorRT,
24    #[cfg(feature = "coreml")]
25    CoreML,
26    #[cfg(feature = "directml")]
27    DirectML,
28    #[cfg(feature = "migraphx")]
29    MIGraphX,
30    #[cfg(feature = "openvino")]
31    OpenVINO,
32    #[cfg(feature = "webgpu")]
33    WebGPU,
34    #[cfg(feature = "nnapi")]
35    NNAPI,
36}
37
38#[derive(Clone)]
39pub struct ModelConfig {
40    pub execution_provider: ExecutionProvider,
41    pub intra_threads: usize,
42    pub inter_threads: usize,
43    pub configure: Option<Rc<dyn Fn(SessionBuilder) -> ort::Result<SessionBuilder>>>,
44    /// Optional cache directory for compiled CoreML models. When set, avoids
45    /// recompiling the ONNX-to-CoreML conversion on each session load (~5s).
46    /// Only used when execution_provider is CoreML.
47    pub coreml_cache_dir: Option<PathBuf>,
48}
49
50impl fmt::Debug for ModelConfig {
51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52        f.debug_struct("ModelConfig")
53            .field("execution_provider", &self.execution_provider)
54            .field("intra_threads", &self.intra_threads)
55            .field("inter_threads", &self.inter_threads)
56            .field(
57                "configure",
58                &if self.configure.is_some() {
59                    "<fn>"
60                } else {
61                    "None"
62                },
63            )
64            .field("coreml_cache_dir", &self.coreml_cache_dir)
65            .finish()
66    }
67}
68
69impl Default for ModelConfig {
70    fn default() -> Self {
71        Self {
72            execution_provider: ExecutionProvider::default(),
73            intra_threads: 4,
74            inter_threads: 1,
75            configure: None,
76            coreml_cache_dir: None,
77        }
78    }
79}
80
81impl ModelConfig {
82    pub fn new() -> Self {
83        Self::default()
84    }
85
86    pub fn with_execution_provider(mut self, provider: ExecutionProvider) -> Self {
87        self.execution_provider = provider;
88        self
89    }
90
91    pub fn with_intra_threads(mut self, threads: usize) -> Self {
92        self.intra_threads = threads;
93        self
94    }
95
96    pub fn with_inter_threads(mut self, threads: usize) -> Self {
97        self.inter_threads = threads;
98        self
99    }
100
101    pub fn with_custom_configure(
102        mut self,
103        configure: impl Fn(SessionBuilder) -> ort::Result<SessionBuilder> + 'static,
104    ) -> Self {
105        self.configure = Some(Rc::new(configure));
106        self
107    }
108
109    /// Set cache directory for compiled CoreML models.
110    /// Avoids ~5s recompilation on each session load.
111    pub fn with_coreml_cache_dir(mut self, path: impl Into<PathBuf>) -> Self {
112        self.coreml_cache_dir = Some(path.into());
113        self
114    }
115
116    pub(crate) fn apply_to_session_builder(
117        &self,
118        builder: SessionBuilder,
119    ) -> Result<SessionBuilder> {
120        #[cfg(any(
121            feature = "cuda",
122            feature = "tensorrt",
123            feature = "coreml",
124            feature = "directml",
125            feature = "migraphx",
126            feature = "openvino",
127            feature = "webgpu",
128            feature = "nnapi"
129        ))]
130        use ort::ep::CPU as CPUExecutionProvider;
131        use ort::session::builder::GraphOptimizationLevel;
132
133        let mut builder = builder
134            .with_optimization_level(GraphOptimizationLevel::Level3)?
135            .with_intra_threads(self.intra_threads)?
136            .with_inter_threads(self.inter_threads)?;
137
138        builder = match self.execution_provider {
139            ExecutionProvider::Cpu => builder,
140
141            #[cfg(feature = "cuda")]
142            ExecutionProvider::Cuda => builder.with_execution_providers([
143                ort::ep::CUDA::default().build(),
144                CPUExecutionProvider::default().build().error_on_failure(),
145            ])?,
146
147            #[cfg(feature = "tensorrt")]
148            ExecutionProvider::TensorRT => builder.with_execution_providers([
149                ort::ep::TensorRT::default().build(),
150                CPUExecutionProvider::default().build().error_on_failure(),
151            ])?,
152
153            #[cfg(feature = "coreml")]
154            ExecutionProvider::CoreML => {
155                use ort::ep::coreml::{ComputeUnits, CoreML};
156                let mut coreml = CoreML::default().with_compute_units(ComputeUnits::CPUAndGPU);
157
158                if let Some(cache_dir) = &self.coreml_cache_dir {
159                    coreml = coreml.with_model_cache_dir(cache_dir.to_string_lossy());
160                }
161
162                builder.with_execution_providers([
163                    coreml.build(),
164                    CPUExecutionProvider::default().build().error_on_failure(),
165                ])?
166            }
167
168            #[cfg(feature = "directml")]
169            ExecutionProvider::DirectML => builder.with_execution_providers([
170                ort::ep::DirectML::default().build(),
171                CPUExecutionProvider::default().build().error_on_failure(),
172            ])?,
173
174            #[cfg(feature = "migraphx")]
175            ExecutionProvider::MIGraphX => builder.with_execution_providers([
176                ort::ep::MIGraphX::default().build(),
177                CPUExecutionProvider::default().build().error_on_failure(),
178            ])?,
179
180            #[cfg(feature = "openvino")]
181            ExecutionProvider::OpenVINO => builder.with_execution_providers([
182                ort::ep::OpenVINO::default().build(),
183                CPUExecutionProvider::default().build().error_on_failure(),
184            ])?,
185
186            #[cfg(feature = "webgpu")]
187            ExecutionProvider::WebGPU => builder.with_execution_providers([
188                ort::ep::WebGPU::default().build(),
189                CPUExecutionProvider::default().build().error_on_failure(),
190            ])?,
191
192            #[cfg(feature = "nnapi")]
193            ExecutionProvider::NNAPI => builder.with_execution_providers([
194                ort::ep::NNAPI::default().build(),
195                CPUExecutionProvider::default().build().error_on_failure(),
196            ])?,
197        };
198
199        if let Some(configure) = self.configure.as_ref() {
200            builder = configure(builder)?;
201        }
202
203        Ok(builder)
204    }
205}