use candle_core::{DType, IndexOp, Result, Tensor};
pub mod core;
pub mod simple;
pub mod standard;
pub mod talker;
pub fn rotate_half(x: &Tensor) -> Result<Tensor> {
let last_dim = x.dim(candle_core::D::Minus1)?;
let half = last_dim / 2;
let x1 = x.narrow(candle_core::D::Minus1, 0, half)?;
let x2 = x.narrow(candle_core::D::Minus1, half, half)?;
Tensor::cat(&[&x2.neg()?, &x1], candle_core::D::Minus1)
}
pub fn apply_rotary_pos_emb(
q: &Tensor,
k: &Tensor,
cos: &Tensor,
sin: &Tensor,
) -> Result<(Tensor, Tensor)> {
let q = q.contiguous()?;
let k = k.contiguous()?;
let q_embed = candle_nn::rotary_emb::rope(&q, cos, sin)?;
let k_embed = candle_nn::rotary_emb::rope(&k, cos, sin)?;
Ok((q_embed, k_embed))
}
pub fn apply_rotary_pos_emb_manual(
q: &Tensor,
k: &Tensor,
cos: &Tensor,
sin: &Tensor,
) -> Result<(Tensor, Tensor)> {
let original_dtype = q.dtype();
let q_f32 = q.to_dtype(DType::F32)?;
let k_f32 = k.to_dtype(DType::F32)?;
let cos_f32 = cos.to_dtype(DType::F32)?.unsqueeze(1)?;
let sin_f32 = sin.to_dtype(DType::F32)?.unsqueeze(1)?;
let q_embed = q_f32
.broadcast_mul(&cos_f32)?
.broadcast_add(&rotate_half(&q_f32)?.broadcast_mul(&sin_f32)?)?;
let k_embed = k_f32
.broadcast_mul(&cos_f32)?
.broadcast_add(&rotate_half(&k_f32)?.broadcast_mul(&sin_f32)?)?;
Ok((
q_embed.to_dtype(original_dtype)?,
k_embed.to_dtype(original_dtype)?,
))
}
pub fn apply_multimodal_rotary_pos_emb(
q: &Tensor,
k: &Tensor,
cos: &Tensor,
sin: &Tensor,
mrope_section: &[usize],
interleaved: bool,
) -> Result<(Tensor, Tensor)> {
if interleaved {
apply_multimodal_rotary_pos_emb_interleaved(q, k, cos, sin, mrope_section)
} else {
apply_multimodal_rotary_pos_emb_standard(q, k, cos, sin, mrope_section)
}
}
fn apply_multimodal_rotary_pos_emb_standard(
q: &Tensor,
k: &Tensor,
cos: &Tensor,
sin: &Tensor,
mrope_section: &[usize],
) -> Result<(Tensor, Tensor)> {
let mut cos_parts = Vec::new();
let mut sin_parts = Vec::new();
let mut offset = 0;
for (i, §ion_size) in mrope_section.iter().enumerate() {
let cos_modality = cos.i(i)?; let sin_modality = sin.i(i)?;
let cos_section = cos_modality.narrow(candle_core::D::Minus1, offset, section_size)?;
let sin_section = sin_modality.narrow(candle_core::D::Minus1, offset, section_size)?;
cos_parts.push(cos_section);
sin_parts.push(sin_section);
offset += section_size;
}
let cos_half = Tensor::cat(
&cos_parts.iter().collect::<Vec<_>>(),
candle_core::D::Minus1,
)?
.contiguous()?;
let sin_half = Tensor::cat(
&sin_parts.iter().collect::<Vec<_>>(),
candle_core::D::Minus1,
)?
.contiguous()?;
let q = q.contiguous()?;
let k = k.contiguous()?;
let q_embed = candle_nn::rotary_emb::rope(&q, &cos_half, &sin_half)?;
let k_embed = candle_nn::rotary_emb::rope(&k, &cos_half, &sin_half)?;
Ok((q_embed, k_embed))
}
fn apply_multimodal_rotary_pos_emb_interleaved(
q: &Tensor,
k: &Tensor,
cos: &Tensor,
sin: &Tensor,
mrope_section: &[usize],
) -> Result<(Tensor, Tensor)> {
let (_modalities, _batch, _seq_len, half_dim) = cos.dims4()?;
let modality_num = mrope_section.len();
let original_dtype = cos.dtype();
let cos_half = cos.contiguous()?.to_dtype(DType::F32)?;
let sin_half = sin.contiguous()?.to_dtype(DType::F32)?;
let m1_end = if mrope_section.len() > 1 {
(mrope_section[1] * modality_num).min(half_dim)
} else {
0
};
let m2_end = if mrope_section.len() > 2 {
(mrope_section[2] * modality_num).min(half_dim)
} else {
0
};
let mut m0_positions: Vec<u32> = Vec::new();
let mut m1_positions: Vec<u32> = Vec::new();
let mut m2_positions: Vec<u32> = Vec::new();
let mut output_modality: Vec<usize> = Vec::with_capacity(half_dim);
for pos in 0..half_dim {
let modality = if modality_num >= 3 && mrope_section.len() >= 3 {
if pos >= 1 && pos < m1_end && (pos - 1) % modality_num == 0 {
1
} else if pos >= 2 && pos < m2_end && (pos - 2) % modality_num == 0 {
2
} else {
0
}
} else {
0
};
output_modality.push(modality);
match modality {
0 => m0_positions.push(pos as u32),
1 => m1_positions.push(pos as u32),
2 => m2_positions.push(pos as u32),
_ => unreachable!(),
}
}
let cos_m0 = cos_half.i(0)?.contiguous()?; let sin_m0 = sin_half.i(0)?.contiguous()?;
let cos_m1 = cos_half.i(1)?.contiguous()?;
let sin_m1 = sin_half.i(1)?.contiguous()?;
let cos_m2 = cos_half.i(2)?.contiguous()?;
let sin_m2 = sin_half.i(2)?.contiguous()?;
let mut all_cos_parts: Vec<Tensor> = Vec::with_capacity(half_dim);
let mut all_sin_parts: Vec<Tensor> = Vec::with_capacity(half_dim);
for (pos, &modality) in output_modality.iter().enumerate() {
let (cos_src, sin_src) = match modality {
0 => (&cos_m0, &sin_m0),
1 => (&cos_m1, &sin_m1),
2 => (&cos_m2, &sin_m2),
_ => unreachable!(),
};
let cos_col = cos_src.narrow(2, pos, 1)?;
let sin_col = sin_src.narrow(2, pos, 1)?;
all_cos_parts.push(cos_col);
all_sin_parts.push(sin_col);
}
let cos_interleaved = Tensor::cat(&all_cos_parts.iter().collect::<Vec<_>>(), 2)?;
let sin_interleaved = Tensor::cat(&all_sin_parts.iter().collect::<Vec<_>>(), 2)?;
let cos_half = cos_interleaved.to_dtype(original_dtype)?.contiguous()?;
let sin_half = sin_interleaved.to_dtype(original_dtype)?.contiguous()?;
let q = q.contiguous()?;
let k = k.contiguous()?;
let q_embed = candle_nn::rotary_emb::rope(&q, &cos_half, &sin_half)?;
let k_embed = candle_nn::rotary_emb::rope(&k, &cos_half, &sin_half)?;
Ok((q_embed, k_embed))
}
#[cfg(test)]
mod tests {
use crate::nn::rope::standard::RotaryEmbedding;
use crate::nn::rope::{
apply_multimodal_rotary_pos_emb_interleaved, apply_multimodal_rotary_pos_emb_standard,
rotate_half,
};
use crate::nn::rope_scaling::RopeScalingType;
use candle_core::{Device, Result, Tensor};
#[test]
fn test_rotate_half() -> Result<()> {
let device = Device::Cpu;
let x = Tensor::arange(0f32, 8.0, &device)?.reshape((1, 1, 1, 8))?;
let rotated = rotate_half(&x)?;
let rotated_flat = rotated.flatten_all()?.to_vec1::<f32>()?;
assert_eq!(
rotated_flat,
vec![-4.0, -5.0, -6.0, -7.0, 0.0, 1.0, 2.0, 3.0]
);
Ok(())
}
#[test]
fn test_standard_rope() -> Result<()> {
let device = Device::Cpu;
let rope = RotaryEmbedding::new(64, 8192, 10000.0, &device)?;
let x = Tensor::randn(0f32, 1.0, (2, 4, 10, 64), &device)?;
let position_ids = Tensor::arange(0i64, 10, &device)?
.unsqueeze(0)?
.repeat((2, 1))?;
let (cos, sin) = rope.forward(&x, &position_ids)?;
assert_eq!(cos.dims(), &[2, 10, 32]);
assert_eq!(sin.dims(), &[2, 10, 32]);
assert_eq!(rope.scaling_type(), RopeScalingType::Default);
Ok(())
}
#[test]
fn test_rope_scaling_type_parsing() {
assert_eq!(RopeScalingType::parse("default"), RopeScalingType::Default);
assert_eq!(RopeScalingType::parse("linear"), RopeScalingType::Linear);
assert_eq!(RopeScalingType::parse("dynamic"), RopeScalingType::Dynamic);
assert_eq!(RopeScalingType::parse("ntk"), RopeScalingType::Dynamic);
assert_eq!(RopeScalingType::parse("yarn"), RopeScalingType::Yarn);
assert_eq!(
RopeScalingType::parse("longrope"),
RopeScalingType::LongRope
);
assert_eq!(RopeScalingType::parse("llama3"), RopeScalingType::Llama3);
assert_eq!(RopeScalingType::parse("unknown"), RopeScalingType::Default);
}
#[test]
fn test_interleaved_multimodal_rope() -> Result<()> {
let device = Device::Cpu;
let x = Tensor::randn(0f32, 1.0, (2, 8, 5, 64), &device)?;
let cos = Tensor::randn(0f32, 1.0, (3, 2, 5, 32), &device)?;
let sin = Tensor::randn(0f32, 1.0, (3, 2, 5, 32), &device)?;
let mrope_section = &[8, 12, 12];
let (q_std, k_std) =
apply_multimodal_rotary_pos_emb_standard(&x, &x, &cos, &sin, mrope_section)?;
let (q_int, k_int) =
apply_multimodal_rotary_pos_emb_interleaved(&x, &x, &cos, &sin, mrope_section)?;
assert_eq!(q_std.dims(), &[2, 8, 5, 64]);
assert_eq!(k_std.dims(), &[2, 8, 5, 64]);
assert_eq!(q_int.dims(), &[2, 8, 5, 64]);
assert_eq!(k_int.dims(), &[2, 8, 5, 64]);
Ok(())
}
}