#![allow(clippy::module_name_repetitions)]
#![allow(clippy::needless_pass_by_value)]
#![allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
use candle_core::{D, Result, Tensor};
use candle_nn::VarBuilder;
pub fn apply_rope(qk: &Tensor, freqs_cos: &Tensor, freqs_sin: &Tensor) -> Result<Tensor> {
let (_b, _h, seq_len, head_dim) = qk.dims4()?;
let rope_dim = freqs_cos.dim(D::Minus1)? * 2;
if rope_dim > head_dim {
candle_core::bail!("rope_dim {} cannot exceed head_dim {}", rope_dim, head_dim);
}
if rope_dim == head_dim {
let cos = freqs_cos.narrow(0, 0, seq_len)?.contiguous()?;
let sin = freqs_sin.narrow(0, 0, seq_len)?.contiguous()?;
return candle_nn::rotary_emb::rope(&qk.contiguous()?, &cos, &sin);
}
let rotated = qk.narrow(D::Minus1, 0, rope_dim)?.contiguous()?;
let unrotated = qk.narrow(D::Minus1, rope_dim, head_dim - rope_dim)?;
let cos = freqs_cos.narrow(0, 0, seq_len)?.contiguous()?;
let sin = freqs_sin.narrow(0, 0, seq_len)?.contiguous()?;
let rotated = candle_nn::rotary_emb::rope(&rotated, &cos, &sin)?;
Tensor::cat(&[&rotated, &unrotated.contiguous()?], D::Minus1)
}
pub fn precompute_rope_freqs(
max_seq_len: usize,
rope_dim: usize,
device: &candle_core::Device,
) -> Result<(Tensor, Tensor)> {
assert!(
rope_dim.is_multiple_of(2),
"rope_dim must be even (got {rope_dim})"
);
let half = rope_dim / 2;
let base = 10_000f32;
let inv_freq: Vec<f32> = (0..half)
.map(|i| {
#[allow(clippy::cast_precision_loss)]
let exp = i as f32 / half as f32;
base.powf(-exp)
})
.collect();
let inv_freq = Tensor::from_vec(inv_freq, half, device)?;
let positions: Vec<f32> = (0..max_seq_len)
.map(|i| {
#[allow(clippy::cast_precision_loss)]
let v = i as f32;
v
})
.collect();
let positions = Tensor::from_vec(positions, max_seq_len, device)?;
let freqs = positions
.unsqueeze(1)?
.broadcast_mul(&inv_freq.unsqueeze(0)?)?;
let cos = freqs.cos()?;
let sin = freqs.sin()?;
Ok((cos, sin))
}
#[derive(Debug)]
pub struct FourierFeatures {
weight: Tensor,
out_features: usize,
}
impl FourierFeatures {
pub fn new(in_features: usize, out_features: usize, vb: VarBuilder) -> Result<Self> {
assert!(
out_features.is_multiple_of(2),
"fourier feature out_features must be even (got {out_features})"
);
let weight = vb.get((out_features / 2, in_features), "weight")?;
Ok(Self {
weight,
out_features,
})
}
#[must_use]
pub fn out_features(&self) -> usize {
self.out_features
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let weight_t = self.weight.t()?.contiguous()?;
let weight_t = weight_t.to_dtype(x.dtype())?;
let two_pi = std::f64::consts::TAU;
let f = x.matmul(&weight_t)?;
let f = (f * two_pi)?;
let cos = f.cos()?;
let sin = f.sin()?;
Tensor::cat(&[&cos, &sin], D::Minus1)
}
}