#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
use candle_core::{DType, Device, Module, Result, Tensor, D};
use candle_nn::Linear;
use mistralrs_quant::{
ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer,
ShardedVarBuilder,
};
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, sync::Arc};
use crate::moe::{MoEExperts, MoEExpertsConfig};
use crate::{
amoe::AnyMoeBaseModelMixin,
attention::SdpaParams,
device_map::{DeviceMappedMask, DeviceMapper},
layers::{self, embedding, Activation, CausalMasker, MatMul, RmsNorm, RotaryEmbedding, Sdpa},
layers_masker::PastKvLenCache,
paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
pipeline::{
extract_logits,
text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
EitherCache, IsqModel, KvCache, NormalCache, NormalCacheType, NormalLoadingMetadata,
NormalModel,
},
serde_default_fn,
utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
};
macro_rules! sliding_window {
($layer_idx:expr, $cfg:expr) => {
if !($cfg.sliding_window.is_some()
&& $cfg.use_sliding_window
&& $layer_idx >= $cfg.max_window_layers)
{
None
} else {
$cfg.sliding_window
}
};
}
serde_default_fn!(bool, tie_word_embeddings, false);
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Config {
pub(crate) vocab_size: usize,
pub(crate) hidden_size: usize,
pub(crate) intermediate_size: usize,
pub(crate) num_hidden_layers: usize,
pub(crate) num_attention_heads: usize,
pub(crate) num_key_value_heads: usize,
pub(crate) hidden_act: Activation,
pub(crate) max_position_embeddings: usize,
pub(crate) rms_norm_eps: f64,
pub(crate) rope_theta: f64,
pub(crate) sliding_window: Option<usize>,
pub(crate) head_dim: Option<usize>,
pub(crate) quantization_config: Option<QuantizedConfig>,
#[serde(default = "tie_word_embeddings")]
pub(crate) tie_word_embeddings: bool,
pub(crate) max_window_layers: usize,
pub(crate) use_sliding_window: bool,
pub(crate) moe_intermediate_size: usize,
pub(crate) num_experts: usize,
pub(crate) mlp_only_layers: Vec<usize>,
pub(crate) decoder_sparse_step: usize,
pub(crate) norm_topk_prob: bool,
pub(crate) num_experts_per_tok: usize,
}
impl Config {
pub(crate) fn head_dim(&self) -> usize {
self.head_dim
.unwrap_or(self.hidden_size / self.num_attention_heads)
}
}
struct Attention {
q_proj: Arc<dyn QuantMethod>,
k_proj: Arc<dyn QuantMethod>,
v_proj: Arc<dyn QuantMethod>,
o_proj: Arc<dyn QuantMethod>,
q_norm: RmsNorm,
k_norm: RmsNorm,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
rotary_emb: Arc<RotaryEmbedding>,
paged_attn: Option<PagedAttention>,
sdpa_params: SdpaParams,
}
impl Attention {
#[allow(clippy::too_many_arguments)]
fn new(
rotary_emb: Arc<RotaryEmbedding>,
cfg: &Config,
vb: ShardedVarBuilder,
mapper: &dyn DeviceMapper,
layer_idx: usize,
loading_isq: bool,
paged_attn: Option<PagedAttention>,
comm: &Arc<mistralrs_quant::Comm>,
) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
let head_dim = cfg.head_dim();
let q_proj = ColumnParallelLayer::new(
hidden_sz,
num_heads * head_dim,
&cfg.quantization_config,
false,
comm,
mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
)?;
let kv_shard = mistralrs_quant::compute_kv_shard(
cfg.num_key_value_heads,
cfg.hidden_size / cfg.num_attention_heads,
comm,
);
let k_proj = ColumnParallelLayer::new_with_shard(
hidden_sz,
num_kv_heads * head_dim,
&cfg.quantization_config,
false,
comm,
kv_shard,
mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
)?;
let v_proj = ColumnParallelLayer::new_with_shard(
hidden_sz,
num_kv_heads * head_dim,
&cfg.quantization_config,
false,
comm,
kv_shard,
mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
)?;
let o_proj = RowParallelLayer::new(
num_heads * head_dim,
hidden_sz,
&cfg.quantization_config,
false,
comm,
mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
)?;
let sliding_window = sliding_window!(layer_idx, cfg);
let q_norm = RmsNorm::new(
cfg.head_dim(),
cfg.rms_norm_eps,
mapper.set_device(layer_idx, vb.pp("q_norm"), false),
)?;
let k_norm = RmsNorm::new(
cfg.head_dim(),
cfg.rms_norm_eps,
mapper.set_device(layer_idx, vb.pp("k_norm"), false),
)?;
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
num_heads: num_heads / comm.world_size(),
num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
head_dim,
rotary_emb,
paged_attn,
sdpa_params: SdpaParams {
n_kv_groups: mistralrs_quant::compute_n_kv_groups(
cfg.num_key_value_heads,
cfg.num_attention_heads,
comm,
),
softcap: None,
softmax_scale: 1.0 / (head_dim as f32).sqrt(),
sliding_window,
sinks: None,
},
})
}
#[allow(clippy::too_many_arguments)]
fn forward(
&self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offsets: &[usize],
kv_cache: &mut KvCache,
metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
flash_params: &FlashParams,
) -> Result<Tensor> {
let (b_sz, q_len, _) = xs.dims3()?;
let original_dtype = xs.dtype();
let mut xs = xs.clone();
if let Some(t) = self.q_proj.quantized_act_type() {
xs = xs.to_dtype(t)?;
}
let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
if self.q_proj.quantized_act_type().is_some() {
q = q.to_dtype(original_dtype)?;
k = k.to_dtype(original_dtype)?;
v = v.to_dtype(original_dtype)?;
}
(q, k, v) = if q_len != 1 {
let q = q
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let k = k
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let v = v
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
(q, k, v)
} else {
let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
(q, k, v)
};
q = q.apply(&self.q_norm)?;
k = k.apply(&self.k_norm)?;
(q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
let mut attn_output = match &self.paged_attn {
Some(paged_attn) => match metadata {
Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
&q,
&k,
&v,
attention_mask,
Some(key_cache),
Some(value_cache),
input_metadata,
&self.sdpa_params,
Some(flash_params),
)?,
None => {
let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
assert!(attention_mask.is_some());
paged_attn.forward(
&q,
&k,
&v,
attention_mask,
None,
None,
&input_metadata,
&self.sdpa_params,
Some(flash_params),
)?
}
},
None => {
let (k, v) = kv_cache.append(&k, &v)?;
Sdpa.run_attention(
&q,
&k,
&v,
attention_mask,
Some(flash_params),
&self.sdpa_params,
)?
}
};
if let Some(t) = self.q_proj.quantized_act_type() {
attn_output = attn_output.to_dtype(t)?;
}
attn_output = if attention_mask.is_some() {
attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
} else {
attn_output.reshape((b_sz, q_len, ()))?
};
let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
if self.q_proj.quantized_act_type().is_some() {
res = res.to_dtype(original_dtype)?;
}
Ok(res)
}
}
#[derive(Clone)]
struct Mlp {
gate_proj: Arc<dyn QuantMethod>,
up_proj: Arc<dyn QuantMethod>,
down_proj: Arc<dyn QuantMethod>,
act_fn: Activation,
}
impl Mlp {
fn new(
cfg: &Config,
vb: ShardedVarBuilder,
comm: &Arc<mistralrs_quant::Comm>,
i_size: usize,
) -> Result<Self> {
let hidden_size = cfg.hidden_size;
let gate_proj = ColumnParallelLayer::new(
hidden_size,
i_size,
&cfg.quantization_config,
false,
comm,
vb.pp("gate_proj"),
)?;
let up_proj = RowParallelLayer::new(
hidden_size,
i_size,
&cfg.quantization_config,
false,
comm,
vb.pp("up_proj"),
)?;
let down_proj = ColumnParallelLayer::new(
i_size,
hidden_size,
&cfg.quantization_config,
false,
comm,
vb.pp("down_proj"),
)?;
Ok(Self {
gate_proj,
up_proj,
down_proj,
act_fn: cfg.hidden_act,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let original_dtype = xs.dtype();
let mut xs = xs.clone();
if let Some(t) = self.gate_proj.quantized_act_type() {
xs = xs.to_dtype(t)?;
}
let gate_out = MatMul.qmethod_matmul(&xs, &*self.gate_proj)?;
let up_out = MatMul.qmethod_matmul(&xs, &*self.up_proj)?;
let current_hidden_states = crate::ops::mul_and_act(&gate_out, &up_out, self.act_fn)?;
let mut res = MatMul.qmethod_matmul(¤t_hidden_states, &*self.down_proj)?;
if self.gate_proj.quantized_act_type().is_some() {
res = res.to_dtype(original_dtype)?;
}
Ok(res)
}
}
struct MoeMlp {
gate: Linear,
experts: MoEExperts,
num_experts_per_tok: usize,
norm_topk_prob: bool,
}
impl MoeMlp {
fn new(
cfg: &Config,
vb: ShardedVarBuilder,
layer_device: Device,
comm: &Arc<mistralrs_quant::Comm>,
loading_isq: bool,
) -> Result<Self> {
let gate = layers::linear_no_bias(
cfg.hidden_size,
cfg.num_experts,
vb.pp("gate").set_device(layer_device.clone()),
)?;
let moe_cfg = MoEExpertsConfig {
num_experts: cfg.num_experts,
num_experts_per_tok: cfg.num_experts_per_tok,
hidden_size: cfg.hidden_size,
moe_intermediate_size: cfg.moe_intermediate_size,
};
let experts = MoEExperts::new(
&moe_cfg,
vb,
layer_device,
comm,
loading_isq,
&cfg.quantization_config,
cfg.hidden_act,
)?;
Ok(Self {
gate,
experts,
num_experts_per_tok: cfg.num_experts_per_tok,
norm_topk_prob: cfg.norm_topk_prob,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (b_size, seq_len, hidden_dim) = xs.dims3()?;
let xs_flat = xs.reshape(((), hidden_dim))?;
let router_logits = self.gate.forward(&xs_flat)?;
let routing_weights =
candle_nn::ops::softmax_last_dim(&router_logits.to_dtype(DType::F32)?)?;
let topk_ids = routing_weights
.arg_sort_last_dim(false)?
.narrow(D::Minus1, 0, self.num_experts_per_tok)?
.contiguous()?;
let mut topk_weights = routing_weights.gather(&topk_ids, D::Minus1)?;
if self.norm_topk_prob {
topk_weights = topk_weights.broadcast_div(&topk_weights.sum_keepdim(D::Minus1)?)?;
}
let ys = self.experts.forward(xs, topk_weights, &topk_ids)?;
ys.reshape((b_size, seq_len, hidden_dim))
}
fn gate(&self) -> &Linear {
&self.gate
}
fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
self.experts.get_isq_layers()
}
}
enum MoeOrMlp {
Moe(MoeMlp),
Mlp(Mlp),
}
impl MoeOrMlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::Mlp(m) => m.forward(xs),
Self::Moe(m) => m.forward(xs),
}
}
}
struct DecoderLayer {
self_attn: Attention,
mlp: MoeOrMlp,
input_layernorm: RmsNorm,
post_attention_layernorm: RmsNorm,
}
impl DecoderLayer {
#[allow(clippy::too_many_arguments)]
fn new(
rotary_emb: Arc<RotaryEmbedding>,
cfg: &Config,
vb: ShardedVarBuilder,
mapper: &dyn DeviceMapper,
layer_idx: usize,
loading_isq: bool,
paged_attn: Option<PagedAttention>,
real_device: Device,
comm: &Arc<mistralrs_quant::Comm>,
) -> Result<Self> {
let self_attn = Attention::new(
rotary_emb,
cfg,
mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
mapper,
layer_idx,
loading_isq,
paged_attn,
comm,
)?;
let mlp = if !cfg.mlp_only_layers.contains(&layer_idx)
&& (cfg.num_experts > 0 && (layer_idx + 1).is_multiple_of(cfg.decoder_sparse_step))
{
let vb = mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq);
let layer_device = mapper
.device_for(layer_idx, false)
.cloned()
.unwrap_or(real_device);
MoeOrMlp::Moe(MoeMlp::new(cfg, vb, layer_device, comm, loading_isq)?)
} else {
MoeOrMlp::Mlp(Mlp::new(
cfg,
mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
comm,
cfg.intermediate_size,
)?)
};
let input_layernorm = RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
)?;
let post_attention_layernorm = RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
)?;
Ok(Self {
self_attn,
mlp,
input_layernorm,
post_attention_layernorm,
})
}
#[allow(clippy::too_many_arguments)]
fn forward(
&self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offsets: &[usize],
kv_cache: &mut KvCache,
metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
flash_params: &FlashParams,
) -> Result<Tensor> {
let residual = xs;
let xs = self.input_layernorm.forward(xs)?;
let xs = self.self_attn.forward(
&xs,
attention_mask,
seqlen_offsets,
kv_cache,
metadata,
flash_params,
)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = self
.mlp
.forward(&xs.apply(&self.post_attention_layernorm)?)?;
residual + xs
}
}
pub struct Model {
embed_tokens: candle_nn::Embedding,
layers: Vec<DecoderLayer>,
norm: RmsNorm,
lm_head: Arc<dyn QuantMethod>,
sliding_window: Option<usize>,
device: Device,
cache: EitherCache,
max_seq_len: usize,
mapper: Box<dyn DeviceMapper + Send + Sync>,
cfg: ModelConfigMetadata,
}
impl Model {
pub fn new(
cfg: &Config,
vb: ShardedVarBuilder,
is_gptx: bool,
normal_loading_metadata: NormalLoadingMetadata,
attention_mechanism: AttentionImplementation,
) -> Result<Self> {
let vb_m = vb.pp("model");
let vb_lm_head = vb.pp("lm_head");
Self::new_inner(
cfg,
vb_m,
vb_lm_head,
is_gptx,
normal_loading_metadata,
attention_mechanism,
)
}
pub fn new_inner(
cfg: &Config,
vb_m: ShardedVarBuilder,
vb_lm_head: ShardedVarBuilder,
is_gptx: bool,
normal_loading_metadata: NormalLoadingMetadata,
attention_mechanism: AttentionImplementation,
) -> Result<Self> {
if let Some(ref quant_cfg) = &cfg.quantization_config {
tracing::info!(
"Using {} quantization: {}.",
quant_cfg.name(),
quant_cfg.get_bits_name(&vb_m)
);
}
let mapper = normal_loading_metadata.mapper;
let embed_tokens = embedding(
cfg.vocab_size,
cfg.hidden_size,
mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
&cfg.quantization_config,
)?;
let head_dim = cfg.head_dim();
let mut ropes = HashMap::new();
for layer_idx in 0..cfg.num_hidden_layers {
let device = mapper
.device_for(layer_idx, false)
.unwrap_or(&normal_loading_metadata.real_device);
ropes.insert(
device.location(),
Arc::new(RotaryEmbedding::new(
cfg.rope_theta as f32,
head_dim,
cfg.max_position_embeddings,
device,
is_gptx,
vb_m.dtype(),
)?),
);
}
let load_in_parallel =
!(normal_loading_metadata.real_device.is_metal() && cfg.quantization_config.is_none());
let vb_l = vb_m.pp("layers");
let layers: Vec<DecoderLayer> = NiceProgressBar::<_, 'b'>(
0..cfg.num_hidden_layers,
"Loading repeating layers",
&normal_loading_metadata.multi_progress,
)
.run(load_in_parallel, |layer_idx| {
let device = mapper
.device_for(layer_idx, false)
.unwrap_or(&normal_loading_metadata.real_device);
let rotary_emb = ropes
.get(&device.location())
.expect("No RoPE for device location!")
.clone();
let paged_attn = match &attention_mechanism {
AttentionImplementation::Eager => None,
AttentionImplementation::PagedAttention => Some(
PagedAttention::new(head_dim, device, None)
.expect("PagedAttention creation failed"),
),
};
let comm = mapper
.get_comm_for(layer_idx)
.expect("Failed to get comm for layer");
DecoderLayer::new(
rotary_emb.clone(),
cfg,
vb_l.pp(layer_idx),
&*mapper,
layer_idx,
normal_loading_metadata.loading_isq,
paged_attn,
normal_loading_metadata.real_device.clone(),
&comm,
)
})?;
let norm = RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
mapper.set_nm_device(vb_m.pp("norm"), false),
)?;
let lm_head = if !cfg.tie_word_embeddings {
ReplicatedLayer::new(
cfg.hidden_size,
cfg.vocab_size,
&cfg.quantization_config,
false,
mapper.set_nm_device(vb_lm_head, normal_loading_metadata.loading_isq),
)?
} else {
ReplicatedLayer::from_linear(candle_nn::Linear::new(
mapper.cast_nm_device(
embed_tokens.embeddings(),
normal_loading_metadata.loading_isq,
)?,
None,
))?
};
let cache_types = (0..cfg.num_hidden_layers)
.map(|layer_idx| {
sliding_window!(layer_idx, cfg)
.map(|window| NormalCacheType::SlidingWindow { window })
.unwrap_or(NormalCacheType::Normal {
max_seq_len: cfg.max_position_embeddings,
})
})
.collect::<Vec<_>>();
Ok(Self {
embed_tokens,
layers,
norm,
lm_head,
sliding_window: cfg.sliding_window,
device: normal_loading_metadata.real_device,
cache: EitherCache::Normal(NormalCache::from_types(cache_types)),
max_seq_len: cfg.max_position_embeddings,
cfg: ModelConfigMetadata {
max_seq_len: cfg.max_position_embeddings,
num_layers: cfg.num_hidden_layers,
hidden_size: cfg.hidden_size,
num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
.max(1),
num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
sliding_window: cfg.sliding_window,
k_head_dim: cfg.head_dim(),
v_head_dim: cfg.head_dim(),
kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
},
mapper,
})
}
pub fn forward(
&self,
input_ids: &Tensor,
seqlen_offsets: &[usize],
context_lens: Vec<(usize, usize)>,
metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
flash_params: &FlashParams,
) -> Result<Tensor> {
self.forward_embeds(
input_ids,
self.embed_tokens.forward(input_ids)?,
seqlen_offsets,
context_lens,
metadata,
flash_params,
)
}
#[allow(clippy::too_many_arguments)]
pub fn forward_embeds(
&self,
input_ids: &Tensor,
input_embeds: Tensor,
seqlen_offsets: &[usize],
context_lens: Vec<(usize, usize)>,
metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
flash_params: &FlashParams,
) -> Result<Tensor> {
let mut xs = input_embeds;
let cache = &mut self.cache.normal().0;
let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
input_ids,
metadata
.as_ref()
.map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
.unwrap_or(cache as &dyn PastKvLenCache),
self.sliding_window,
xs.dtype(),
self.cfg.num_attn_heads,
)?;
let attention_mask = attention_mask.filter(|_| {
metadata
.as_ref()
.map(|(_, meta)| meta.is_first_prompt_chunk)
.unwrap_or(true)
});
let attention_mask = DeviceMappedMask::new(attention_mask, &*self.mapper)?;
for (i, layer) in self.layers.iter().enumerate() {
xs = self.mapper.map(xs, i)?;
xs = layer.forward(
&xs,
attention_mask.as_ref().map(|m| m.get(xs.device())),
seqlen_offsets,
&mut cache[i],
metadata
.as_ref()
.map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
flash_params,
)?;
}
let xs = xs.to_device(&self.device)?;
let xs = xs.apply(&self.norm)?;
let mut xs = extract_logits(&xs, context_lens)?;
if let Some(t) = self.lm_head.quantized_act_type() {
xs = xs.to_dtype(t)?;
}
MatMul.qmethod_matmul(&xs, &*self.lm_head)
}
}
impl IsqModel for Model {
fn get_layers(
&mut self,
) -> (
Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
&dyn DeviceMapper,
) {
let mut tensors = Vec::new();
tensors.push((&mut self.lm_head, None));
for (i, layer) in self.layers.iter_mut().enumerate() {
tensors.push((&mut layer.self_attn.q_proj, Some(i)));
tensors.push((&mut layer.self_attn.k_proj, Some(i)));
tensors.push((&mut layer.self_attn.v_proj, Some(i)));
tensors.push((&mut layer.self_attn.o_proj, Some(i)));
match &mut layer.mlp {
MoeOrMlp::Mlp(mlp) => {
tensors.push((&mut mlp.gate_proj, Some(i)));
tensors.push((&mut mlp.up_proj, Some(i)));
tensors.push((&mut mlp.down_proj, Some(i)));
}
MoeOrMlp::Moe(moe) => {
for layer in moe.get_isq_layers() {
tensors.push((layer, Some(i)));
}
}
}
}
(tensors, &*self.mapper)
}
fn residual_tensors(&self) -> Vec<(String, Tensor)> {
let uvb = UnVarBuilder::new();
let uvb_m = uvb.pp("model");
uvb_m.pp("embed_tokens").add(&self.embed_tokens);
uvb_m.pp("norm").add(&self.norm);
for (layer_idx, layer) in self.layers.iter().enumerate() {
let uvb_l = uvb_m.pp("layers").pp(layer_idx);
uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
uvb_l
.pp("post_attention_layernorm")
.add(&layer.post_attention_layernorm);
uvb_l
.pp("self_attn")
.pp("q_norm")
.add(&layer.self_attn.q_norm);
uvb_l
.pp("self_attn")
.pp("k_norm")
.add(&layer.self_attn.k_norm);
if let MoeOrMlp::Moe(moe) = &layer.mlp {
uvb_l.pp("mlp").pp("gate").add(moe.gate());
}
}
uvb.to_safetensors()
}
}
impl NormalModel for Model {
fn forward(
&self,
input_ids: &Tensor,
seqlen_offsets: &[usize],
context_lens: Vec<(usize, usize)>,
_position_ids: Vec<usize>,
metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
flash_params: &FlashParams,
) -> Result<Tensor> {
self.forward(
input_ids,
seqlen_offsets,
context_lens,
metadata,
flash_params,
)
}
fn xlora_forward(
&self,
_input_ids: &Tensor,
_input_ids_full: &Tensor,
_seqlen_offsets: &[usize],
_seqlen_offsets_full: &[usize],
_no_kv_cache: bool,
_non_granular_state: &Option<crate::xlora_models::NonGranularState>,
_context_lens: Vec<(usize, usize)>,
_position_ids: Vec<usize>,
_flash_params: &FlashParams,
_flash_params_full: &FlashParams,
) -> Result<Tensor> {
unimplemented!()
}
fn cache(&self) -> &EitherCache {
&self.cache
}
fn cache_mut(&mut self) -> &mut EitherCache {
&mut self.cache
}
fn device(&self) -> &Device {
&self.device
}
fn is_xlora(&self) -> bool {
false
}
fn max_seq_len(&self) -> usize {
self.max_seq_len
}
fn config(&self) -> &ModelConfigMetadata {
&self.cfg
}
}
impl AnyMoeBaseModelMixin for Model {}