#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
use std::{collections::HashMap, sync::Arc};
use candle_core::{Device, Module, Result, Tensor};
use mistralrs_quant::{ColumnParallelLayer, QuantMethod, RowParallelLayer, ShardedVarBuilder};
use crate::{
amoe::{AnyMoeBaseModelMixin, MlpLayer},
attention::SdpaParams,
device_map::DeviceMapper,
layers::{
embedding, Gemma3RotaryEmbedding, GemmaRmsNorm, MatMul, Mlp, RotaryEmbedding,
ScaledEmbedding, Sdpa,
},
layers_masker::BidirectionalMasker,
paged_attention::AttentionImplementation,
pipeline::{
text_models_inputs_processor::FlashParams, EmbeddingModel, IsqModel, NormalLoadingMetadata,
},
utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
};
use mistralrs_quant::QuantizedConfig;
use crate::{
layers::{Activation, Gemma3RopeScalingConfig},
serde_default_fn,
};
serde_default_fn!(bool, attention_bias, false);
serde_default_fn!(usize, head_dim, 256);
serde_default_fn!(Activation, hidden_activation, Activation::GeluPytorchTanh);
serde_default_fn!(f64, rms_norm_eps, 1e-6);
serde_default_fn!(f64, rope_theta, 1000000.);
serde_default_fn!(usize, vocab_size, 262208);
serde_default_fn!(usize, query_pre_attn_scalar, 256);
serde_default_fn!(usize, max_position_embeddings, 131072);
serde_default_fn!(f64, rope_local_base_freq, 10000.);
serde_default_fn!(usize, sliding_window_pattern, 6);
serde_default_fn!(usize, num_attention_heads, 8);
serde_default_fn!(usize, num_key_value_heads, 4);
#[derive(Debug, Clone, serde::Deserialize)]
pub struct EmbeddingGemmaConfig {
#[serde(default = "attention_bias")]
pub attention_bias: bool,
#[serde(default = "head_dim")]
pub head_dim: usize,
#[serde(default = "hidden_activation")]
pub hidden_activation: Activation,
pub hidden_size: usize,
pub intermediate_size: usize,
#[serde(default = "num_attention_heads")]
pub num_attention_heads: usize,
pub num_hidden_layers: usize,
#[serde(default = "num_key_value_heads")]
pub num_key_value_heads: usize,
#[serde(default = "rms_norm_eps")]
pub rms_norm_eps: f64,
#[serde(default = "rope_theta")]
pub rope_theta: f64,
#[serde(default = "vocab_size")]
pub vocab_size: usize,
pub sliding_window: usize,
pub attn_logit_softcapping: Option<f64>,
pub final_logit_softcapping: Option<f64>,
#[serde(default = "query_pre_attn_scalar")]
pub query_pre_attn_scalar: usize,
#[serde(default = "max_position_embeddings")]
pub max_position_embeddings: usize,
pub quantization_config: Option<QuantizedConfig>,
#[serde(default = "rope_local_base_freq")]
pub rope_local_base_freq: f64,
#[serde(default = "sliding_window_pattern")]
pub sliding_window_pattern: usize,
pub rope_scaling: Option<Gemma3RopeScalingConfig>,
pub use_bidirectional_attention: bool,
}
macro_rules! is_sliding {
($layer_idx:expr, $cfg:expr) => {
($layer_idx + 1) % $cfg.sliding_window_pattern != 0
};
}
struct Attention {
q_proj: Arc<dyn QuantMethod>,
k_proj: Arc<dyn QuantMethod>,
v_proj: Arc<dyn QuantMethod>,
o_proj: Arc<dyn QuantMethod>,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
rotary_emb_global: Arc<Gemma3RotaryEmbedding>,
rotary_emb_local: Arc<RotaryEmbedding>,
use_sliding_window: bool,
sdpa_params: SdpaParams,
q_norm: GemmaRmsNorm,
k_norm: GemmaRmsNorm,
}
impl Attention {
#[allow(clippy::too_many_arguments)]
fn new(
rotary_emb_global: Arc<Gemma3RotaryEmbedding>,
rotary_emb_local: Arc<RotaryEmbedding>,
cfg: &EmbeddingGemmaConfig,
layer_idx: usize,
mapper: &dyn DeviceMapper,
vb: ShardedVarBuilder,
comm: &Arc<mistralrs_quant::Comm>,
) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
let head_dim = cfg.head_dim;
let bias = cfg.attention_bias;
let q_proj = ColumnParallelLayer::new(
hidden_sz,
num_heads * head_dim,
&cfg.quantization_config,
bias,
comm,
vb.pp("q_proj"),
)?;
let kv_shard = mistralrs_quant::compute_kv_shard(
cfg.num_key_value_heads,
cfg.hidden_size / cfg.num_attention_heads,
comm,
);
let k_proj = ColumnParallelLayer::new_with_shard(
hidden_sz,
num_kv_heads * head_dim,
&cfg.quantization_config,
bias,
comm,
kv_shard,
vb.pp("k_proj"),
)?;
let v_proj = ColumnParallelLayer::new_with_shard(
hidden_sz,
num_kv_heads * head_dim,
&cfg.quantization_config,
bias,
comm,
kv_shard,
vb.pp("v_proj"),
)?;
let o_proj = RowParallelLayer::new(
num_heads * head_dim,
hidden_sz,
&cfg.quantization_config,
bias,
comm,
vb.pp("o_proj"),
)?;
let sliding_window = if is_sliding!(layer_idx, cfg) {
Some(cfg.sliding_window)
} else {
None
};
let q_norm = GemmaRmsNorm::new(
cfg.head_dim,
cfg.rms_norm_eps,
mapper.set_device(layer_idx, vb.pp("q_norm"), false),
)?;
let k_norm = GemmaRmsNorm::new(
cfg.head_dim,
cfg.rms_norm_eps,
mapper.set_device(layer_idx, vb.pp("k_norm"), false),
)?;
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
num_heads: num_heads / comm.world_size(),
num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
head_dim,
rotary_emb_global,
rotary_emb_local,
use_sliding_window: sliding_window.is_some(),
sdpa_params: SdpaParams {
n_kv_groups: mistralrs_quant::compute_n_kv_groups(
cfg.num_key_value_heads,
cfg.num_attention_heads,
comm,
),
softcap: cfg.attn_logit_softcapping.map(|x| x as f32),
softmax_scale: 1.0 / (cfg.query_pre_attn_scalar as f32).sqrt(),
sliding_window,
sinks: None,
},
q_norm,
k_norm,
})
}
#[allow(clippy::too_many_arguments)]
fn forward(
&self,
xs: &Tensor,
attention_mask: &Tensor,
sliding_attention_mask: &Tensor,
seqlen_offsets: &[usize],
flash_params: &FlashParams,
) -> Result<Tensor> {
let (b_sz, q_len, _) = xs.dims3()?;
let original_dtype = xs.dtype();
let mut xs = xs.clone();
if let Some(t) = self.q_proj.quantized_act_type() {
xs = xs.to_dtype(t)?;
}
let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
if self.q_proj.quantized_act_type().is_some() {
q = q.to_dtype(original_dtype)?;
k = k.to_dtype(original_dtype)?;
v = v.to_dtype(original_dtype)?;
}
q = q
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
k = k
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
v = v
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
q = q.apply(&self.q_norm)?;
k = k.apply(&self.k_norm)?;
(q, k) = match self.use_sliding_window {
true => self.rotary_emb_local.forward(&q, &k, seqlen_offsets)?,
false => self.rotary_emb_global.forward(&q, &k, seqlen_offsets)?,
};
let mask = if self.use_sliding_window {
sliding_attention_mask
} else {
attention_mask
};
let mut attn_output = Sdpa.run_attention(
&q,
&k,
&v,
Some(mask),
Some(flash_params),
&self.sdpa_params,
)?;
if let Some(t) = self.q_proj.quantized_act_type() {
attn_output = attn_output.to_dtype(t)?;
}
attn_output = attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?;
let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
if self.q_proj.quantized_act_type().is_some() {
res = res.to_dtype(original_dtype)?;
}
Ok(res)
}
}
struct DecoderLayer {
self_attn: Attention,
mlp: Box<dyn MlpLayer>,
input_layernorm: GemmaRmsNorm,
post_attention_layernorm: GemmaRmsNorm,
pre_feedforward_layernorm: GemmaRmsNorm,
post_feedforward_layernorm: GemmaRmsNorm,
}
impl DecoderLayer {
#[allow(clippy::too_many_arguments)]
fn new(
rotary_emb_global: Arc<Gemma3RotaryEmbedding>,
rotary_emb_local: Arc<RotaryEmbedding>,
cfg: &EmbeddingGemmaConfig,
vb: ShardedVarBuilder,
mapper: &dyn DeviceMapper,
layer_idx: usize,
loading_isq: bool,
comm: &Arc<mistralrs_quant::Comm>,
) -> Result<Self> {
let self_attn = Attention::new(
rotary_emb_global,
rotary_emb_local,
cfg,
layer_idx,
mapper,
mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
comm,
)?;
let mlp = Mlp::new(
mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
cfg.hidden_size,
cfg.intermediate_size,
&cfg.quantization_config,
cfg.hidden_activation,
comm,
)?;
let input_layernorm = GemmaRmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
)?;
let post_attention_layernorm = GemmaRmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
)?;
let pre_feedforward_layernorm = GemmaRmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
mapper.set_device(layer_idx, vb.pp("pre_feedforward_layernorm"), false),
)?;
let post_feedforward_layernorm = GemmaRmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
mapper.set_device(layer_idx, vb.pp("post_feedforward_layernorm"), false),
)?;
Ok(Self {
self_attn,
mlp: Box::new(mlp),
input_layernorm,
post_attention_layernorm,
pre_feedforward_layernorm,
post_feedforward_layernorm,
})
}
#[allow(clippy::too_many_arguments)]
fn forward(
&self,
xs: &Tensor,
attention_mask: &Tensor,
sliding_attention_mask: &Tensor,
seqlen_offsets: &[usize],
flash_params: &FlashParams,
) -> Result<Tensor> {
let residual = xs;
let xs = self.input_layernorm.forward(xs)?;
let xs = self
.self_attn
.forward(
&xs,
attention_mask,
sliding_attention_mask,
seqlen_offsets,
flash_params,
)?
.apply(&self.post_attention_layernorm)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = self
.mlp
.forward(&xs.apply(&self.pre_feedforward_layernorm)?)?
.apply(&self.post_feedforward_layernorm)?;
residual + xs
}
}
pub struct EmbeddingGemma {
embed_tokens: ScaledEmbedding,
layers: Vec<DecoderLayer>,
norm: GemmaRmsNorm,
device: Device,
mapper: Box<dyn DeviceMapper + Send + Sync>,
sliding_window: usize,
final_logit_softcapping: Option<f64>,
}
impl EmbeddingGemma {
pub fn new(
cfg: &EmbeddingGemmaConfig,
vb: ShardedVarBuilder,
is_gptx: bool,
normal_loading_metadata: NormalLoadingMetadata,
attention_mechanism: AttentionImplementation,
) -> Result<Self> {
if let Some(ref quant_cfg) = &cfg.quantization_config {
tracing::info!(
"Using {} quantization: {}.",
quant_cfg.name(),
quant_cfg.get_bits_name(&vb)
);
}
if !matches!(attention_mechanism, AttentionImplementation::Eager) {
candle_core::bail!("Expected AttentionImplementation::Eager");
}
let mapper = normal_loading_metadata.mapper;
let embed_tokens = ScaledEmbedding::new(
(cfg.hidden_size as f64).sqrt(),
embedding(
cfg.vocab_size,
cfg.hidden_size,
mapper.set_nm_device(vb.pp("embed_tokens"), false),
&cfg.quantization_config,
)?,
);
let mut global_ropes = HashMap::new();
for layer_idx in 0..cfg.num_hidden_layers {
let device = mapper
.device_for(layer_idx, false)
.unwrap_or(&normal_loading_metadata.real_device);
global_ropes.insert(
device.location(),
Arc::new(Gemma3RotaryEmbedding::new_embedding_gemma(
is_gptx,
vb.dtype(),
cfg,
device,
)?),
);
}
let mut local_ropes = HashMap::new();
for layer_idx in 0..cfg.num_hidden_layers {
let device = mapper
.device_for(layer_idx, false)
.unwrap_or(&normal_loading_metadata.real_device);
local_ropes.insert(
device.location(),
Arc::new(RotaryEmbedding::new(
cfg.rope_local_base_freq as f32,
cfg.head_dim,
cfg.max_position_embeddings,
device,
is_gptx,
vb.dtype(),
)?),
);
}
let vb_l = vb.pp("layers");
let layers = NiceProgressBar::<_, 'b'>(
0..cfg.num_hidden_layers,
"Loading repeating layers",
&normal_loading_metadata.multi_progress,
)
.par_iter_if_isq(|layer_idx| {
let device = mapper
.device_for(layer_idx, false)
.unwrap_or(&normal_loading_metadata.real_device);
let rotary_emb_global = global_ropes
.get(&device.location())
.expect("No RoPE for device location!")
.clone();
let rotary_emb_local = local_ropes
.get(&device.location())
.expect("No RoPE for device location!")
.clone();
let comm = mapper.get_comm_for(layer_idx)?;
DecoderLayer::new(
rotary_emb_global,
rotary_emb_local,
cfg,
vb_l.pp(layer_idx),
&*mapper,
layer_idx,
normal_loading_metadata.loading_isq,
&comm,
)
})?;
let norm = GemmaRmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
mapper.set_nm_device(vb.pp("norm"), false),
)?;
Ok(Self {
embed_tokens,
layers,
norm,
device: normal_loading_metadata.real_device,
sliding_window: cfg.sliding_window,
final_logit_softcapping: cfg.final_logit_softcapping,
mapper,
})
}
pub fn embed_tokens(&self, input_ids: &Tensor) -> Result<Tensor> {
self.embed_tokens.forward(input_ids)
}
pub fn forward_embeds(
&self,
input_ids: &Tensor,
mut xs: Tensor,
flash_params: &FlashParams,
) -> Result<Tensor> {
let (bs, _seqlen) = input_ids.dims2()?;
let seqlen_offsets = vec![0; bs];
let attention_mask = BidirectionalMasker.make_mask(input_ids, xs.dtype())?;
let sliding_mask =
BidirectionalMasker.make_sliding_mask(input_ids, xs.dtype(), self.sliding_window)?;
for (i, layer) in self.layers.iter().enumerate() {
xs = self.mapper.map(xs, i)?;
xs = layer.forward(
&xs,
&attention_mask.to_device(xs.device())?,
&sliding_mask.to_device(xs.device())?,
&seqlen_offsets,
flash_params,
)?;
}
let xs = xs.to_device(&self.device)?;
let mut xs = xs.apply(&self.norm)?;
if let Some(final_logit_softcapping) = self.final_logit_softcapping {
xs = (xs / final_logit_softcapping)?;
xs = xs.tanh()?;
xs = (xs * final_logit_softcapping)?;
}
Ok(xs)
}
}
impl IsqModel for EmbeddingGemma {
fn get_layers(
&mut self,
) -> (
Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
&dyn DeviceMapper,
) {
let mut tensors = Vec::new();
for (i, layer) in self.layers.iter_mut().enumerate() {
tensors.push((&mut layer.self_attn.q_proj, Some(i)));
tensors.push((&mut layer.self_attn.k_proj, Some(i)));
tensors.push((&mut layer.self_attn.v_proj, Some(i)));
tensors.push((&mut layer.self_attn.o_proj, Some(i)));
tensors.extend(
layer
.mlp
.get_isq_layers()
.into_iter()
.map(|m| (m, Some(i)))
.collect::<Vec<_>>(),
);
}
(tensors, &*self.mapper)
}
fn residual_tensors(&self) -> Vec<(String, Tensor)> {
let uvb = UnVarBuilder::new();
uvb.pp("embed_tokens").add(&self.embed_tokens);
uvb.pp("norm").add(&self.norm);
for (layer_idx, layer) in self.layers.iter().enumerate() {
let uvb_l = uvb.pp("layers").pp(layer_idx);
uvb_l
.pp("self_attn")
.pp("q_norm")
.add(&layer.self_attn.q_norm);
uvb_l
.pp("self_attn")
.pp("k_norm")
.add(&layer.self_attn.k_norm);
uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
uvb_l
.pp("post_attention_layernorm")
.add(&layer.post_attention_layernorm);
uvb_l
.pp("pre_feedforward_layernorm")
.add(&layer.pre_feedforward_layernorm);
uvb_l
.pp("post_feedforward_layernorm")
.add(&layer.post_feedforward_layernorm);
}
uvb.to_safetensors()
}
fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
let mut names = Vec::new();
names.push(None);
for i in 0..self.layers.len() {
names.push(Some(format!("blk.{i}.attn_q.weight")));
names.push(Some(format!("blk.{i}.attn_k.weight")));
names.push(Some(format!("blk.{i}.attn_v.weight")));
names.push(Some(format!("blk.{i}.attn_output.weight")));
names.push(Some(format!("blk.{i}.ffn_gate.weight")));
names.push(Some(format!("blk.{i}.ffn_up.weight")));
names.push(Some(format!("blk.{i}.ffn_down.weight")));
}
Ok(names)
}
}
impl EmbeddingModel for EmbeddingGemma {
fn forward(
&self,
input_ids: &Tensor,
flash_params: &FlashParams,
) -> candle_core::Result<Tensor> {
self.forward_embeds(input_ids, self.embed_tokens(input_ids)?, flash_params)
}
fn device(&self) -> &Device {
&self.device
}
}
impl AnyMoeBaseModelMixin for EmbeddingGemma {}