oxionnx/execution_providers.rs
1//! Execution Provider compatibility layer for migration from `ort`.
2//!
3//! `ort` 2.x exposes a rich set of execution provider (EP) types such as
4//! `CUDAExecutionProvider`, `CoreMLExecutionProvider`, etc., that allow
5//! callers to configure hardware acceleration at session build time.
6//!
7//! oxionnx selects its backend at compile time via Cargo feature flags
8//! (`gpu`, `cuda`). These stub types mirror the `ort` EP API surface so
9//! that code written against `ort` can compile against oxionnx with only
10//! a `use` path change — no call-site edits required.
11//!
12//! Every `build()` call returns an [`ExecutionProviderDispatch`] no-op token.
13//! The actual backend selection is governed by the crate's feature flags.
14
15use oxionnx_core::graph::OpKind;
16use std::collections::HashMap;
17
18/// Opaque no-op token returned by EP `.build()` calls.
19///
20/// Passed to [`crate::SessionBuilder::with_execution_providers`], which
21/// accepts but ignores the list.
22#[derive(Debug, Clone, Copy, Default)]
23pub struct ExecutionProviderDispatch;
24
25// ── CPU ─────────────────────────────────────────────────────────────────────
26
27/// CPU execution provider stub (always active in oxionnx).
28#[derive(Debug, Clone, Default)]
29pub struct CPUExecutionProvider;
30
31impl CPUExecutionProvider {
32 /// Finalise configuration and return an [`ExecutionProviderDispatch`].
33 pub fn build(self) -> ExecutionProviderDispatch {
34 ExecutionProviderDispatch
35 }
36}
37
38// ── CUDA ────────────────────────────────────────────────────────────────────
39
40/// CUDA execution provider stub.
41///
42/// When the `cuda` feature is enabled, oxionnx routes eligible ops through
43/// the CUDA backend automatically. This type is accepted at the API level
44/// for ort-compatible code.
45#[derive(Debug, Clone, Default)]
46pub struct CUDAExecutionProvider;
47
48impl CUDAExecutionProvider {
49 /// Finalise configuration and return an [`ExecutionProviderDispatch`].
50 pub fn build(self) -> ExecutionProviderDispatch {
51 ExecutionProviderDispatch
52 }
53}
54
55// ── CoreML ──────────────────────────────────────────────────────────────────
56
57/// Apple CoreML execution provider stub.
58#[derive(Debug, Clone, Default)]
59pub struct CoreMLExecutionProvider;
60
61impl CoreMLExecutionProvider {
62 /// Finalise configuration and return an [`ExecutionProviderDispatch`].
63 pub fn build(self) -> ExecutionProviderDispatch {
64 ExecutionProviderDispatch
65 }
66}
67
68// ── DirectML ────────────────────────────────────────────────────────────────
69
70/// DirectML (Windows GPU) execution provider stub.
71#[derive(Debug, Clone, Default)]
72pub struct DirectMLExecutionProvider;
73
74impl DirectMLExecutionProvider {
75 /// Finalise configuration and return an [`ExecutionProviderDispatch`].
76 pub fn build(self) -> ExecutionProviderDispatch {
77 ExecutionProviderDispatch
78 }
79}
80
81// ── TensorRT ────────────────────────────────────────────────────────────────
82
83/// NVIDIA TensorRT execution provider stub.
84#[derive(Debug, Clone, Default)]
85pub struct TensorRTExecutionProvider;
86
87impl TensorRTExecutionProvider {
88 /// Finalise configuration and return an [`ExecutionProviderDispatch`].
89 pub fn build(self) -> ExecutionProviderDispatch {
90 ExecutionProviderDispatch
91 }
92}
93
94// ── OpenVINO ────────────────────────────────────────────────────────────────
95
96/// Intel OpenVINO execution provider stub.
97#[derive(Debug, Clone, Default)]
98pub struct OpenVINOExecutionProvider;
99
100impl OpenVINOExecutionProvider {
101 /// Finalise configuration and return an [`ExecutionProviderDispatch`].
102 pub fn build(self) -> ExecutionProviderDispatch {
103 ExecutionProviderDispatch
104 }
105}
106
107// ── Operator Placement ──────────────────────────────────────────────────────
108
109/// Controls how operators are assigned to execution providers.
110#[derive(Debug, Clone, Default)]
111pub enum OpPlacement {
112 /// All ops on CPU (default when no GPU feature).
113 #[default]
114 CpuOnly,
115 /// Auto-select based on op type and tensor size thresholds.
116 Auto {
117 /// Minimum output tensor bytes for GPU dispatch (default: 65536 = 64KB).
118 gpu_threshold_bytes: usize,
119 },
120 /// Manual per-operator placement.
121 Manual(HashMap<OpKind, ProviderKind>),
122}
123
124/// Which provider to use for an operator invocation.
125#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
126pub enum ProviderKind {
127 Cpu,
128 #[cfg(feature = "gpu")]
129 Gpu,
130 #[cfg(feature = "cuda")]
131 Cuda,
132}
133
134/// Decide placement for a specific operator invocation.
135pub fn decide_placement(op: &OpKind, output_bytes: usize, placement: &OpPlacement) -> ProviderKind {
136 match placement {
137 OpPlacement::CpuOnly => ProviderKind::Cpu,
138 OpPlacement::Auto {
139 gpu_threshold_bytes,
140 } => {
141 if output_bytes >= *gpu_threshold_bytes && is_gpu_capable(op) {
142 #[cfg(feature = "gpu")]
143 return ProviderKind::Gpu;
144 #[cfg(not(feature = "gpu"))]
145 return ProviderKind::Cpu;
146 }
147 ProviderKind::Cpu
148 }
149 OpPlacement::Manual(map) => map.get(op).copied().unwrap_or(ProviderKind::Cpu),
150 }
151}
152
153/// Check if an operator has a GPU implementation.
154pub fn is_gpu_capable(op: &OpKind) -> bool {
155 matches!(
156 op,
157 OpKind::MatMul
158 | OpKind::Gemm
159 | OpKind::Conv
160 | OpKind::Add
161 | OpKind::Mul
162 | OpKind::Sub
163 | OpKind::Relu
164 | OpKind::Sigmoid
165 | OpKind::Softmax
166 | OpKind::LayerNorm
167 | OpKind::BatchNorm
168 | OpKind::Transpose
169 | OpKind::ReduceMean
170 )
171}