use std::fmt;
use std::str::FromStr;
use std::sync::atomic::{AtomicI32, AtomicU8, Ordering};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
#[repr(u8)]
pub enum OrtAccelerator {
Auto = 0,
#[serde(rename = "cpu", alias = "cpu_only")]
CpuOnly = 1,
Cuda = 2,
#[serde(rename = "tensorrt", alias = "tensor_rt")]
TensorRt = 7,
#[serde(rename = "directml", alias = "direct_ml")]
DirectMl = 3,
Rocm = 4,
#[serde(rename = "coreml")]
CoreMl = 5,
#[serde(rename = "webgpu")]
WebGpu = 6,
}
static ORT_ACCELERATOR: AtomicU8 = AtomicU8::new(OrtAccelerator::Auto as u8);
pub fn set_ort_accelerator(pref: OrtAccelerator) {
ORT_ACCELERATOR.store(pref as u8, Ordering::Relaxed);
}
pub fn get_ort_accelerator() -> OrtAccelerator {
OrtAccelerator::from_u8(ORT_ACCELERATOR.load(Ordering::Relaxed))
}
impl OrtAccelerator {
pub fn available() -> Vec<OrtAccelerator> {
#[allow(unused_mut)]
let mut v = vec![OrtAccelerator::CpuOnly];
#[cfg(feature = "ort-cuda")]
v.push(OrtAccelerator::Cuda);
#[cfg(feature = "ort-tensorrt")]
v.push(OrtAccelerator::TensorRt);
#[cfg(feature = "ort-directml")]
v.push(OrtAccelerator::DirectMl);
#[cfg(feature = "ort-rocm")]
v.push(OrtAccelerator::Rocm);
#[cfg(feature = "ort-coreml")]
v.push(OrtAccelerator::CoreMl);
#[cfg(feature = "ort-webgpu")]
v.push(OrtAccelerator::WebGpu);
v
}
fn from_u8(val: u8) -> Self {
match val {
0 => Self::Auto,
1 => Self::CpuOnly,
2 => Self::Cuda,
3 => Self::DirectMl,
4 => Self::Rocm,
5 => Self::CoreMl,
6 => Self::WebGpu,
7 => Self::TensorRt,
_ => Self::Auto,
}
}
}
impl Default for OrtAccelerator {
fn default() -> Self {
Self::Auto
}
}
impl fmt::Display for OrtAccelerator {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
Self::Auto => "auto",
Self::CpuOnly => "cpu",
Self::Cuda => "cuda",
Self::TensorRt => "tensorrt",
Self::DirectMl => "directml",
Self::Rocm => "rocm",
Self::CoreMl => "coreml",
Self::WebGpu => "webgpu",
};
f.write_str(s)
}
}
impl FromStr for OrtAccelerator {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_ascii_lowercase().as_str() {
"auto" => Ok(Self::Auto),
"cpu" | "cpu_only" | "cpuonly" => Ok(Self::CpuOnly),
"cuda" => Ok(Self::Cuda),
"tensorrt" | "trt" | "tensor_rt" => Ok(Self::TensorRt),
"directml" | "dml" => Ok(Self::DirectMl),
"rocm" => Ok(Self::Rocm),
"coreml" | "core_ml" => Ok(Self::CoreMl),
"webgpu" | "web_gpu" => Ok(Self::WebGpu),
other => Err(format!("unknown ORT accelerator: {other}")),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[repr(u8)]
pub enum WhisperAccelerator {
Auto = 0,
CpuOnly = 1,
Gpu = 2,
}
static WHISPER_ACCELERATOR: AtomicU8 = AtomicU8::new(WhisperAccelerator::Auto as u8);
pub fn set_whisper_accelerator(pref: WhisperAccelerator) {
WHISPER_ACCELERATOR.store(pref as u8, Ordering::Relaxed);
}
pub fn get_whisper_accelerator() -> WhisperAccelerator {
WhisperAccelerator::from_u8(WHISPER_ACCELERATOR.load(Ordering::Relaxed))
}
impl WhisperAccelerator {
pub fn available() -> Vec<WhisperAccelerator> {
#[allow(unused_mut)]
let mut v = vec![WhisperAccelerator::CpuOnly];
#[cfg(any(
feature = "whisper-metal",
feature = "whisper-vulkan",
feature = "whisper-cuda"
))]
v.push(WhisperAccelerator::Gpu);
v
}
pub fn use_gpu(&self) -> bool {
*self != Self::CpuOnly
}
fn from_u8(val: u8) -> Self {
match val {
0 => Self::Auto,
1 => Self::CpuOnly,
2 => Self::Gpu,
_ => Self::Auto,
}
}
}
impl Default for WhisperAccelerator {
fn default() -> Self {
Self::Auto
}
}
impl fmt::Display for WhisperAccelerator {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
Self::Auto => "auto",
Self::CpuOnly => "cpu",
Self::Gpu => "gpu",
};
f.write_str(s)
}
}
impl FromStr for WhisperAccelerator {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_ascii_lowercase().as_str() {
"auto" => Ok(Self::Auto),
"cpu" | "cpu_only" | "cpuonly" => Ok(Self::CpuOnly),
"gpu" => Ok(Self::Gpu),
other => Err(format!("unknown Whisper accelerator: {other}")),
}
}
}
pub const GPU_DEVICE_AUTO: i32 = -1;
static WHISPER_GPU_DEVICE: AtomicI32 = AtomicI32::new(GPU_DEVICE_AUTO);
pub fn set_whisper_gpu_device(device: i32) {
WHISPER_GPU_DEVICE.store(device, Ordering::Relaxed);
}
pub fn get_whisper_gpu_device() -> i32 {
WHISPER_GPU_DEVICE.load(Ordering::Relaxed)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
static ACCEL_LOCK: Mutex<()> = Mutex::new(());
struct AccelGuard(#[allow(dead_code)] std::sync::MutexGuard<'static, ()>);
impl AccelGuard {
fn new() -> Self {
let g = ACCEL_LOCK.lock().unwrap_or_else(|e| e.into_inner());
Self(g)
}
}
impl Drop for AccelGuard {
fn drop(&mut self) {
set_ort_accelerator(OrtAccelerator::Auto);
set_whisper_accelerator(WhisperAccelerator::Auto);
set_whisper_gpu_device(GPU_DEVICE_AUTO);
}
}
#[test]
fn ort_default_is_auto() {
let _g = AccelGuard::new();
set_ort_accelerator(OrtAccelerator::Auto);
assert_eq!(get_ort_accelerator(), OrtAccelerator::Auto);
}
#[test]
fn ort_set_and_get() {
let _g = AccelGuard::new();
set_ort_accelerator(OrtAccelerator::Cuda);
assert_eq!(get_ort_accelerator(), OrtAccelerator::Cuda);
set_ort_accelerator(OrtAccelerator::CpuOnly);
assert_eq!(get_ort_accelerator(), OrtAccelerator::CpuOnly);
}
#[test]
fn ort_display_roundtrip() {
for pref in [
OrtAccelerator::Auto,
OrtAccelerator::CpuOnly,
OrtAccelerator::Cuda,
OrtAccelerator::TensorRt,
OrtAccelerator::DirectMl,
OrtAccelerator::Rocm,
OrtAccelerator::CoreMl,
OrtAccelerator::WebGpu,
] {
let s = pref.to_string();
let parsed: OrtAccelerator = s.parse().unwrap();
assert_eq!(parsed, pref);
}
}
#[test]
fn ort_parse_aliases() {
assert_eq!(
"dml".parse::<OrtAccelerator>().unwrap(),
OrtAccelerator::DirectMl
);
assert_eq!(
"CPU".parse::<OrtAccelerator>().unwrap(),
OrtAccelerator::CpuOnly
);
assert_eq!(
"cpu_only".parse::<OrtAccelerator>().unwrap(),
OrtAccelerator::CpuOnly
);
assert_eq!(
"trt".parse::<OrtAccelerator>().unwrap(),
OrtAccelerator::TensorRt
);
}
#[test]
fn ort_parse_unknown_errors() {
assert!("foobar".parse::<OrtAccelerator>().is_err());
}
#[test]
fn ort_serde_roundtrip() {
for (pref, expected) in [
(OrtAccelerator::Auto, "\"auto\""),
(OrtAccelerator::CpuOnly, "\"cpu\""),
(OrtAccelerator::Cuda, "\"cuda\""),
(OrtAccelerator::TensorRt, "\"tensorrt\""),
(OrtAccelerator::DirectMl, "\"directml\""),
(OrtAccelerator::Rocm, "\"rocm\""),
(OrtAccelerator::CoreMl, "\"coreml\""),
(OrtAccelerator::WebGpu, "\"webgpu\""),
] {
let json = serde_json::to_string(&pref).unwrap();
assert_eq!(json, expected, "serialize {:?}", pref);
let back: OrtAccelerator = serde_json::from_str(&json).unwrap();
assert_eq!(back, pref, "deserialize {}", json);
}
}
#[test]
fn ort_serde_backward_compat() {
let old_cpu: OrtAccelerator = serde_json::from_str("\"cpu_only\"").unwrap();
assert_eq!(old_cpu, OrtAccelerator::CpuOnly);
let old_dml: OrtAccelerator = serde_json::from_str("\"direct_ml\"").unwrap();
assert_eq!(old_dml, OrtAccelerator::DirectMl);
let old_trt: OrtAccelerator = serde_json::from_str("\"tensor_rt\"").unwrap();
assert_eq!(old_trt, OrtAccelerator::TensorRt);
}
#[test]
fn ort_available_always_includes_cpu() {
let avail = OrtAccelerator::available();
assert!(avail.contains(&OrtAccelerator::CpuOnly));
}
#[test]
fn ort_from_u8_invalid_returns_auto() {
assert_eq!(OrtAccelerator::from_u8(255), OrtAccelerator::Auto);
}
#[test]
fn whisper_default_is_auto() {
let _g = AccelGuard::new();
set_whisper_accelerator(WhisperAccelerator::Auto);
assert_eq!(get_whisper_accelerator(), WhisperAccelerator::Auto);
}
#[test]
fn whisper_set_and_get() {
let _g = AccelGuard::new();
set_whisper_accelerator(WhisperAccelerator::CpuOnly);
assert_eq!(get_whisper_accelerator(), WhisperAccelerator::CpuOnly);
set_whisper_accelerator(WhisperAccelerator::Gpu);
assert_eq!(get_whisper_accelerator(), WhisperAccelerator::Gpu);
}
#[test]
fn whisper_display_roundtrip() {
for pref in [
WhisperAccelerator::Auto,
WhisperAccelerator::CpuOnly,
WhisperAccelerator::Gpu,
] {
let s = pref.to_string();
let parsed: WhisperAccelerator = s.parse().unwrap();
assert_eq!(parsed, pref);
}
}
#[test]
fn whisper_use_gpu_flag() {
assert!(WhisperAccelerator::Auto.use_gpu());
assert!(!WhisperAccelerator::CpuOnly.use_gpu());
assert!(WhisperAccelerator::Gpu.use_gpu());
}
#[test]
fn whisper_parse_unknown_errors() {
assert!("foobar".parse::<WhisperAccelerator>().is_err());
}
#[test]
fn whisper_serde_roundtrip() {
let pref = WhisperAccelerator::Gpu;
let json = serde_json::to_string(&pref).unwrap();
assert_eq!(json, "\"gpu\"");
let back: WhisperAccelerator = serde_json::from_str(&json).unwrap();
assert_eq!(back, pref);
}
#[test]
fn whisper_available_always_includes_cpu() {
let avail = WhisperAccelerator::available();
assert!(avail.contains(&WhisperAccelerator::CpuOnly));
}
#[test]
fn whisper_from_u8_invalid_returns_auto() {
assert_eq!(WhisperAccelerator::from_u8(255), WhisperAccelerator::Auto);
}
#[test]
fn gpu_device_default_is_auto() {
let _g = AccelGuard::new();
set_whisper_gpu_device(GPU_DEVICE_AUTO);
assert_eq!(get_whisper_gpu_device(), GPU_DEVICE_AUTO);
}
#[test]
fn gpu_device_set_and_get() {
let _g = AccelGuard::new();
set_whisper_gpu_device(1);
assert_eq!(get_whisper_gpu_device(), 1);
set_whisper_gpu_device(GPU_DEVICE_AUTO);
assert_eq!(get_whisper_gpu_device(), GPU_DEVICE_AUTO);
}
}