rlx-gemma 0.2.0

Gemma / Gemma 2 causal LMs for RLX
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.

use super::config::GemmaConfig;

pub fn default_inv_freq(rope_theta: f64, head_dim: usize) -> Vec<f64> {
    (0..head_dim)
        .step_by(2)
        .map(|i| 1.0 / rope_theta.powf(i as f64 / head_dim as f64))
        .collect()
}

pub fn resolve_inv_freq(cfg: &GemmaConfig, rope_freq_factors: Option<&[f32]>) -> Vec<f64> {
    let base = default_inv_freq(cfg.rope_theta, cfg.head_dim());
    match rope_freq_factors {
        Some(f) if !f.is_empty() => base
            .iter()
            .zip(f.iter())
            .map(|(f, ff)| f / *ff as f64)
            .collect(),
        _ => base,
    }
}

pub fn build_rope_tables(inv_freq: &[f64], max_pos: usize) -> (Vec<f32>, Vec<f32>) {
    let half = inv_freq.len();
    let mut cos = vec![0f32; max_pos * half];
    let mut sin = vec![0f32; max_pos * half];
    for pos in 0..max_pos {
        for (i, &freq) in inv_freq.iter().enumerate() {
            let angle = pos as f64 * freq;
            cos[pos * half + i] = angle.cos() as f32;
            sin[pos * half + i] = angle.sin() as f32;
        }
    }
    (cos, sin)
}

/// Single-row cos/sin slice for decode at absolute position `pos`.
pub fn rope_slice(inv_freq: &[f64], pos: usize) -> (Vec<f32>, Vec<f32>) {
    let half = inv_freq.len();
    let mut cos = vec![0f32; half];
    let mut sin = vec![0f32; half];
    for (i, &freq) in inv_freq.iter().enumerate() {
        let angle = pos as f64 * freq;
        let (s, c) = angle.sin_cos();
        cos[i] = c as f32;
        sin[i] = s as f32;
    }
    (cos, sin)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn tables_len_matches_max_pos() {
        let inv = default_inv_freq(10_000.0, 8);
        let (c, s) = build_rope_tables(&inv, 4);
        assert_eq!(c.len(), 4 * inv.len());
        assert_eq!(s.len(), c.len());
    }
}