xq-vision 0.1.3

High-performance ONNX recognition pipeline for Chinese chessboard corners and pieces.
Documentation
use ort::execution_providers::ExecutionProviderDispatch;
use ort::session::Session;
use ort::session::builder::SessionBuilder;

use crate::config::ModelSource;
use crate::config::SessionConfig;
use crate::error::Result;
use crate::error::XqVisionError;

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ProviderFailure {
    Fallback,
    Error,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ExecutionProvider {
    Cpu,
    CoreMl,
    Cuda,
    TensorRt,
    DirectMl,
    OpenVino,
    Xnnpack,
}

impl ExecutionProvider {
    fn dispatch(self, failure: ProviderFailure) -> Result<ExecutionProviderDispatch> {
        let dispatch = match self {
            Self::Cpu => ort::execution_providers::CPUExecutionProvider::default().build(),
            Self::CoreMl => {
                #[cfg(feature = "coreml")]
                {
                    ort::execution_providers::CoreMLExecutionProvider::default().build()
                }
                #[cfg(not(feature = "coreml"))]
                {
                    return Err(XqVisionError::UnsupportedProvider { provider: self });
                }
            }
            Self::Cuda => {
                #[cfg(feature = "cuda")]
                {
                    ort::execution_providers::CUDAExecutionProvider::default().build()
                }
                #[cfg(not(feature = "cuda"))]
                {
                    return Err(XqVisionError::UnsupportedProvider { provider: self });
                }
            }
            Self::TensorRt => {
                #[cfg(feature = "tensorrt")]
                {
                    ort::execution_providers::TensorRTExecutionProvider::default().build()
                }
                #[cfg(not(feature = "tensorrt"))]
                {
                    return Err(XqVisionError::UnsupportedProvider { provider: self });
                }
            }
            Self::DirectMl => {
                #[cfg(feature = "directml")]
                {
                    ort::execution_providers::DirectMLExecutionProvider::default().build()
                }
                #[cfg(not(feature = "directml"))]
                {
                    return Err(XqVisionError::UnsupportedProvider { provider: self });
                }
            }
            Self::OpenVino => {
                #[cfg(feature = "openvino")]
                {
                    ort::execution_providers::OpenVINOExecutionProvider::default().build()
                }
                #[cfg(not(feature = "openvino"))]
                {
                    return Err(XqVisionError::UnsupportedProvider { provider: self });
                }
            }
            Self::Xnnpack => {
                #[cfg(feature = "xnnpack")]
                {
                    ort::execution_providers::XNNPACKExecutionProvider::default().build()
                }
                #[cfg(not(feature = "xnnpack"))]
                {
                    return Err(XqVisionError::UnsupportedProvider { provider: self });
                }
            }
        };

        Ok(match failure {
            ProviderFailure::Fallback => dispatch.fail_silently(),
            ProviderFailure::Error => dispatch.error_on_failure(),
        })
    }
}

pub(crate) fn create_session(source: &ModelSource, config: &SessionConfig) -> Result<Session> {
    let mut builder =
        Session::builder()?.with_optimization_level(config.graph_optimization().into()).map_err(map_builder_error)?;

    if let Some(threads) = config.intra_threads() {
        builder = builder.with_intra_threads(threads).map_err(map_builder_error)?;
    }
    if let Some(threads) = config.inter_threads() {
        builder = builder.with_inter_threads(threads).map_err(map_builder_error)?;
    }
    if config.parallel_execution() {
        builder = builder.with_parallel_execution(true).map_err(map_builder_error)?;
    }

    let providers = config
        .execution_providers()
        .iter()
        .copied()
        .map(|provider| provider.dispatch(config.provider_failure()))
        .collect::<Result<Vec<_>>>()?;
    builder = builder.with_execution_providers(providers).map_err(map_builder_error)?;

    match source {
        ModelSource::File(path) => Ok(builder.commit_from_file(path)?),
        ModelSource::Memory(bytes) => Ok(builder.commit_from_memory(bytes)?),
    }
}

fn map_builder_error(error: ort::Error<SessionBuilder>) -> XqVisionError {
    XqVisionError::Ort(ort::Error::new(error.to_string()))
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn cpu_provider_dispatches_without_feature_gate() -> Result<()> {
        let dispatch = ExecutionProvider::Cpu.dispatch(ProviderFailure::Fallback)?;
        assert!(dispatch.downcast_ref::<ort::execution_providers::CPUExecutionProvider>().is_some());
        Ok(())
    }
}