#[cfg(feature = "cuda")]
use oar_ocr::core::config::OrtExecutionProvider;
use oar_ocr::core::config::OrtSessionConfig;
pub fn parse_device_config(
device: &str,
) -> Result<Option<OrtSessionConfig>, Box<dyn std::error::Error>> {
let device_lower = device.to_lowercase();
if device_lower == "cpu" {
return Ok(None);
}
#[cfg(feature = "cuda")]
{
if device_lower.starts_with("cuda") {
let device_id = if device_lower == "cuda" {
0
} else if let Some(id_str) = device_lower.strip_prefix("cuda:") {
id_str.parse::<i32>().map_err(|_| {
format!(
"Invalid CUDA device ID: {}. Expected format: 'cuda' or 'cuda:N'",
device
)
})?
} else {
return Err(format!(
"Invalid device format: {}. Expected 'cuda' or 'cuda:N'",
device
)
.into());
};
let config = OrtSessionConfig::new().with_execution_providers(vec![
OrtExecutionProvider::CUDA {
device_id: Some(device_id),
gpu_mem_limit: None,
arena_extend_strategy: None,
cudnn_conv_algo_search: None,
cudnn_conv_use_max_workspace: None,
},
OrtExecutionProvider::CPU, ]);
return Ok(Some(config));
}
}
#[cfg(not(feature = "cuda"))]
{
if device_lower.starts_with("cuda") {
return Err(format!(
"CUDA device '{}' requested but CUDA feature is not enabled. \
Rebuild with --features=cuda to enable CUDA support.",
device
)
.into());
}
}
Err(format!(
"Unsupported device: {}. Supported devices: cpu{}",
device,
if cfg!(feature = "cuda") {
", cuda, cuda:N"
} else {
""
}
)
.into())
}