use crate::error::Result;
use candle_core::Device;
use std::sync::Once;
#[derive(Debug, Clone, Default)]
pub struct DeviceConfig {
pub cuda_device: usize,
pub force_cpu: bool,
pub crate_name: Option<String>,
}
impl DeviceConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_cuda_device(mut self, ordinal: usize) -> Self {
self.cuda_device = ordinal;
self
}
#[must_use]
pub fn with_force_cpu(mut self, force: bool) -> Self {
self.force_cpu = force;
self
}
#[must_use]
pub fn with_crate_name(mut self, name: impl Into<String>) -> Self {
self.crate_name = Some(name.into());
self
}
#[must_use]
pub fn from_env() -> Self {
let mut config = Self::default();
let force_cpu_vars = [
"RUST_AI_FORCE_CPU",
"AXOLOTL_FORCE_CPU",
"VSA_OPTIM_FORCE_CPU",
];
for var in force_cpu_vars {
if let Ok(val) = std::env::var(var) {
if val == "1" || val.to_lowercase() == "true" {
config.force_cpu = true;
break;
}
}
}
let cuda_device_vars = [
"RUST_AI_CUDA_DEVICE",
"AXOLOTL_CUDA_DEVICE",
"VSA_OPTIM_CUDA_DEVICE",
];
for var in cuda_device_vars {
if let Ok(val) = std::env::var(var) {
if let Ok(ordinal) = val.parse::<usize>() {
config.cuda_device = ordinal;
break;
}
}
}
config
}
}
pub fn get_device(config: &DeviceConfig) -> Result<Device> {
let crate_name = config.crate_name.as_deref().unwrap_or("rust-ai");
if config.force_cpu {
tracing::warn!(
"{}: CPU device forced via configuration. \
CUDA is the intended default for optimal performance.",
crate_name
);
return Ok(Device::Cpu);
}
match Device::cuda_if_available(config.cuda_device) {
Ok(Device::Cuda(cuda)) => {
tracing::info!(
"{}: Using CUDA device {} for GPU-accelerated execution",
crate_name,
config.cuda_device
);
Ok(Device::Cuda(cuda))
}
Ok(Device::Cpu) | Err(_) => {
warn_if_cpu_internal(&Device::Cpu, crate_name);
Ok(Device::Cpu)
}
Ok(device) => Ok(device), }
}
pub fn warn_if_cpu(device: &Device, crate_name: &str) {
warn_if_cpu_internal(device, crate_name);
}
fn warn_if_cpu_internal(device: &Device, crate_name: &str) {
static WARN_ONCE: Once = Once::new();
if matches!(device, Device::Cpu) {
WARN_ONCE.call_once(|| {
tracing::warn!(
"{crate_name}: CPU device in use. CUDA is the intended default; \
CPU mode exists only as a compatibility fallback. \
For production workloads, ensure CUDA is available. \
Set RUST_AI_FORCE_CPU=1 to silence this warning."
);
eprintln!(
"WARNING: {crate_name}: CPU device in use. CUDA is the intended default; \
CPU mode exists only as a compatibility fallback."
);
});
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_device_config_default() {
let config = DeviceConfig::default();
assert_eq!(config.cuda_device, 0);
assert!(!config.force_cpu);
assert!(config.crate_name.is_none());
}
#[test]
fn test_device_config_builder() {
let config = DeviceConfig::new()
.with_cuda_device(1)
.with_force_cpu(true)
.with_crate_name("test-crate");
assert_eq!(config.cuda_device, 1);
assert!(config.force_cpu);
assert_eq!(config.crate_name.as_deref(), Some("test-crate"));
}
#[test]
fn test_force_cpu_returns_cpu() {
let config = DeviceConfig::new().with_force_cpu(true);
let device = get_device(&config).unwrap();
assert!(matches!(device, Device::Cpu));
}
}