#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
use candle_core::{Device, Module, Result, Tensor};
use mistralrs_quant::{
ColumnParallelLayer, QuantMethod, QuantizedConfig, RowParallelLayer, ShardedVarBuilder,
};
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, sync::Arc};
use crate::{
amoe::{AnyMoeBaseModelMixin, MlpLayer},
attention::SdpaParams,
device_map::DeviceMapper,
layers::{embedding, Activation, CausalMasker, MatMul, Mlp, RmsNorm, RotaryEmbedding, Sdpa},
layers_masker::NotACache,
paged_attention::{AttentionImplementation, ModelConfigMetadata},
pipeline::{
text_models_inputs_processor::FlashParams, EmbeddingModel, IsqModel, NormalLoadingMetadata,
},
serde_default_fn,
utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
};
macro_rules! sliding_window {
($layer_idx:expr, $cfg:expr) => {
if !($cfg.sliding_window.is_some()
&& $cfg.use_sliding_window
&& $layer_idx >= $cfg.max_window_layers)
{
None
} else {
$cfg.sliding_window
}
};
}
serde_default_fn!(bool, tie_word_embeddings, false);
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Config {
pub(crate) vocab_size: usize,
pub(crate) hidden_size: usize,
pub(crate) intermediate_size: usize,
pub(crate) num_hidden_layers: usize,
pub(crate) num_attention_heads: usize,
pub(crate) num_key_value_heads: usize,
pub(crate) hidden_act: Activation,
pub(crate) max_position_embeddings: usize,
pub(crate) rms_norm_eps: f64,
pub(crate) rope_theta: f64,
pub(crate) sliding_window: Option<usize>,
pub(crate) head_dim: Option<usize>,
pub(crate) quantization_config: Option<QuantizedConfig>,
#[serde(default = "tie_word_embeddings")]
pub(crate) tie_word_embeddings: bool,
pub(crate) max_window_layers: usize,
pub(crate) use_sliding_window: bool,
}
impl Config {
pub(crate) fn head_dim(&self) -> usize {
self.head_dim
.unwrap_or(self.hidden_size / self.num_attention_heads)
}
}
struct Attention {
q_proj: Arc<dyn QuantMethod>,
k_proj: Arc<dyn QuantMethod>,
v_proj: Arc<dyn QuantMethod>,
o_proj: Arc<dyn QuantMethod>,
q_norm: RmsNorm,
k_norm: RmsNorm,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
rotary_emb: Arc<RotaryEmbedding>,
sdpa_params: SdpaParams,
}
impl Attention {
#[allow(clippy::too_many_arguments)]
fn new(
rotary_emb: Arc<RotaryEmbedding>,
cfg: &Config,
vb: ShardedVarBuilder,
mapper: &dyn DeviceMapper,
layer_idx: usize,
loading_isq: bool,
comm: &Arc<mistralrs_quant::Comm>,
) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
let head_dim = cfg.head_dim();
let q_proj = ColumnParallelLayer::new(
hidden_sz,
num_heads * head_dim,
&cfg.quantization_config,
false,
comm,
mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
)?;
let kv_shard = mistralrs_quant::compute_kv_shard(
cfg.num_key_value_heads,
cfg.hidden_size / cfg.num_attention_heads,
comm,
);
let k_proj = ColumnParallelLayer::new_with_shard(
hidden_sz,
num_kv_heads * head_dim,
&cfg.quantization_config,
false,
comm,
kv_shard,
mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
)?;
let v_proj = ColumnParallelLayer::new_with_shard(
hidden_sz,
num_kv_heads * head_dim,
&cfg.quantization_config,
false,
comm,
kv_shard,
mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
)?;
let o_proj = RowParallelLayer::new(
num_heads * head_dim,
hidden_sz,
&cfg.quantization_config,
false,
comm,
mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
)?;
let sliding_window = sliding_window!(layer_idx, cfg);
let q_norm = RmsNorm::new(
cfg.head_dim(),
cfg.rms_norm_eps,
mapper.set_device(layer_idx, vb.pp("q_norm"), false),
)?;
let k_norm = RmsNorm::new(
cfg.head_dim(),
cfg.rms_norm_eps,
mapper.set_device(layer_idx, vb.pp("k_norm"), false),
)?;
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
num_heads: num_heads / comm.world_size(),
num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
head_dim,
rotary_emb,
sdpa_params: SdpaParams {
n_kv_groups: mistralrs_quant::compute_n_kv_groups(
cfg.num_key_value_heads,
cfg.num_attention_heads,
comm,
),
softcap: None,
softmax_scale: 1.0 / (head_dim as f32).sqrt(),
sliding_window,
sinks: None,
},
})
}
#[allow(clippy::too_many_arguments)]
fn forward(
&self,
xs: &Tensor,
attention_mask: &Tensor,
seqlen_offsets: &[usize],
flash_params: &FlashParams,
) -> Result<Tensor> {
let (b_sz, q_len, _) = xs.dims3()?;
let original_dtype = xs.dtype();
let mut xs = xs.clone();
if let Some(t) = self.q_proj.quantized_act_type() {
xs = xs.to_dtype(t)?;
}
let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
if self.q_proj.quantized_act_type().is_some() {
q = q.to_dtype(original_dtype)?;
k = k.to_dtype(original_dtype)?;
v = v.to_dtype(original_dtype)?;
}
q = q
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
k = k
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
v = v
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
q = q.apply(&self.q_norm)?;
k = k.apply(&self.k_norm)?;
(q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
let mut attn_output = Sdpa.run_attention(
&q,
&k,
&v,
Some(attention_mask),
Some(flash_params),
&self.sdpa_params,
)?;
if let Some(t) = self.q_proj.quantized_act_type() {
attn_output = attn_output.to_dtype(t)?;
}
attn_output = attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?;
let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
if self.q_proj.quantized_act_type().is_some() {
res = res.to_dtype(original_dtype)?;
}
Ok(res)
}
}
struct DecoderLayer {
self_attn: Attention,
mlp: Box<dyn MlpLayer>,
input_layernorm: RmsNorm,
post_attention_layernorm: RmsNorm,
}
impl DecoderLayer {
#[allow(clippy::too_many_arguments)]
fn new(
rotary_emb: Arc<RotaryEmbedding>,
cfg: &Config,
vb: ShardedVarBuilder,
mapper: &dyn DeviceMapper,
layer_idx: usize,
loading_isq: bool,
comm: &Arc<mistralrs_quant::Comm>,
) -> Result<Self> {
let self_attn = Attention::new(
rotary_emb,
cfg,
mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
mapper,
layer_idx,
loading_isq,
comm,
)?;
let mlp = Mlp::new(
mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
cfg.hidden_size,
cfg.intermediate_size,
&cfg.quantization_config,
cfg.hidden_act,
comm,
)?;
let input_layernorm = RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
)?;
let post_attention_layernorm = RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
)?;
Ok(Self {
self_attn,
mlp: Box::new(mlp),
input_layernorm,
post_attention_layernorm,
})
}
#[allow(clippy::too_many_arguments)]
fn forward(
&self,
xs: &Tensor,
attention_mask: &Tensor,
seqlen_offsets: &[usize],
flash_params: &FlashParams,
) -> Result<Tensor> {
let residual = xs;
let xs = self.input_layernorm.forward(xs)?;
let xs = self
.self_attn
.forward(&xs, attention_mask, seqlen_offsets, flash_params)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = self
.mlp
.forward(&xs.apply(&self.post_attention_layernorm)?)?;
residual + xs
}
}
pub struct Model {
embed_tokens: candle_nn::Embedding,
layers: Vec<DecoderLayer>,
norm: RmsNorm,
sliding_window: Option<usize>,
device: Device,
mapper: Box<dyn DeviceMapper + Send + Sync>,
cfg: ModelConfigMetadata,
}
impl Model {
pub fn new(
cfg: &Config,
vb: ShardedVarBuilder,
is_gptx: bool,
normal_loading_metadata: NormalLoadingMetadata,
attention_mechanism: AttentionImplementation,
) -> Result<Self> {
Self::new_inner(
cfg,
vb,
is_gptx,
normal_loading_metadata,
attention_mechanism,
)
}
pub fn new_inner(
cfg: &Config,
vb_m: ShardedVarBuilder,
is_gptx: bool,
normal_loading_metadata: NormalLoadingMetadata,
attention_mechanism: AttentionImplementation,
) -> Result<Self> {
if let Some(ref quant_cfg) = &cfg.quantization_config {
tracing::info!(
"Using {} quantization: {}.",
quant_cfg.name(),
quant_cfg.get_bits_name(&vb_m)
);
}
if !matches!(attention_mechanism, AttentionImplementation::Eager) {
candle_core::bail!("Expected AttentionImplementation::Eager");
}
let mapper = normal_loading_metadata.mapper;
let embed_tokens = embedding(
cfg.vocab_size,
cfg.hidden_size,
mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
&cfg.quantization_config,
)?;
let head_dim = cfg.head_dim();
let mut ropes = HashMap::new();
for layer_idx in 0..cfg.num_hidden_layers {
let device = mapper
.device_for(layer_idx, false)
.unwrap_or(&normal_loading_metadata.real_device);
ropes.insert(
device.location(),
Arc::new(RotaryEmbedding::new(
cfg.rope_theta as f32,
head_dim,
cfg.max_position_embeddings,
device,
is_gptx,
vb_m.dtype(),
)?),
);
}
let vb_l = vb_m.pp("layers");
let layers = NiceProgressBar::<_, 'b'>(
0..cfg.num_hidden_layers,
"Loading repeating layers",
&normal_loading_metadata.multi_progress,
)
.par_iter_if_isq(|layer_idx| -> Result<DecoderLayer> {
let device = mapper
.device_for(layer_idx, false)
.unwrap_or(&normal_loading_metadata.real_device);
let rotary_emb = ropes
.get(&device.location())
.expect("No RoPE for device location!")
.clone();
let comm = mapper.get_comm_for(layer_idx)?;
DecoderLayer::new(
rotary_emb.clone(),
cfg,
vb_l.pp(layer_idx),
&*mapper,
layer_idx,
normal_loading_metadata.loading_isq,
&comm,
)
})?;
let norm = RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
mapper.set_nm_device(vb_m.pp("norm"), false),
)?;
Ok(Self {
embed_tokens,
layers,
norm,
sliding_window: cfg.sliding_window,
device: normal_loading_metadata.real_device,
cfg: ModelConfigMetadata {
max_seq_len: cfg.max_position_embeddings,
num_layers: cfg.num_hidden_layers,
hidden_size: cfg.hidden_size,
num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
.max(1),
num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
sliding_window: cfg.sliding_window,
k_head_dim: cfg.head_dim(),
v_head_dim: cfg.head_dim(),
kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
},
mapper,
})
}
pub fn forward(&self, input_ids: &Tensor, flash_params: &FlashParams) -> Result<Tensor> {
self.forward_embeds(
input_ids,
self.embed_tokens.forward(input_ids)?,
flash_params,
)
}
#[allow(clippy::too_many_arguments)]
pub fn forward_embeds(
&self,
input_ids: &Tensor,
input_embeds: Tensor,
flash_params: &FlashParams,
) -> Result<Tensor> {
let mut xs = input_embeds;
let (bs, _seqlen) = input_ids.dims2()?;
let seqlen_offsets = vec![0; bs];
let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
input_ids,
&NotACache,
self.sliding_window,
xs.dtype(),
self.cfg.num_attn_heads,
)?;
let Some(attention_mask) = attention_mask else {
unreachable!()
};
for (i, layer) in self.layers.iter().enumerate() {
xs = self.mapper.map(xs, i)?;
xs = layer.forward(
&xs,
&attention_mask.to_device(xs.device())?,
&seqlen_offsets,
flash_params,
)?;
}
let xs = xs.to_device(&self.device)?;
xs.apply(&self.norm)
}
}
impl IsqModel for Model {
fn get_layers(
&mut self,
) -> (
Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
&dyn DeviceMapper,
) {
let mut tensors = Vec::new();
for (i, layer) in self.layers.iter_mut().enumerate() {
tensors.push((&mut layer.self_attn.q_proj, Some(i)));
tensors.push((&mut layer.self_attn.k_proj, Some(i)));
tensors.push((&mut layer.self_attn.v_proj, Some(i)));
tensors.push((&mut layer.self_attn.o_proj, Some(i)));
tensors.extend(
layer
.mlp
.get_isq_layers()
.into_iter()
.map(|m| (m, Some(i)))
.collect::<Vec<_>>(),
);
}
(tensors, &*self.mapper)
}
fn residual_tensors(&self) -> Vec<(String, Tensor)> {
let uvb = UnVarBuilder::new();
uvb.pp("embed_tokens").add(&self.embed_tokens);
uvb.pp("norm").add(&self.norm);
for (layer_idx, layer) in self.layers.iter().enumerate() {
let uvb_l = uvb.pp("layers").pp(layer_idx);
uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
uvb_l
.pp("post_attention_layernorm")
.add(&layer.post_attention_layernorm);
uvb_l
.pp("self_attn")
.pp("q_norm")
.add(&layer.self_attn.q_norm);
uvb_l
.pp("self_attn")
.pp("k_norm")
.add(&layer.self_attn.k_norm);
}
uvb.to_safetensors()
}
fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
let mut names = Vec::new();
names.push(None);
for i in 0..self.layers.len() {
names.push(Some(format!("blk.{i}.attn_q.weight")));
names.push(Some(format!("blk.{i}.attn_k.weight")));
names.push(Some(format!("blk.{i}.attn_v.weight")));
names.push(Some(format!("blk.{i}.attn_output.weight")));
names.push(Some(format!("blk.{i}.ffn_gate.weight")));
names.push(Some(format!("blk.{i}.ffn_up.weight")));
names.push(Some(format!("blk.{i}.ffn_down.weight")));
}
Ok(names)
}
}
impl EmbeddingModel for Model {
fn forward(
&self,
input_ids: &Tensor,
flash_params: &FlashParams,
) -> candle_core::Result<Tensor> {
self.forward(input_ids, flash_params)
}
fn device(&self) -> &Device {
&self.device
}
}
impl AnyMoeBaseModelMixin for Model {}