Skip to main content

oxionnx/
execution_providers.rs

1//! Execution Provider compatibility layer for migration from `ort`.
2//!
3//! `ort` 2.x exposes a rich set of execution provider (EP) types such as
4//! `CUDAExecutionProvider`, `CoreMLExecutionProvider`, etc., that allow
5//! callers to configure hardware acceleration at session build time.
6//!
7//! oxionnx selects its backend at compile time via Cargo feature flags
8//! (`gpu`, `cuda`).  These stub types mirror the `ort` EP API surface so
9//! that code written against `ort` can compile against oxionnx with only
10//! a `use` path change — no call-site edits required.
11//!
12//! Every `build()` call returns an [`ExecutionProviderDispatch`] no-op token.
13//! The actual backend selection is governed by the crate's feature flags.
14
15use oxionnx_core::graph::OpKind;
16use std::collections::HashMap;
17
18/// Opaque no-op token returned by EP `.build()` calls.
19///
20/// Passed to [`crate::SessionBuilder::with_execution_providers`], which
21/// accepts but ignores the list.
22#[derive(Debug, Clone, Copy, Default)]
23pub struct ExecutionProviderDispatch;
24
25// ── CPU ─────────────────────────────────────────────────────────────────────
26
27/// CPU execution provider stub (always active in oxionnx).
28#[derive(Debug, Clone, Default)]
29pub struct CPUExecutionProvider;
30
31impl CPUExecutionProvider {
32    /// Finalise configuration and return an [`ExecutionProviderDispatch`].
33    pub fn build(self) -> ExecutionProviderDispatch {
34        ExecutionProviderDispatch
35    }
36}
37
38// ── CUDA ────────────────────────────────────────────────────────────────────
39
40/// CUDA execution provider stub.
41///
42/// When the `cuda` feature is enabled, oxionnx routes eligible ops through
43/// the CUDA backend automatically.  This type is accepted at the API level
44/// for ort-compatible code.
45#[derive(Debug, Clone, Default)]
46pub struct CUDAExecutionProvider;
47
48impl CUDAExecutionProvider {
49    /// Finalise configuration and return an [`ExecutionProviderDispatch`].
50    pub fn build(self) -> ExecutionProviderDispatch {
51        ExecutionProviderDispatch
52    }
53}
54
55// ── CoreML ──────────────────────────────────────────────────────────────────
56
57/// Apple CoreML execution provider stub.
58#[derive(Debug, Clone, Default)]
59pub struct CoreMLExecutionProvider;
60
61impl CoreMLExecutionProvider {
62    /// Finalise configuration and return an [`ExecutionProviderDispatch`].
63    pub fn build(self) -> ExecutionProviderDispatch {
64        ExecutionProviderDispatch
65    }
66}
67
68// ── DirectML ────────────────────────────────────────────────────────────────
69
70/// DirectML (Windows GPU) execution provider stub.
71#[derive(Debug, Clone, Default)]
72pub struct DirectMLExecutionProvider;
73
74impl DirectMLExecutionProvider {
75    /// Finalise configuration and return an [`ExecutionProviderDispatch`].
76    pub fn build(self) -> ExecutionProviderDispatch {
77        ExecutionProviderDispatch
78    }
79}
80
81// ── TensorRT ────────────────────────────────────────────────────────────────
82
83/// NVIDIA TensorRT execution provider stub.
84#[derive(Debug, Clone, Default)]
85pub struct TensorRTExecutionProvider;
86
87impl TensorRTExecutionProvider {
88    /// Finalise configuration and return an [`ExecutionProviderDispatch`].
89    pub fn build(self) -> ExecutionProviderDispatch {
90        ExecutionProviderDispatch
91    }
92}
93
94// ── OpenVINO ────────────────────────────────────────────────────────────────
95
96/// Intel OpenVINO execution provider stub.
97#[derive(Debug, Clone, Default)]
98pub struct OpenVINOExecutionProvider;
99
100impl OpenVINOExecutionProvider {
101    /// Finalise configuration and return an [`ExecutionProviderDispatch`].
102    pub fn build(self) -> ExecutionProviderDispatch {
103        ExecutionProviderDispatch
104    }
105}
106
107// ── Operator Placement ──────────────────────────────────────────────────────
108
109/// Controls how operators are assigned to execution providers.
110#[derive(Debug, Clone, Default)]
111pub enum OpPlacement {
112    /// All ops on CPU (default when no GPU feature).
113    #[default]
114    CpuOnly,
115    /// Auto-select based on op type and tensor size thresholds.
116    Auto {
117        /// Minimum output tensor bytes for GPU dispatch (default: 65536 = 64KB).
118        gpu_threshold_bytes: usize,
119    },
120    /// Manual per-operator placement.
121    Manual(HashMap<OpKind, ProviderKind>),
122}
123
124/// Which provider to use for an operator invocation.
125#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
126pub enum ProviderKind {
127    Cpu,
128    #[cfg(feature = "gpu")]
129    Gpu,
130    #[cfg(feature = "cuda")]
131    Cuda,
132}
133
134/// Decide placement for a specific operator invocation.
135pub fn decide_placement(op: &OpKind, output_bytes: usize, placement: &OpPlacement) -> ProviderKind {
136    match placement {
137        OpPlacement::CpuOnly => ProviderKind::Cpu,
138        OpPlacement::Auto {
139            gpu_threshold_bytes,
140        } => {
141            if output_bytes >= *gpu_threshold_bytes && is_gpu_capable(op) {
142                #[cfg(feature = "gpu")]
143                return ProviderKind::Gpu;
144                #[cfg(not(feature = "gpu"))]
145                return ProviderKind::Cpu;
146            }
147            ProviderKind::Cpu
148        }
149        OpPlacement::Manual(map) => map.get(op).copied().unwrap_or(ProviderKind::Cpu),
150    }
151}
152
153/// Check if an operator has a GPU implementation.
154pub fn is_gpu_capable(op: &OpKind) -> bool {
155    matches!(
156        op,
157        OpKind::MatMul
158            | OpKind::Gemm
159            | OpKind::Conv
160            | OpKind::Add
161            | OpKind::Mul
162            | OpKind::Sub
163            | OpKind::Relu
164            | OpKind::Sigmoid
165            | OpKind::Softmax
166            | OpKind::LayerNorm
167            | OpKind::BatchNorm
168            | OpKind::Transpose
169            | OpKind::ReduceMean
170    )
171}