use anyhow::{Ok, Result};
use candle_core::{D, Device, Tensor};
use candle_nn::{Embedding, Linear, Module, RmsNorm, VarBuilder, embedding, rms_norm};
use crate::{
models::{
common::{
InferenceModel,
modules::{GateUpDownMLP, NaiveAttention},
},
minicpm4::config::MiniCPM4Config,
},
position_embed::rope::compute_default_rope_parameters,
utils::tensor_utils::prepare_causal_attention_mask,
};
pub struct MiniCPMLongRoPE {
short_factor: Vec<f32>,
long_factor: Vec<f32>,
original_max_position_embeddings: usize,
max_seq_len_cached: usize,
scaling_factor: f64,
inv_freq: Tensor,
cos_cached: Tensor,
sin_cached: Tensor,
device: Device,
}
impl MiniCPMLongRoPE {
pub fn new(cfg: &MiniCPM4Config, device: &Device) -> Result<Self> {
let head_dim = cfg.hidden_size / cfg.num_attention_heads;
let rope_theta = 10000.0;
let short_factor = cfg.rope_scaling.short_factor.clone();
let long_factor = cfg.rope_scaling.short_factor.clone();
let original_max_position_embeddings = cfg.rope_scaling.original_max_position_embeddings;
let max_position_embeddings = cfg.max_position_embeddings;
let scale = max_position_embeddings as f64 / original_max_position_embeddings as f64;
let scaling_factor =
(1.0 + scale.ln() / (original_max_position_embeddings as f64).ln()).sqrt();
let inv_freq = compute_default_rope_parameters(head_dim, rope_theta);
let inv_freq = Tensor::from_slice(&inv_freq, (1, inv_freq.len()), device)?;
let max_seq_len_cached = max_position_embeddings;
let t = Tensor::arange(0.0_f32, max_position_embeddings as f32, device)?
.reshape((max_position_embeddings, 1))?;
let ext_factors = Tensor::from_slice(&short_factor, (1, short_factor.len()), device)?;
let ext_factors = Tensor::ones_like(&ext_factors)?.div(&ext_factors)?;
let freqs = t.matmul(&ext_factors)?.broadcast_mul(&inv_freq)?;
let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
let cos_cached = emb.cos()?.affine(scaling_factor, 0.0)?;
let sin_cached = emb.sin()?.affine(scaling_factor, 0.0)?;
Ok(Self {
short_factor,
long_factor,
original_max_position_embeddings,
max_seq_len_cached,
scaling_factor,
inv_freq,
cos_cached,
sin_cached,
device: device.clone(),
})
}
pub fn update_cos_sin_cache(&mut self, seqlen: usize) -> Result<()> {
self.max_seq_len_cached = seqlen;
let t = Tensor::arange(0.0_f32, seqlen as f32, &self.device)?.reshape((seqlen, 1))?;
let mut ext_factors = Tensor::from_slice(
&self.short_factor,
(1, self.short_factor.len()),
&self.device,
)?;
if seqlen > self.original_max_position_embeddings {
ext_factors =
Tensor::from_slice(&self.long_factor, (1, self.long_factor.len()), &self.device)?;
}
let ext_factors = Tensor::ones_like(&ext_factors)?.div(&ext_factors)?;
let freqs = t.matmul(&ext_factors)?.broadcast_mul(&self.inv_freq)?;
let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
let cos_cached = emb.cos()?.affine(self.scaling_factor, 0.0)?;
let sin_cached = emb.sin()?.affine(self.scaling_factor, 0.0)?;
self.cos_cached = cos_cached;
self.sin_cached = sin_cached;
Ok(())
}
pub fn forward(&mut self, pos_offset: usize, seqlen: usize) -> Result<(Tensor, Tensor)> {
if pos_offset + seqlen > self.max_seq_len_cached {
self.update_cos_sin_cache(pos_offset + seqlen)?;
}
let cos = self.cos_cached.narrow(0, pos_offset, seqlen)?;
let sin = self.sin_cached.narrow(0, pos_offset, seqlen)?;
Ok((cos, sin))
}
}
pub struct MiniCPMDecoderLayer {
self_attn: NaiveAttention,
mlp: GateUpDownMLP,
input_layernorm: RmsNorm,
post_attention_layernorm: RmsNorm,
scale_depth: f32,
num_hidden_layers: usize,
}
impl MiniCPMDecoderLayer {
pub fn new(vb: VarBuilder, cfg: &MiniCPM4Config) -> Result<Self> {
let self_attn = NaiveAttention::new(
vb.pp("self_attn"),
cfg.hidden_size,
cfg.num_attention_heads,
cfg.num_key_value_heads,
None,
false,
None,
None,
None,
None,
)?;
let mlp = GateUpDownMLP::new(
vb.pp("mlp"),
cfg.hidden_size,
cfg.intermediate_size,
cfg.hidden_act,
false,
None,
None,
None,
)?;
let input_layernorm =
rms_norm(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
let post_attention_layernorm = rms_norm(
cfg.hidden_size,
cfg.rms_norm_eps,
vb.pp("post_attention_layernorm"),
)?;
Ok(Self {
self_attn,
mlp,
input_layernorm,
post_attention_layernorm,
scale_depth: cfg.scale_depth,
num_hidden_layers: cfg.num_hidden_layers,
})
}
pub fn forward(
&self,
xs: &Tensor,
cos: &Tensor,
sin: &Tensor,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let residual = xs.clone();
let xs = self.input_layernorm.forward(xs)?;
let xs = self
.self_attn
.forward(&xs, Some(cos), Some(sin), attention_mask, true)?;
let xs = (residual
+ xs.affine(
self.scale_depth as f64 / (self.num_hidden_layers as f64).sqrt(),
0.0,
))?;
let residual = &xs;
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
let xs = (residual
+ xs.affine(
self.scale_depth as f64 / (self.num_hidden_layers as f64).sqrt(),
0.0,
)?)?;
Ok(xs)
}
pub fn forward_with_cache(
&mut self,
xs: &Tensor,
cos: &Tensor,
sin: &Tensor,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let residual = xs.clone();
let xs = self.input_layernorm.forward(xs)?;
let xs = self
.self_attn
.forward_with_cache(&xs, cos, sin, attention_mask, true)?;
let xs = (residual
+ xs.affine(
self.scale_depth as f64 / (self.num_hidden_layers as f64).sqrt(),
0.0,
))?;
let residual = &xs;
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
let xs = (residual
+ xs.affine(
self.scale_depth as f64 / (self.num_hidden_layers as f64).sqrt(),
0.0,
)?)?;
Ok(xs)
}
pub fn clear_kv_cache(&mut self) {
self.self_attn.clear_kv_cache();
}
}
pub struct MiniCPMModel {
cfg: MiniCPM4Config,
embed_tokens: Embedding,
layers: Vec<MiniCPMDecoderLayer>,
norm: RmsNorm,
rope_emb: MiniCPMLongRoPE,
lm_head: Linear,
stop_token_ids: Vec<u32>,
}
impl MiniCPMModel {
pub fn new(vb: VarBuilder, cfg: MiniCPM4Config, eos_ids: Vec<u32>) -> Result<Self> {
let vb = vb.pp("model");
let embed_tokens = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?;
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_layers = vb.pp("layers");
for i in 0..cfg.num_hidden_layers {
let layer = MiniCPMDecoderLayer::new(vb_layers.pp(i), &cfg)?;
layers.push(layer);
}
let norm = rms_norm(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("norm"))?;
let rope_emb = MiniCPMLongRoPE::new(&cfg, vb.device())?;
let lm_head = Linear::new(embed_tokens.embeddings().clone(), None);
Ok(Self {
cfg,
embed_tokens,
layers,
norm,
rope_emb,
lm_head,
stop_token_ids: eos_ids,
})
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (bs, seq_len) = input_ids.dims2()?;
let input_embeds = self
.embed_tokens
.forward(input_ids)?
.affine(self.cfg.scale_emb, 0.0)?;
let attention_mask: Option<Tensor> = {
if seq_len <= 1 {
None
} else {
Some(prepare_causal_attention_mask(
bs,
seq_len,
0,
input_ids.device(),
)?)
}
};
let (cos, sin) = self.rope_emb.forward(seqlen_offset, seq_len)?;
let mut hidden_states = input_embeds;
for decode_layer in &self.layers {
hidden_states =
decode_layer.forward(&hidden_states, &cos, &sin, attention_mask.as_ref())?;
}
hidden_states = self.norm.forward(&hidden_states)?;
let hidden_state = hidden_states.narrow(1, seq_len - 1, 1)?;
let hidden_state = hidden_state.affine(
1.0 / (self.cfg.hidden_size / self.cfg.dim_model_base) as f64,
0.0,
)?;
let logits = self.lm_head.forward(&hidden_state)?;
Ok(logits)
}
pub fn forward_with_cache(
&mut self,
input_ids: &Tensor,
seqlen_offset: usize,
) -> Result<Tensor> {
let (bs, seq_len) = input_ids.dims2()?;
let input_embeds = self
.embed_tokens
.forward(input_ids)?
.affine(self.cfg.scale_emb, 0.0)?;
let attention_mask: Option<Tensor> = {
if seq_len <= 1 {
None
} else {
Some(prepare_causal_attention_mask(
bs,
seq_len,
0,
input_ids.device(),
)?)
}
};
let (cos, sin) = self.rope_emb.forward(seqlen_offset, seq_len)?;
let mut hidden_states = input_embeds;
for decode_layer in &mut self.layers {
hidden_states = decode_layer.forward_with_cache(
&hidden_states,
&cos,
&sin,
attention_mask.as_ref(),
)?;
}
hidden_states = self.norm.forward(&hidden_states)?;
let hidden_state = hidden_states.narrow(1, seq_len - 1, 1)?;
let hidden_state = hidden_state.affine(
1.0 / (self.cfg.hidden_size / self.cfg.dim_model_base) as f64,
0.0,
)?;
let logits = self.lm_head.forward(&hidden_state)?;
Ok(logits)
}
pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()
}
}
}
impl InferenceModel for MiniCPMModel {
fn forward_step(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
self.forward_with_cache(input_ids, seqlen_offset)
}
fn clear_cache(&mut self) {
self.clear_kv_cache();
}
fn stop_token_ids(&self) -> Vec<u32> {
self.stop_token_ids.clone()
}
}