use oxionnx_core::graph::OpKind;
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, Default)]
pub struct ExecutionProviderDispatch;
#[derive(Debug, Clone, Default)]
pub struct CPUExecutionProvider;
impl CPUExecutionProvider {
pub fn build(self) -> ExecutionProviderDispatch {
ExecutionProviderDispatch
}
}
#[derive(Debug, Clone, Default)]
pub struct CUDAExecutionProvider;
impl CUDAExecutionProvider {
pub fn build(self) -> ExecutionProviderDispatch {
ExecutionProviderDispatch
}
}
#[derive(Debug, Clone, Default)]
pub struct CoreMLExecutionProvider;
impl CoreMLExecutionProvider {
pub fn build(self) -> ExecutionProviderDispatch {
ExecutionProviderDispatch
}
}
#[derive(Debug, Clone, Default)]
pub struct DirectMLExecutionProvider;
impl DirectMLExecutionProvider {
pub fn build(self) -> ExecutionProviderDispatch {
ExecutionProviderDispatch
}
}
#[derive(Debug, Clone, Default)]
pub struct TensorRTExecutionProvider;
impl TensorRTExecutionProvider {
pub fn build(self) -> ExecutionProviderDispatch {
ExecutionProviderDispatch
}
}
#[derive(Debug, Clone, Default)]
pub struct OpenVINOExecutionProvider;
impl OpenVINOExecutionProvider {
pub fn build(self) -> ExecutionProviderDispatch {
ExecutionProviderDispatch
}
}
#[derive(Debug, Clone, Default)]
pub enum OpPlacement {
#[default]
CpuOnly,
Auto {
gpu_threshold_bytes: usize,
},
Manual(HashMap<OpKind, ProviderKind>),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ProviderKind {
Cpu,
#[cfg(feature = "gpu")]
Gpu,
#[cfg(feature = "cuda")]
Cuda,
}
pub fn decide_placement(op: &OpKind, output_bytes: usize, placement: &OpPlacement) -> ProviderKind {
match placement {
OpPlacement::CpuOnly => ProviderKind::Cpu,
OpPlacement::Auto {
gpu_threshold_bytes,
} => {
if output_bytes >= *gpu_threshold_bytes && is_gpu_capable(op) {
#[cfg(feature = "gpu")]
return ProviderKind::Gpu;
#[cfg(not(feature = "gpu"))]
return ProviderKind::Cpu;
}
ProviderKind::Cpu
}
OpPlacement::Manual(map) => map.get(op).copied().unwrap_or(ProviderKind::Cpu),
}
}
pub fn is_gpu_capable(op: &OpKind) -> bool {
matches!(
op,
OpKind::MatMul
| OpKind::Gemm
| OpKind::Conv
| OpKind::Add
| OpKind::Mul
| OpKind::Sub
| OpKind::Relu
| OpKind::Sigmoid
| OpKind::Softmax
| OpKind::LayerNorm
| OpKind::BatchNorm
| OpKind::Transpose
| OpKind::ReduceMean
)
}