use crate::config::Llama32Config;
use anyhow::{Result, bail};
use rlx_runtime::Device;
pub fn validate_device(cfg: &Llama32Config, device: Device, packed_weights: bool) -> Result<()> {
let _ = (cfg, packed_weights);
match device {
Device::Cpu | Device::Metal | Device::Mlx | Device::Cuda | Device::Rocm | Device::Gpu => {
Ok(())
}
other => bail!(
"llama32: device {other:?} is not supported \
(use Device::Cpu, Device::Metal, Device::Mlx, Device::Cuda, Device::Rocm, or Device::Gpu)"
),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Llama32Config;
fn tiny_cfg() -> Llama32Config {
Llama32Config {
vocab_size: 32,
hidden_size: 16,
intermediate_size: 32,
num_hidden_layers: 1,
num_attention_heads: 4,
num_key_value_heads: 2,
max_position_embeddings: 16,
rms_norm_eps: 1e-5,
rope_theta: 500_000.0,
hidden_act: "silu".into(),
tie_word_embeddings: false,
attention_bias: false,
head_dim: None,
rope_scaling: None,
}
}
#[test]
fn all_backends_allowed() {
let cfg = tiny_cfg();
for dev in [
Device::Cpu,
Device::Metal,
Device::Mlx,
Device::Cuda,
Device::Rocm,
Device::Gpu,
] {
validate_device(&cfg, dev, false).unwrap();
validate_device(&cfg, dev, true).unwrap();
}
}
}