oxibonsai-model 0.1.4

Qwen3-8B Transformer implementation for OxiBonsai 1-bit inference
Documentation
//! Rotary Position Embeddings (RoPE).
//!
//! Applies rotation to query and key vectors to encode positional
//! information. Uses precomputed cos/sin tables.

use crate::error::ModelResult;

/// Precomputed RoPE sin/cos table.
#[derive(Debug)]
pub struct RopeTable {
    /// Cosine values: [max_seq_len × head_dim/2].
    cos: Vec<f32>,
    /// Sine values: [max_seq_len × head_dim/2].
    sin: Vec<f32>,
    /// Half of head dimension (rotation pairs).
    half_dim: usize,
    /// Maximum sequence length.
    max_seq_len: usize,
}

impl RopeTable {
    /// Precompute RoPE rotation table.
    ///
    /// - `head_dim`: Dimension of each attention head.
    /// - `max_seq_len`: Maximum sequence length to precompute.
    /// - `freq_base`: RoPE frequency base (default: 1000000.0 for Qwen3).
    pub fn new(head_dim: usize, max_seq_len: usize, freq_base: f32) -> Self {
        let half_dim = head_dim / 2;
        let mut cos = vec![0.0f32; max_seq_len * half_dim];
        let mut sin = vec![0.0f32; max_seq_len * half_dim];

        for pos in 0..max_seq_len {
            for i in 0..half_dim {
                let freq = 1.0 / freq_base.powf(2.0 * i as f32 / head_dim as f32);
                let angle = pos as f32 * freq;
                cos[pos * half_dim + i] = angle.cos();
                sin[pos * half_dim + i] = angle.sin();
            }
        }

        Self {
            cos,
            sin,
            half_dim,
            max_seq_len,
        }
    }

    /// Apply RoPE rotation to a query or key vector at the given position.
    ///
    /// - `vec`: Input vector of length `head_dim` (modified in-place via output).
    /// - `output`: Output vector of length `head_dim`.
    /// - `pos`: Token position in the sequence.
    ///
    /// Delegates the inner rotation to SIMD-accelerated `oxibonsai_kernels::rope_apply_simd`.
    pub fn apply(&self, vec: &[f32], output: &mut [f32], pos: usize) -> ModelResult<()> {
        debug_assert!(pos < self.max_seq_len);
        debug_assert_eq!(vec.len(), self.half_dim * 2);
        debug_assert!(output.len() >= self.half_dim * 2);

        let cos_row = &self.cos[pos * self.half_dim..(pos + 1) * self.half_dim];
        let sin_row = &self.sin[pos * self.half_dim..(pos + 1) * self.half_dim];

        oxibonsai_kernels::rope_apply_simd(vec, output, cos_row, sin_row);

        Ok(())
    }

    /// Maximum precomputed sequence length.
    pub fn max_seq_len(&self) -> usize {
        self.max_seq_len
    }

    /// Get cos values for a given position: `&[half_dim]`.
    pub fn cos_at(&self, pos: usize) -> &[f32] {
        &self.cos[pos * self.half_dim..(pos + 1) * self.half_dim]
    }

    /// Get sin values for a given position: `&[half_dim]`.
    pub fn sin_at(&self, pos: usize) -> &[f32] {
        &self.sin[pos * self.half_dim..(pos + 1) * self.half_dim]
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn rope_at_position_zero_is_identity() {
        let table = RopeTable::new(4, 16, 10000.0);
        let input = vec![1.0, 2.0, 3.0, 4.0];
        let mut output = vec![0.0; 4];

        table
            .apply(&input, &mut output, 0)
            .expect("rope apply should succeed");

        // At position 0, cos=1, sin=0 → identity
        assert!((output[0] - 1.0).abs() < 1e-5);
        assert!((output[1] - 2.0).abs() < 1e-5);
        assert!((output[2] - 3.0).abs() < 1e-5);
        assert!((output[3] - 4.0).abs() < 1e-5);
    }

    #[test]
    fn rope_preserves_norm() {
        let table = RopeTable::new(4, 16, 10000.0);
        let input = vec![1.0, 0.0, 0.0, 1.0];
        let mut output = vec![0.0; 4];

        table
            .apply(&input, &mut output, 5)
            .expect("rope apply should succeed");

        let input_norm: f32 = input.iter().map(|x| x * x).sum::<f32>().sqrt();
        let output_norm: f32 = output.iter().map(|x| x * x).sum::<f32>().sqrt();
        assert!(
            (input_norm - output_norm).abs() < 1e-4,
            "RoPE should preserve vector norm"
        );
    }
}