rlx-onnx 0.2.4

ONNX inference for RLX — native compile by default, optional ORT fallback
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, version 3.

//! Map [`rlx_runtime::Device`] to ONNX Runtime execution providers.

use std::num::NonZeroUsize;
use std::path::Path;
use std::sync::OnceLock;

use anyhow::{Context, Result, bail};
use ort::execution_providers::ExecutionProviderDispatch;
use ort::session::Session;
use ort::session::builder::GraphOptimizationLevel;
use rlx_runtime::{Device, is_available};

const STANDARD_DEVICES: &[Device] = &[
    Device::Cpu,
    Device::Metal,
    Device::Mlx,
    Device::Cuda,
    Device::Rocm,
    Device::Gpu,
    Device::Vulkan,
];

/// Result of building an ONNX Runtime session.
pub struct OrtSession {
    pub session: Session,
    /// Short label for the execution provider chain that loaded the model.
    pub ort_ep: String,
}

pub fn validate_device(device: Device) -> Result<()> {
    if STANDARD_DEVICES.contains(&device) {
        Ok(())
    } else {
        bail!(
            "rlx-onnx: device {device:?} is not supported \
             (use cpu|metal|mlx|cuda|rocm|gpu|vulkan)"
        )
    }
}

pub fn execution_providers_for(device: Device) -> Vec<ExecutionProviderDispatch> {
    let mut eps = primary_execution_providers(device);
    if !matches!(device, Device::Cpu) {
        eps.push(ort::ep::CPU::default().build());
    }
    eps
}

fn cpu_ep() -> ExecutionProviderDispatch {
    ort::ep::CPU::default().build()
}

#[cfg(all(feature = "ort-coreml", any(target_os = "macos", target_os = "ios")))]
fn coreml_ep_default() -> ExecutionProviderDispatch {
    ort::ep::CoreML::default().build()
}

#[cfg(all(feature = "ort-coreml", any(target_os = "macos", target_os = "ios")))]
fn coreml_ep_mlprogram_cpu() -> ExecutionProviderDispatch {
    ort::ep::CoreML::default()
        .with_model_format(ort::ep::coreml::ModelFormat::MLProgram)
        .with_compute_units(ort::ep::coreml::ComputeUnits::CPUOnly)
        .build()
}

#[cfg(feature = "ort-cuda")]
fn cuda_ep() -> ExecutionProviderDispatch {
    use ort::ep::cuda::{AttentionBackend, ConvAlgorithmSearch};

    let mut ep = ort::ep::CUDA::default()
        .with_device_id(0)
        .with_conv_algorithm_search(ConvAlgorithmSearch::Heuristic)
        .with_tf32(true)
        .with_attention_backend(
            AttentionBackend::FLASH_ATTENTION
                | AttentionBackend::EFFICIENT_ATTENTION
                | AttentionBackend::CUDNN_FLASH_ATTENTION,
        );
    if std::env::var_os("RLX_ORT_CUDA_GRAPH")
        .or_else(|| std::env::var_os("KITTENTTS_ORT_CUDA_GRAPH"))
        .is_some_and(|v| v == "1")
    {
        ep = ep.with_cuda_graph(true);
    }
    ep.build()
}

#[cfg(feature = "ort-rocm")]
fn rocm_ep() -> ExecutionProviderDispatch {
    ort::ep::ROCm::default().build()
}

#[cfg(all(feature = "ort-directml", target_os = "windows"))]
fn directml_ep() -> ExecutionProviderDispatch {
    ort::ep::DirectML::default().build()
}

fn primary_execution_providers(device: Device) -> Vec<ExecutionProviderDispatch> {
    match device {
        Device::Cpu => vec![cpu_ep()],
        Device::Metal | Device::Mlx => coreml_primary(device),
        Device::Cuda => cuda_primary(),
        Device::Rocm => rocm_primary(),
        Device::Gpu | Device::Vulkan => gpu_primary(device),
        other => {
            eprintln!("[rlx-onnx] WARN: {other:?} has no ONNX Runtime EP mapping; using CPU");
            Vec::new()
        }
    }
}

fn coreml_primary(device: Device) -> Vec<ExecutionProviderDispatch> {
    #[cfg(all(feature = "ort-coreml", any(target_os = "macos", target_os = "ios")))]
    {
        eprintln!("[rlx-onnx] ORT execution provider: CoreML ({device:?})");
        return vec![coreml_ep_default()];
    }
    #[cfg(not(all(feature = "ort-coreml", any(target_os = "macos", target_os = "ios"))))]
    {
        eprintln!(
            "[rlx-onnx] WARN: {device:?} requested but CoreML EP unavailable on this target; using CPU"
        );
        Vec::new()
    }
}

fn cuda_primary() -> Vec<ExecutionProviderDispatch> {
    #[cfg(feature = "ort-cuda")]
    {
        eprintln!("[rlx-onnx] ORT execution provider: CUDA");
        return vec![cuda_ep()];
    }
    #[cfg(not(feature = "ort-cuda"))]
    {
        eprintln!("[rlx-onnx] WARN: CUDA requested but ort/cuda not enabled; using CPU");
        Vec::new()
    }
}

fn rocm_primary() -> Vec<ExecutionProviderDispatch> {
    #[cfg(feature = "ort-rocm")]
    {
        eprintln!("[rlx-onnx] ORT execution provider: ROCm");
        return vec![rocm_ep()];
    }
    #[cfg(not(feature = "ort-rocm"))]
    {
        eprintln!("[rlx-onnx] WARN: ROCm requested but ort/rocm not enabled; using CPU");
        Vec::new()
    }
}

fn gpu_primary(device: Device) -> Vec<ExecutionProviderDispatch> {
    #[cfg(all(feature = "ort-directml", target_os = "windows"))]
    {
        eprintln!("[rlx-onnx] ORT execution provider: DirectML ({device:?})");
        return vec![directml_ep()];
    }
    #[cfg(all(feature = "ort-cuda", target_os = "linux"))]
    {
        eprintln!("[rlx-onnx] ORT execution provider: CUDA ({device:?})");
        return vec![cuda_ep()];
    }
    #[cfg(all(feature = "ort-coreml", target_os = "macos"))]
    {
        eprintln!("[rlx-onnx] ORT execution provider: CoreML ({device:?})");
        return vec![coreml_ep_default()];
    }
    #[cfg(not(any(
        all(feature = "ort-directml", target_os = "windows"),
        all(feature = "ort-cuda", target_os = "linux"),
        all(feature = "ort-coreml", target_os = "macos")
    )))]
    {
        eprintln!(
            "[rlx-onnx] WARN: {device:?} requested but no matching ORT GPU EP on this target; using CPU"
        );
        Vec::new()
    }
}

fn ep_attempts(device: Device) -> Vec<(&'static str, Vec<ExecutionProviderDispatch>)> {
    let mut attempts = Vec::new();

    match device {
        Device::Cpu => {
            attempts.push(("cpu", vec![cpu_ep()]));
        }
        Device::Metal | Device::Mlx => {
            #[cfg(all(feature = "ort-coreml", any(target_os = "macos", target_os = "ios")))]
            {
                attempts.push(("coreml_default+cpu", vec![coreml_ep_default(), cpu_ep()]));
                attempts.push((
                    "coreml_mlprogram_cpu+cpu",
                    vec![coreml_ep_mlprogram_cpu(), cpu_ep()],
                ));
            }
            attempts.push(("cpu", vec![cpu_ep()]));
        }
        Device::Cuda => {
            #[cfg(feature = "ort-cuda")]
            attempts.push(("cuda+cpu", vec![cuda_ep(), cpu_ep()]));
            attempts.push(("cpu", vec![cpu_ep()]));
        }
        Device::Rocm => {
            #[cfg(feature = "ort-rocm")]
            attempts.push(("rocm+cpu", vec![rocm_ep(), cpu_ep()]));
            attempts.push(("cpu", vec![cpu_ep()]));
        }
        Device::Gpu | Device::Vulkan => {
            #[cfg(all(feature = "ort-directml", target_os = "windows"))]
            attempts.push(("directml+cpu", vec![directml_ep(), cpu_ep()]));
            #[cfg(all(feature = "ort-cuda", target_os = "linux"))]
            attempts.push(("cuda+cpu", vec![cuda_ep(), cpu_ep()]));
            #[cfg(all(feature = "ort-coreml", target_os = "macos"))]
            {
                attempts.push(("coreml_default+cpu", vec![coreml_ep_default(), cpu_ep()]));
                attempts.push((
                    "coreml_mlprogram_cpu+cpu",
                    vec![coreml_ep_mlprogram_cpu(), cpu_ep()],
                ));
            }
            attempts.push(("cpu", vec![cpu_ep()]));
        }
        _ => {
            attempts.push(("cpu", vec![cpu_ep()]));
        }
    }

    attempts
}

fn ort_intra_threads() -> usize {
    static THREADS: OnceLock<usize> = OnceLock::new();
    *THREADS.get_or_init(|| {
        for key in ["RLX_ORT_INTRA_THREADS", "KITTENTTS_ORT_INTRA_THREADS"] {
            if let Ok(s) = std::env::var(key) {
                if let Ok(n) = s.parse::<usize>() {
                    return n.max(1);
                }
            }
        }
        std::thread::available_parallelism()
            .map(NonZeroUsize::get)
            .unwrap_or(4)
            .max(1)
    })
}

fn try_build(
    model_path: &Path,
    label: &str,
    eps: Vec<ExecutionProviderDispatch>,
) -> Result<Session> {
    let mut builder = Session::builder().context("ORT session builder")?;
    builder = builder
        .with_optimization_level(GraphOptimizationLevel::Level3)
        .context("ORT graph optimization level")?
        .with_intra_threads(ort_intra_threads())
        .context("ORT intra-op threads")?;
    if !eps.is_empty() {
        builder = builder
            .with_execution_providers(eps)
            .with_context(|| format!("ORT execution providers ({label})"))?;
    }
    builder
        .commit_from_file(model_path)
        .with_context(|| format!("load ONNX model ({label}): {}", model_path.display()))
}

/// Build an ONNX Runtime session for `model_path` on `device`.
pub fn build_onnx_session(model_path: &Path, device: Device) -> Result<OrtSession> {
    validate_device(device)?;
    if !is_available(device) {
        eprintln!(
            "[rlx-onnx] WARN: RLX backend {device:?} not available in this build; ORT may fall back to CPU EP"
        );
    }

    let attempts = ep_attempts(device);
    let mut last_err = None;

    for (label, eps) in attempts {
        match try_build(model_path, label, eps) {
            Ok(session) => {
                if label != "cpu" && device != Device::Cpu {
                    eprintln!("[rlx-onnx] loaded {device:?} via ORT EP chain '{label}'");
                } else if device != Device::Cpu {
                    eprintln!(
                        "[rlx-onnx] loaded {device:?} via CPU EP fallback (GPU EP unavailable for this graph)"
                    );
                }
                let ort_ep = if label.starts_with("cuda") {
                    "cuda".to_string()
                } else if label.starts_with("rocm") {
                    "rocm".to_string()
                } else if label.starts_with("directml") {
                    "directml".to_string()
                } else if label.contains("coreml") {
                    "coreml".to_string()
                } else {
                    label.to_string()
                };
                return Ok(OrtSession { session, ort_ep });
            }
            Err(e) => {
                eprintln!("[rlx-onnx] WARN: ORT {label} failed for {device:?}: {e:#}");
                last_err = Some(e);
            }
        }
    }

    Err(last_err.unwrap_or_else(|| anyhow::anyhow!("no ORT EP attempts for {device:?}")))
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn cpu_always_has_cpu_ep() {
        assert_eq!(execution_providers_for(Device::Cpu).len(), 1);
    }

    #[test]
    fn validate_rejects_exotic_devices() {
        assert!(validate_device(Device::Tpu).is_err());
    }
}