use candle_core::{Result, Tensor};
use crate::nn::rope::{apply_multimodal_rotary_pos_emb, apply_rotary_pos_emb};
#[derive(Debug, Clone)]
pub enum RopeStrategy {
Standard,
Multimodal {
mrope_section: Vec<usize>,
interleaved: bool,
},
}
impl RopeStrategy {
pub fn apply(
&self,
q: &Tensor,
k: &Tensor,
cos: &Tensor,
sin: &Tensor,
) -> Result<(Tensor, Tensor)> {
match self {
RopeStrategy::Standard => apply_rotary_pos_emb(q, k, cos, sin),
RopeStrategy::Multimodal {
mrope_section,
interleaved,
} => apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, *interleaved),
}
}
pub fn standard() -> Self {
RopeStrategy::Standard
}
pub fn multimodal(mrope_section: Vec<usize>, interleaved: bool) -> Self {
RopeStrategy::Multimodal {
mrope_section,
interleaved,
}
}
}