use candle_core::{Result, Tensor};
use candle_nn::VarBuilder;
use crate::{config::talker_config::TalkerConfig, nn::kv_cache::KVCache};
use crate::nn::attention::rope_strategy::RopeStrategy;
use crate::nn::attention::unified::UnifiedAttention;
#[derive(Debug, Clone)]
pub struct TalkerAttention {
inner: UnifiedAttention,
}
impl TalkerAttention {
pub fn new(
config: &TalkerConfig,
layer_idx: usize,
use_flash_attn: bool,
vb: VarBuilder,
) -> Result<Self> {
let (mrope_section, interleaved) = if let Some(ref rope_scaling) = config.rope_scaling {
(rope_scaling.mrope_section.clone(), rope_scaling.interleaved)
} else {
(vec![16, 24, 24], false) };
let inner = UnifiedAttention::new(
config,
RopeStrategy::multimodal(mrope_section, interleaved),
layer_idx,
use_flash_attn,
vb,
)?;
Ok(Self { inner })
}
pub fn load(
config: &TalkerConfig,
layer_idx: usize,
use_flash_attn: bool,
vb: VarBuilder,
) -> Result<Self> {
Self::new(config, layer_idx, use_flash_attn, vb)
}
pub fn layer_idx(&self) -> usize {
self.inner.layer_idx()
}
pub fn get_sliding_window(&self) -> Option<usize> {
self.inner.sliding_window()
}
pub fn forward(
&self,
hidden_states: &Tensor,
position_embeddings: (&Tensor, &Tensor),
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
self.inner
.forward(hidden_states, position_embeddings, attention_mask)
}
pub fn forward_with_cache(
&self,
hidden_states: &Tensor,
position_embeddings: (&Tensor, &Tensor),
attention_mask: Option<&Tensor>,
cache: &mut KVCache,
) -> Result<Tensor> {
self.inner
.forward_with_cache(hidden_states, position_embeddings, attention_mask, cache)
}
}