rlx-llama32 0.2.1

LLaMA 3.2 for RLX
Documentation
//! Backend capability checks for llama32 graphs.
//!
//! All portable GPU backends support the standard Llama decoder IR plus
//! packed GGUF K-quant (`Op::DequantMatMul`).

use crate::config::Llama32Config;
use anyhow::{Result, bail};
use rlx_runtime::Device;

/// Validate that `device` can run a llama32 graph with the given options.
pub fn validate_device(cfg: &Llama32Config, device: Device, packed_weights: bool) -> Result<()> {
    let _ = (cfg, packed_weights);
    match device {
        Device::Cpu | Device::Metal | Device::Mlx | Device::Cuda | Device::Rocm | Device::Gpu => {
            Ok(())
        }
        other => bail!(
            "llama32: device {other:?} is not supported \
             (use Device::Cpu, Device::Metal, Device::Mlx, Device::Cuda, Device::Rocm, or Device::Gpu)"
        ),
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::Llama32Config;

    fn tiny_cfg() -> Llama32Config {
        Llama32Config {
            vocab_size: 32,
            hidden_size: 16,
            intermediate_size: 32,
            num_hidden_layers: 1,
            num_attention_heads: 4,
            num_key_value_heads: 2,
            max_position_embeddings: 16,
            rms_norm_eps: 1e-5,
            rope_theta: 500_000.0,
            hidden_act: "silu".into(),
            tie_word_embeddings: false,
            attention_bias: false,
            head_dim: None,
            rope_scaling: None,
        }
    }

    #[test]
    fn all_backends_allowed() {
        let cfg = tiny_cfg();
        for dev in [
            Device::Cpu,
            Device::Metal,
            Device::Mlx,
            Device::Cuda,
            Device::Rocm,
            Device::Gpu,
        ] {
            validate_device(&cfg, dev, false).unwrap();
            validate_device(&cfg, dev, true).unwrap();
        }
    }
}