#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
use std::{
collections::HashMap,
sync::{Arc, Mutex},
};
use candle_core::{DType, Device, Module, Result, Tensor, D};
use candle_nn::Embedding;
use mistralrs_quant::{
ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer,
ShardedVarBuilder,
};
use super::config::{LayerType, TextConfig};
use crate::{
attention::SdpaParams,
device_map::{DeviceMappedMask, DeviceMapper},
kv_cache::{
HybridCache, HybridCacheConfig, HybridLayerCache, HybridLayerType, RecurrentLayerConfig,
},
layers::{self, GemmaRmsNorm, Qwen3VLRotaryEmbedding, Sdpa},
models::gdn::{GatedDeltaNet, GdnConfig, GdnLayerCache, GdnWeightMode},
paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
pipeline::{
extract_logits,
text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
EitherCache, IsqModel, KvCache, NormalLoadingMetadata,
},
utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
};
impl GdnConfig for TextConfig {
fn hidden_size(&self) -> usize {
self.hidden_size
}
fn rms_norm_eps(&self) -> f64 {
self.rms_norm_eps
}
fn linear_conv_kernel_dim(&self) -> usize {
self.linear_conv_kernel_dim
}
fn linear_key_head_dim(&self) -> usize {
self.linear_key_head_dim
}
fn linear_value_head_dim(&self) -> usize {
self.linear_value_head_dim
}
fn linear_num_key_heads(&self) -> usize {
self.linear_num_key_heads
}
fn linear_num_value_heads(&self) -> usize {
self.linear_num_value_heads
}
fn quantization_config(&self) -> &Option<QuantizedConfig> {
&self.quantization_config
}
}
#[allow(dead_code)]
struct FullAttention {
q_proj: Arc<dyn QuantMethod>,
k_proj: Arc<dyn QuantMethod>,
v_proj: Arc<dyn QuantMethod>,
o_proj: Arc<dyn QuantMethod>,
q_norm: GemmaRmsNorm,
k_norm: GemmaRmsNorm,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
rotary_emb: Arc<Qwen3VLRotaryEmbedding>,
rot_dim: usize,
paged_attn: Option<PagedAttention>,
sdpa_params: SdpaParams,
}
impl FullAttention {
#[allow(clippy::too_many_arguments)]
fn load(
vb: ShardedVarBuilder,
cfg: &TextConfig,
mapper: &dyn DeviceMapper,
layer_idx: usize,
loading_isq: bool,
rotary_emb: Arc<Qwen3VLRotaryEmbedding>,
paged_attn: Option<PagedAttention>,
comm: &Arc<mistralrs_quant::Comm>,
) -> Result<Self> {
let vb_sa = mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq);
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
let head_dim = cfg.head_dim;
let q_proj = ColumnParallelLayer::new(
cfg.hidden_size,
num_heads * head_dim * 2,
&cfg.quantization_config,
false,
comm,
vb_sa.pp("q_proj"),
)?;
let kv_shard = mistralrs_quant::compute_kv_shard(num_kv_heads, head_dim, comm);
let k_proj = ColumnParallelLayer::new_with_shard(
cfg.hidden_size,
num_kv_heads * head_dim,
&cfg.quantization_config,
false,
comm,
kv_shard,
vb_sa.pp("k_proj"),
)?;
let v_proj = ColumnParallelLayer::new_with_shard(
cfg.hidden_size,
num_kv_heads * head_dim,
&cfg.quantization_config,
false,
comm,
kv_shard,
vb_sa.pp("v_proj"),
)?;
let o_proj = RowParallelLayer::new(
num_heads * head_dim,
cfg.hidden_size,
&cfg.quantization_config,
false,
comm,
vb_sa.pp("o_proj"),
)?;
let vb_sa_norms = mapper.set_device(layer_idx, vb.pp("self_attn"), false);
let q_norm = GemmaRmsNorm::new(head_dim, cfg.rms_norm_eps, vb_sa_norms.pp("q_norm"))?;
let k_norm = GemmaRmsNorm::new(head_dim, cfg.rms_norm_eps, vb_sa_norms.pp("k_norm"))?;
let rot_dim = cfg.rot_dim();
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
num_heads: num_heads / comm.world_size(),
num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
head_dim,
rotary_emb,
rot_dim,
paged_attn,
sdpa_params: SdpaParams {
n_kv_groups: mistralrs_quant::compute_n_kv_groups(num_kv_heads, num_heads, comm),
softcap: None,
softmax_scale: 1.0 / (head_dim as f32).sqrt(),
sliding_window: None,
sinks: None,
},
})
}
#[allow(clippy::too_many_arguments)]
fn forward(
&self,
x: &Tensor,
attention_mask: Option<&Tensor>,
cos_sin: &(Tensor, Tensor),
kv_cache: &mut KvCache,
metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
flash_params: &FlashParams,
) -> Result<Tensor> {
let (b_sz, seq_len, _) = x.dims3()?;
let original_dtype = x.dtype();
let mut x = x.clone();
if let Some(t) = self.q_proj.quantized_act_type() {
x = x.to_dtype(t)?;
}
let mut q_gate = mistralrs_quant::MatMul.qmethod_matmul(&x, &*self.q_proj)?;
let mut k = mistralrs_quant::MatMul.qmethod_matmul(&x, &*self.k_proj)?;
let mut v = mistralrs_quant::MatMul.qmethod_matmul(&x, &*self.v_proj)?;
if self.q_proj.quantized_act_type().is_some() {
q_gate = q_gate.to_dtype(original_dtype)?;
k = k.to_dtype(original_dtype)?;
v = v.to_dtype(original_dtype)?;
}
let q_gate = q_gate.reshape((b_sz, seq_len, self.num_heads, self.head_dim * 2))?;
let q = q_gate.narrow(D::Minus1, 0, self.head_dim)?;
let gate = q_gate.narrow(D::Minus1, self.head_dim, self.head_dim)?;
let gate = gate.reshape((b_sz, seq_len, self.num_heads * self.head_dim))?;
let (mut q, mut k, v) = if seq_len != 1 {
let q = q.transpose(1, 2)?;
let k = k
.reshape((b_sz, seq_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let v = v
.reshape((b_sz, seq_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
(q, k, v)
} else {
let q = q.reshape((b_sz, self.num_heads, seq_len, self.head_dim))?;
let k = k.reshape((b_sz, self.num_kv_heads, seq_len, self.head_dim))?;
let v = v.reshape((b_sz, self.num_kv_heads, seq_len, self.head_dim))?;
(q, k, v)
};
q = q.apply(&self.q_norm)?;
k = k.apply(&self.k_norm)?;
if self.rot_dim < self.head_dim {
let mut q_rot = q.narrow(D::Minus1, 0, self.rot_dim)?;
let q_pass = q.narrow(D::Minus1, self.rot_dim, self.head_dim - self.rot_dim)?;
let mut k_rot = k.narrow(D::Minus1, 0, self.rot_dim)?;
let k_pass = k.narrow(D::Minus1, self.rot_dim, self.head_dim - self.rot_dim)?;
self.rotary_emb.forward(cos_sin, &mut q_rot, &mut k_rot)?;
q = Tensor::cat(&[q_rot, q_pass], D::Minus1)?;
k = Tensor::cat(&[k_rot, k_pass], D::Minus1)?;
} else {
self.rotary_emb.forward(cos_sin, &mut q, &mut k)?;
}
let mut y = match &self.paged_attn {
Some(paged_attn) => match metadata {
Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
&q,
&k,
&v,
attention_mask,
Some(key_cache),
Some(value_cache),
input_metadata,
&self.sdpa_params,
Some(flash_params),
)?,
None => {
let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
assert!(attention_mask.is_some());
paged_attn.forward(
&q,
&k,
&v,
attention_mask,
None,
None,
&input_metadata,
&self.sdpa_params,
Some(flash_params),
)?
}
},
None => {
let (cache_k, cache_v) = kv_cache.append(&k, &v)?;
Sdpa.run_attention(
&q,
&cache_k,
&cache_v,
attention_mask,
Some(flash_params),
&self.sdpa_params,
)?
}
};
if let Some(t) = self.q_proj.quantized_act_type() {
y = y.to_dtype(t)?;
}
y = if attention_mask.is_some() {
y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?
} else {
y.reshape((b_sz, seq_len, ()))?
};
let gate = candle_nn::ops::sigmoid(&gate.to_dtype(y.dtype())?)?;
y = y.broadcast_mul(&gate)?;
let mut res = mistralrs_quant::MatMul.qmethod_matmul(&y, &*self.o_proj)?;
if self.q_proj.quantized_act_type().is_some() {
res = res.to_dtype(original_dtype)?;
}
Ok(res)
}
}
#[derive(Clone)]
struct Mlp {
gate_proj: Arc<dyn QuantMethod>,
up_proj: Arc<dyn QuantMethod>,
down_proj: Arc<dyn QuantMethod>,
act_fn: crate::layers::Activation,
}
impl Mlp {
fn new(
vb: ShardedVarBuilder,
hidden_size: usize,
intermediate_size: usize,
quant_config: &Option<mistralrs_quant::QuantizedConfig>,
act_fn: crate::layers::Activation,
comm: &Arc<mistralrs_quant::Comm>,
) -> Result<Self> {
let gate_proj = ColumnParallelLayer::new(
hidden_size,
intermediate_size,
quant_config,
false,
comm,
vb.pp("gate_proj"),
)?;
let up_proj = ColumnParallelLayer::new(
hidden_size,
intermediate_size,
quant_config,
false,
comm,
vb.pp("up_proj"),
)?;
let down_proj = RowParallelLayer::new(
intermediate_size,
hidden_size,
quant_config,
false,
comm,
vb.pp("down_proj"),
)?;
Ok(Self {
gate_proj,
up_proj,
down_proj,
act_fn,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let original_dtype = xs.dtype();
let mut xs = xs.clone();
if let Some(t) = self.gate_proj.quantized_act_type() {
xs = xs.to_dtype(t)?;
}
let gate = mistralrs_quant::MatMul.qmethod_matmul(&xs, &*self.gate_proj)?;
let up = mistralrs_quant::MatMul.qmethod_matmul(&xs, &*self.up_proj)?;
let activated = crate::ops::mul_and_act(&gate, &up, self.act_fn)?;
let mut res = mistralrs_quant::MatMul.qmethod_matmul(&activated, &*self.down_proj)?;
if self.gate_proj.quantized_act_type().is_some() {
res = res.to_dtype(original_dtype)?;
}
Ok(res)
}
fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
vec![&mut self.gate_proj, &mut self.up_proj, &mut self.down_proj]
}
}
enum LayerImpl {
FullAttention(FullAttention),
LinearAttention(GatedDeltaNet),
}
struct DecoderLayer {
layer_impl: LayerImpl,
input_layernorm: GemmaRmsNorm,
post_attention_layernorm: GemmaRmsNorm,
mlp: Mlp,
}
impl DecoderLayer {
#[allow(clippy::too_many_arguments)]
fn forward_attention(
&self,
x: &Tensor,
attention_mask: Option<&Tensor>,
cos_sin: &(Tensor, Tensor),
kv_cache: &mut KvCache,
metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
flash_params: &FlashParams,
) -> Result<Tensor> {
let attn = match &self.layer_impl {
LayerImpl::FullAttention(attn) => attn,
_ => candle_core::bail!("Expected full attention layer"),
};
let residual = x;
let x = self.input_layernorm.forward(x)?;
let attn_out = attn.forward(
&x,
attention_mask,
cos_sin,
kv_cache,
metadata,
flash_params,
)?;
let x = (attn_out + residual)?;
let residual = &x;
let normed = self.post_attention_layernorm.forward(&x)?;
let ffn_out = self.mlp.forward(&normed)?;
ffn_out + residual
}
fn forward_linear(&self, x: &Tensor, cache: &mut GdnLayerCache) -> Result<Tensor> {
let gdn = match &self.layer_impl {
LayerImpl::LinearAttention(gdn) => gdn,
_ => candle_core::bail!("Expected linear attention layer"),
};
let residual = x;
let x = self.input_layernorm.forward(x)?;
let gdn_out = gdn.forward(&x, cache)?;
let x = (gdn_out + residual)?;
let residual = &x;
let normed = self.post_attention_layernorm.forward(&x)?;
let ffn_out = self.mlp.forward(&normed)?;
ffn_out + residual
}
}
pub struct Qwen3_5TextModel {
embed_tokens: Embedding,
pub(super) norm: GemmaRmsNorm,
layers: Vec<DecoderLayer>,
layer_types: Vec<LayerType>,
mapper: Box<dyn DeviceMapper + Send + Sync>,
lm_head: Arc<dyn QuantMethod>,
pub(super) cache: EitherCache,
pub(super) cfg: ModelConfigMetadata,
pub(super) device: Device,
pub(super) dtype: DType,
pub(super) max_seq_len: usize,
}
impl Qwen3_5TextModel {
pub fn new(
cfg: &TextConfig,
vb: ShardedVarBuilder,
tie: bool,
normal_loading_metadata: NormalLoadingMetadata,
attention_mechanism: AttentionImplementation,
) -> Result<Self> {
let mapper = normal_loading_metadata.mapper;
let vb_m = if vb.contains_tensor("language_model.model.embed_tokens.weight") {
vb.pp("language_model").pp("model")
} else {
vb.pp("model").pp("language_model")
};
let embed_tokens = layers::embedding(
cfg.vocab_size,
cfg.hidden_size,
mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
&cfg.quantization_config,
)?;
let layer_types = cfg.layer_types();
let rot_dim = cfg.rot_dim();
let mut ropes = HashMap::new();
for (layer_idx, layer_type) in layer_types.iter().enumerate() {
if *layer_type != LayerType::FullAttention {
continue;
}
let device = mapper
.device_for(layer_idx, false)
.unwrap_or(&normal_loading_metadata.real_device);
ropes.entry(device.location()).or_insert_with(|| {
Arc::new(
Qwen3VLRotaryEmbedding::new(
cfg.rope_theta() as f32,
rot_dim,
device,
cfg.mrope_section().to_vec(),
)
.expect("Failed to create rotary embedding"),
)
});
}
let vb_l = vb_m.pp("layers");
let layers = NiceProgressBar::<_, 'b'>(
0..cfg.num_hidden_layers,
"Loading repeating layers",
&normal_loading_metadata.multi_progress,
)
.par_iter_if_isq(|layer_idx| {
let comm = mapper.get_comm_for(layer_idx)?;
let layer_impl = match layer_types[layer_idx] {
LayerType::FullAttention => {
let device = mapper
.device_for(layer_idx, false)
.unwrap_or(&normal_loading_metadata.real_device);
let rotary_emb = ropes
.get(&device.location())
.expect("No RoPE for device location!")
.clone();
let paged_attn = match &attention_mechanism {
AttentionImplementation::Eager => None,
AttentionImplementation::PagedAttention => {
Some(PagedAttention::new(cfg.head_dim, device, None)?)
}
};
LayerImpl::FullAttention(FullAttention::load(
vb_l.pp(layer_idx),
cfg,
&*mapper,
layer_idx,
normal_loading_metadata.loading_isq,
rotary_emb,
paged_attn,
&comm,
)?)
}
LayerType::LinearAttention => LayerImpl::LinearAttention(GatedDeltaNet::load(
vb_l.pp(layer_idx),
cfg as &dyn GdnConfig,
&*mapper,
layer_idx,
normal_loading_metadata.loading_isq,
&comm,
GdnWeightMode::MergedWithFallback,
)?),
};
let input_layernorm = GemmaRmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
mapper.set_device(layer_idx, vb_l.pp(layer_idx).pp("input_layernorm"), false),
)?;
let post_attention_layernorm = GemmaRmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
mapper.set_device(
layer_idx,
vb_l.pp(layer_idx).pp("post_attention_layernorm"),
false,
),
)?;
let mlp = Mlp::new(
mapper.set_device(
layer_idx,
vb_l.pp(layer_idx).pp("mlp"),
normal_loading_metadata.loading_isq,
),
cfg.hidden_size,
cfg.intermediate_size,
&cfg.quantization_config,
cfg.hidden_act,
&comm,
)?;
Ok(DecoderLayer {
layer_impl,
input_layernorm,
post_attention_layernorm,
mlp,
})
})?;
let norm = GemmaRmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
mapper.set_nm_device(vb_m.pp("norm"), false),
)?;
let lm_head = if !tie {
ReplicatedLayer::new(
cfg.hidden_size,
cfg.vocab_size,
&cfg.quantization_config,
false,
mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
)?
} else {
ReplicatedLayer::from_linear(candle_nn::Linear::new(
mapper.cast_nm_device(
embed_tokens.embeddings(),
normal_loading_metadata.loading_isq,
)?,
None,
))?
};
let pipeline_layer_types: Vec<HybridLayerType> = layer_types
.iter()
.map(|lt| match lt {
LayerType::FullAttention => HybridLayerType::Attention,
LayerType::LinearAttention => HybridLayerType::Recurrent,
})
.collect();
let hybrid_cache_config = HybridCacheConfig {
layer_types: pipeline_layer_types,
max_seq_len: cfg.max_position_embeddings,
recurrent: RecurrentLayerConfig {
conv_dim: cfg.linear_conv_dim(),
conv_width: cfg.linear_conv_kernel_dim,
state_dims: vec![
cfg.linear_num_value_heads,
cfg.linear_key_head_dim,
cfg.linear_value_head_dim,
],
},
};
let pipeline_cache = Arc::new(Mutex::new(
HybridCache::new(
hybrid_cache_config,
vb_m.dtype(),
&normal_loading_metadata.real_device,
)
.map_err(|e| {
candle_core::Error::Msg(format!("Failed to create hybrid cache: {}", e))
})?,
));
Ok(Self {
embed_tokens,
norm,
layers,
layer_types: layer_types.clone(),
lm_head,
cache: EitherCache::Hybrid(pipeline_cache),
max_seq_len: cfg.max_position_embeddings,
cfg: ModelConfigMetadata {
max_seq_len: cfg.max_position_embeddings,
num_layers: cfg.num_hidden_layers,
hidden_size: cfg.hidden_size,
num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
.max(1),
sliding_window: None,
k_head_dim: cfg.head_dim,
v_head_dim: cfg.head_dim,
kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
},
device: normal_loading_metadata.real_device.clone(),
dtype: vb.dtype(),
mapper,
})
}
pub fn embed_tokens(&self, input_ids: &Tensor) -> Result<Tensor> {
self.embed_tokens.forward(input_ids)
}
#[allow(clippy::too_many_arguments)]
pub fn forward_embeds(
&self,
mut xs: Tensor,
attention_mask: Option<&Tensor>,
position_ids: &Tensor,
_seqlen_offsets: &[usize],
context_lens: Vec<(usize, usize)>,
metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
flash_params: &FlashParams,
visual_pos_masks: Option<&Tensor>,
deepstack_visual_embeds: Option<&[Tensor]>,
) -> Result<Tensor> {
let mut hybrid_cache = self.cache.hybrid();
let state_indices = hybrid_cache.state_indices().cloned();
if self
.layer_types
.iter()
.any(|lt| matches!(lt, LayerType::LinearAttention))
&& state_indices.is_none()
{
candle_core::bail!(
"Hybrid recurrent state indices are required for linear-attention layers."
);
}
let cos_sin = {
let first_attn_idx = self
.layer_types
.iter()
.position(|lt| *lt == LayerType::FullAttention)
.expect("No full attention layer found");
match &self.layers[first_attn_idx].layer_impl {
LayerImpl::FullAttention(attn) => {
attn.rotary_emb.compute_cos_sin(position_ids, xs.dtype())?
}
_ => unreachable!(),
}
};
let attention_mask = DeviceMappedMask::new(attention_mask.cloned(), &*self.mapper)?;
let deepstack_indices = if let Some(visual_pos_masks) = visual_pos_masks {
let mask_flat: Vec<f32> = visual_pos_masks
.to_device(&self.device)?
.to_dtype(DType::F32)?
.flatten_all()?
.to_vec1()?;
let indices: Vec<u32> = mask_flat
.iter()
.enumerate()
.filter(|(_, &v)| v > 0.0)
.map(|(i, _)| i as u32)
.collect();
if indices.is_empty() {
None
} else {
let hidden = xs.dim(candle_core::D::Minus1)?;
let n = indices.len();
let idx = Tensor::from_vec(indices, (n,), &self.device)?;
let idx_expanded = idx.unsqueeze(1)?.repeat((1, hidden))?;
Some((idx, idx_expanded))
}
} else {
None
};
for (i, layer) in self.layers.iter().enumerate() {
xs = self.mapper.map(xs, i)?;
match &self.layer_types[i] {
LayerType::FullAttention => {
if let Some(HybridLayerCache::Attention(kv_cache)) = hybrid_cache.get_mut(i) {
xs = layer.forward_attention(
&xs,
attention_mask.as_ref().map(|m| m.get(xs.device())),
&cos_sin,
kv_cache,
metadata
.as_ref()
.map(|(kv_cache, meta)| (kv_cache[i].clone(), *meta)),
flash_params,
)?;
}
}
LayerType::LinearAttention => {
if let Some(HybridLayerCache::Recurrent(pool)) = hybrid_cache.get_mut(i) {
let indices = state_indices.as_ref().expect(
"checked above: linear-attention layers require recurrent indices",
);
let indices_vec: Vec<u32> = indices.to_vec1()?;
if indices_vec.is_empty() {
candle_core::bail!("Hybrid recurrent state indices are empty.");
}
let first_offset = pool.get_seqlen_offset(indices_vec[0] as usize);
if indices_vec
.iter()
.any(|&idx| pool.get_seqlen_offset(idx as usize) != first_offset)
{
candle_core::bail!(
"Hybrid recurrent seqlen offsets diverged within a batch for layer {i}."
);
}
let conv_state = pool.gather_conv_state(indices)?;
let recurrent_state = pool.gather_recurrent_state(indices)?;
let mut gdn_cache = GdnLayerCache {
conv_state,
recurrent_state,
seqlen_offset: first_offset,
};
xs = layer.forward_linear(&xs, &mut gdn_cache)?;
pool.scatter_conv_state(indices, &gdn_cache.conv_state)?;
pool.scatter_recurrent_state(indices, &gdn_cache.recurrent_state)?;
let delta = gdn_cache.seqlen_offset.saturating_sub(first_offset);
for &idx in &indices_vec {
let updated = pool.get_seqlen_offset(idx as usize) + delta;
pool.set_seqlen_offset(idx as usize, updated);
}
} else {
candle_core::bail!(
"Hybrid cache layer {i} is not recurrent for a linear-attention layer."
);
}
}
}
if let (Some((idx, idx_expanded)), Some(deepstack)) =
(&deepstack_indices, deepstack_visual_embeds)
{
if i < deepstack.len() {
xs = self.deepstack_process(xs, idx, idx_expanded, &deepstack[i])?;
}
}
}
let xs = xs.to_device(&self.device)?;
let xs = xs.apply(&self.norm)?;
let mut xs = extract_logits(&xs, context_lens)?;
if let Some(t) = self.lm_head.quantized_act_type() {
xs = xs.to_dtype(t)?;
}
mistralrs_quant::MatMul.qmethod_matmul(&xs, &*self.lm_head)
}
fn deepstack_process(
&self,
hidden_states: Tensor,
idx: &Tensor,
idx_expanded: &Tensor,
visual_embeds: &Tensor,
) -> Result<Tensor> {
let device = hidden_states.device();
let dtype = hidden_states.dtype();
let visual_embeds = visual_embeds.to_device(device)?.to_dtype(dtype)?;
let (batch, seq, hidden) = hidden_states.dims3()?;
let total = batch * seq;
let hidden_flat = hidden_states.reshape((total, hidden))?;
if idx.dim(0)? != visual_embeds.dim(0)? {
candle_core::bail!(
"Mismatch between DeepStack visual embeds ({}) and mask positions ({})",
visual_embeds.dim(0)?,
idx.dim(0)?
);
}
let result = hidden_flat.scatter_add(idx_expanded, &visual_embeds, 0)?;
result.reshape((batch, seq, hidden))
}
}
impl IsqModel for Qwen3_5TextModel {
fn get_layers(
&mut self,
) -> (
Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
&dyn DeviceMapper,
) {
let mut tensors = Vec::new();
tensors.push((&mut self.lm_head, None));
for (i, layer) in self.layers.iter_mut().enumerate() {
match &mut layer.layer_impl {
LayerImpl::FullAttention(attn) => {
tensors.push((&mut attn.q_proj, Some(i)));
tensors.push((&mut attn.k_proj, Some(i)));
tensors.push((&mut attn.v_proj, Some(i)));
tensors.push((&mut attn.o_proj, Some(i)));
}
LayerImpl::LinearAttention(gdn) => {
tensors.push((&mut gdn.out_proj, Some(i)));
}
}
for l in layer.mlp.get_isq_layers() {
tensors.push((l, Some(i)));
}
}
(tensors, &*self.mapper)
}
fn residual_tensors(&self) -> Vec<(String, Tensor)> {
let uvb = UnVarBuilder::new();
let uvb_lm = uvb.pp("model").pp("language_model");
uvb_lm.pp("embed_tokens").add(&self.embed_tokens);
uvb_lm.pp("norm").add(&self.norm);
for (layer_idx, layer) in self.layers.iter().enumerate() {
let uvb_l = uvb_lm.pp("layers").pp(layer_idx);
uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
uvb_l
.pp("post_attention_layernorm")
.add(&layer.post_attention_layernorm);
match &layer.layer_impl {
LayerImpl::FullAttention(attn) => {
uvb_l.pp("self_attn").pp("q_norm").add(&attn.q_norm);
uvb_l.pp("self_attn").pp("k_norm").add(&attn.k_norm);
}
LayerImpl::LinearAttention(gdn) => {
uvb_l
.pp("linear_attn")
.pp("in_proj_qkvz")
.add_tensor("weight", gdn.in_proj_qkvz.weight().clone());
uvb_l
.pp("linear_attn")
.pp("in_proj_ba")
.add_tensor("weight", gdn.in_proj_ba.weight().clone());
uvb_l
.pp("linear_attn")
.add_tensor("conv1d.weight", gdn.conv1d_weight.clone());
uvb_l
.pp("linear_attn")
.add_tensor("dt_bias", gdn.dt_bias.clone());
uvb_l
.pp("linear_attn")
.add_tensor("A_log", gdn.a_log.clone());
uvb_l
.pp("linear_attn")
.pp("norm")
.add_tensor("weight", gdn.norm.weight.clone());
}
}
}
uvb.to_safetensors()
}
}