candle-mi 0.1.7

Mechanistic interpretability for language models in Rust, built on candle
Documentation
// SPDX-License-Identifier: MIT OR Apache-2.0

//! Rotary position embeddings (`RoPE`).
//!
//! Pre-computes `cos` and `sin` tensors at model load time and applies
//! them to query and key tensors during the forward pass.
//!
//! Uses `candle_nn::rotary_emb::rope()` for the actual rotation, matching
//! the reference implementation in plip-rs (frozen predecessor project, v1.4.0).

use candle_core::{DType, Device, Tensor};

use crate::error::Result;

// ---------------------------------------------------------------------------
// RoPE cache — pre-computed cos/sin
// ---------------------------------------------------------------------------

/// Pre-computed cosine and sine tensors for rotary position embeddings.
pub struct RopeCache {
    /// Cosine values: `[max_position, head_dim / 2]`.
    cos: Tensor,
    /// Sine values: `[max_position, head_dim / 2]`.
    sin: Tensor,
}

impl RopeCache {
    /// Pre-compute the `RoPE` cache.
    ///
    /// # Shapes
    /// - `cos`: `[max_position, head_dim / 2]`
    /// - `sin`: `[max_position, head_dim / 2]`
    ///
    /// # Errors
    ///
    /// Returns [`MIError::Model`] on tensor operation failures.
    pub fn new(
        head_dim: usize,
        max_position: usize,
        theta: f64,
        device: &Device,
        dtype: DType,
    ) -> Result<Self> {
        let half_dim = head_dim / 2;

        // Compute inverse frequencies: theta^(-2i/d) for i in 0..half_dim
        let inv_freq: Vec<f32> = (0..half_dim)
            .map(|i| {
                // CAST: usize → f64, loop index and head_dim fit in f64 mantissa
                #[allow(clippy::cast_precision_loss, clippy::as_conversions)]
                let freq = 1.0 / theta.powf(2.0 * i as f64 / head_dim as f64);
                // CAST: f64 → f32, precision loss acceptable for RoPE frequencies
                #[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
                let freq_f32 = freq as f32;
                freq_f32
            })
            .collect();

        let inv_freq_tensor = Tensor::from_vec(inv_freq, (1, half_dim), device)?.to_dtype(dtype)?;

        // Position indices: [0, 1, 2, ..., max_position - 1]
        // CAST: usize → u32, max_position fits in u32 (max ~128K)
        #[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
        let pos_tensor = Tensor::arange(0u32, max_position as u32, device)?
            .to_dtype(dtype)?
            .reshape((max_position, 1))?;

        // Outer product: [max_position, half_dim]
        let freqs = pos_tensor.matmul(&inv_freq_tensor)?;

        let cos = freqs.cos()?;
        let sin = freqs.sin()?;

        Ok(Self { cos, sin })
    }

    /// Apply rotary embeddings to a query or key tensor.
    ///
    /// Uses `candle_nn::rotary_emb::rope()` for the rotation.
    ///
    /// # Shapes
    /// - `x`: `[batch, n_heads, seq_len, head_dim]`
    /// - returns: `[batch, n_heads, seq_len, head_dim]`
    ///
    /// The `start_pos` parameter supports incremental generation (KV-cache):
    /// positions are offset by `start_pos` so that cached keys keep their
    /// original positional encoding.
    ///
    /// # Errors
    ///
    /// Returns [`MIError::Model`] on tensor operation or shape errors.
    pub fn apply(&self, x: &Tensor, start_pos: usize) -> Result<Tensor> {
        let (_, _, seq_len, _) = x.dims4()?;

        // Slice cos/sin for the relevant positions: [seq_len, half_dim]
        let cos = self.cos.narrow(0, start_pos, seq_len)?;
        let sin = self.sin.narrow(0, start_pos, seq_len)?;

        // candle_nn::rotary_emb::rope() expects contiguous input
        Ok(candle_nn::rotary_emb::rope(&x.contiguous()?, &cos, &sin)?)
    }
}