wax-core 0.1.0

Core inference engine for wax, a small Candle-based local LLM runner
Documentation
use candle_core::{utils, DType, Device};

use crate::{Result, WaxError};

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DeviceChoice {
    Auto,
    Cpu,
    Cuda,
    Metal,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DTypeChoice {
    Auto,
    F32,
    F16,
    BF16,
}

pub fn select_device(choice: DeviceChoice) -> Result<Device> {
    match choice {
        DeviceChoice::Auto => {
            if utils::cuda_is_available() {
                Ok(Device::new_cuda(0)?)
            } else if utils::metal_is_available() {
                Ok(Device::new_metal(0)?)
            } else {
                Ok(Device::Cpu)
            }
        }
        DeviceChoice::Cpu => Ok(Device::Cpu),
        DeviceChoice::Cuda => {
            if !utils::cuda_is_available() {
                return Err(WaxError::InvalidRequest(
                    "CUDA was requested but is not available in this build/runtime".to_string(),
                ));
            }
            Ok(Device::new_cuda(0)?)
        }
        DeviceChoice::Metal => {
            if !utils::metal_is_available() {
                return Err(WaxError::InvalidRequest(
                    "Metal was requested but is not available in this build/runtime".to_string(),
                ));
            }
            Ok(Device::new_metal(0)?)
        }
    }
}

pub fn select_dtype(choice: DTypeChoice, device: &Device) -> DType {
    match choice {
        DTypeChoice::Auto => match device {
            Device::Cpu => DType::F32,
            Device::Cuda(_) | Device::Metal(_) => DType::F16,
        },
        DTypeChoice::F32 => DType::F32,
        DTypeChoice::F16 => DType::F16,
        DTypeChoice::BF16 => DType::BF16,
    }
}

pub fn device_label(device: &Device) -> String {
    match device {
        Device::Cpu => "cpu".to_string(),
        Device::Cuda(_) => "cuda:0".to_string(),
        Device::Metal(_) => "metal:0".to_string(),
    }
}

pub fn dtype_label(dtype: DType) -> String {
    format!("{dtype:?}").to_ascii_lowercase()
}

#[cfg(test)]
mod tests {
    use candle_core::{DType, Device};

    use super::{device_label, dtype_label, select_dtype, DTypeChoice};

    #[test]
    fn cpu_auto_dtype_defaults_to_f32() {
        assert_eq!(select_dtype(DTypeChoice::Auto, &Device::Cpu), DType::F32);
    }

    #[test]
    fn explicit_dtype_overrides_device_default() {
        assert_eq!(select_dtype(DTypeChoice::F16, &Device::Cpu), DType::F16);
        assert_eq!(select_dtype(DTypeChoice::BF16, &Device::Cpu), DType::BF16);
    }

    #[test]
    fn labels_are_stable_for_stats() {
        assert_eq!(device_label(&Device::Cpu), "cpu");
        assert_eq!(dtype_label(DType::F32), "f32");
        assert_eq!(dtype_label(DType::F16), "f16");
    }
}