use serde::{Deserialize, Serialize};
use trustformers_core::errors::invalid_config;
use trustformers_core::traits::Config;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TopKMethod {
GroupLimitedGreedy,
Noaux,
}
impl std::fmt::Display for TopKMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TopKMethod::GroupLimitedGreedy => write!(f, "GroupLimitedGreedy"),
TopKMethod::Noaux => write!(f, "Noaux"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ActivationType {
SiLU,
GeLU,
}
impl std::fmt::Display for ActivationType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ActivationType::SiLU => write!(f, "silu"),
ActivationType::GeLU => write!(f, "gelu"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeepSeekV2Config {
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub kv_lora_rank: usize,
pub q_lora_rank: usize,
pub qk_rope_head_dim: usize,
pub qk_nope_head_dim: usize,
pub v_head_dim: usize,
pub num_experts_per_tok: usize,
pub n_routed_experts: usize,
pub n_shared_experts: usize,
pub routed_scaling_factor: f32,
pub topk_method: TopKMethod,
pub n_group: usize,
pub topk_group: usize,
pub aux_loss_alpha: f32,
pub max_position_embeddings: usize,
pub rms_norm_eps: f64,
pub rope_theta: f64,
pub hidden_act: ActivationType,
pub initializer_range: f32,
pub first_k_dense_replace: usize,
pub moe_layer_freq: usize,
}
impl Default for DeepSeekV2Config {
fn default() -> Self {
Self {
vocab_size: 102400,
hidden_size: 5120,
intermediate_size: 12288,
num_hidden_layers: 60,
num_attention_heads: 128,
kv_lora_rank: 512,
q_lora_rank: 1536,
qk_rope_head_dim: 64,
qk_nope_head_dim: 128,
v_head_dim: 128,
num_experts_per_tok: 6,
n_routed_experts: 160,
n_shared_experts: 2,
routed_scaling_factor: 1.0,
topk_method: TopKMethod::GroupLimitedGreedy,
n_group: 8,
topk_group: 3,
aux_loss_alpha: 0.001,
max_position_embeddings: 163840,
rms_norm_eps: 1e-6,
rope_theta: 10000.0,
hidden_act: ActivationType::SiLU,
initializer_range: 0.02,
first_k_dense_replace: 1,
moe_layer_freq: 1,
}
}
}
impl Config for DeepSeekV2Config {
fn validate(&self) -> trustformers_core::errors::Result<()> {
if self.vocab_size == 0 {
return Err(invalid_config("vocab_size", "must be > 0".to_string()));
}
if self.hidden_size == 0 {
return Err(invalid_config("hidden_size", "must be > 0".to_string()));
}
if self.num_attention_heads == 0 {
return Err(invalid_config(
"num_attention_heads",
"must be > 0".to_string(),
));
}
if self.kv_lora_rank == 0 {
return Err(invalid_config("kv_lora_rank", "must be > 0".to_string()));
}
if self.qk_rope_head_dim == 0 {
return Err(invalid_config(
"qk_rope_head_dim",
"must be > 0".to_string(),
));
}
if self.qk_nope_head_dim == 0 {
return Err(invalid_config(
"qk_nope_head_dim",
"must be > 0".to_string(),
));
}
if self.v_head_dim == 0 {
return Err(invalid_config("v_head_dim", "must be > 0".to_string()));
}
if self.n_routed_experts == 0 {
return Err(invalid_config(
"n_routed_experts",
"must be > 0".to_string(),
));
}
if self.num_experts_per_tok == 0 {
return Err(invalid_config(
"num_experts_per_tok",
"must be > 0".to_string(),
));
}
if self.num_experts_per_tok > self.n_routed_experts {
return Err(invalid_config(
"num_experts_per_tok",
"must be <= n_routed_experts".to_string(),
));
}
if self.n_group == 0 {
return Err(invalid_config("n_group", "must be > 0".to_string()));
}
if self.num_hidden_layers == 0 {
return Err(invalid_config(
"num_hidden_layers",
"must be > 0".to_string(),
));
}
Ok(())
}
fn architecture(&self) -> &'static str {
"DeepSeek-V2"
}
}
impl DeepSeekV2Config {
pub fn qk_head_dim(&self) -> usize {
self.qk_rope_head_dim + self.qk_nope_head_dim
}
pub fn is_dense_layer(&self, layer_idx: usize) -> bool {
if layer_idx < self.first_k_dense_replace {
return true;
}
!(layer_idx - self.first_k_dense_replace).is_multiple_of(self.moe_layer_freq)
}
pub fn mha_kv_cache_per_token_per_layer(&self) -> usize {
self.num_attention_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim + self.v_head_dim)
}
pub fn mla_kv_cache_per_token_per_layer(&self) -> usize {
self.kv_lora_rank + self.qk_rope_head_dim
}
pub fn kv_cache_compression_ratio(&self) -> f64 {
let mla = self.mla_kv_cache_per_token_per_layer() as f64;
let mha = self.mha_kv_cache_per_token_per_layer() as f64;
if mha == 0.0 {
return 1.0;
}
mla / mha
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_deepseekv2_default_vocab_size() {
let cfg = DeepSeekV2Config::default();
assert_eq!(cfg.vocab_size, 102400);
}
#[test]
fn test_deepseekv2_default_hidden_size() {
let cfg = DeepSeekV2Config::default();
assert_eq!(cfg.hidden_size, 5120);
}
#[test]
fn test_deepseekv2_default_num_hidden_layers() {
let cfg = DeepSeekV2Config::default();
assert_eq!(cfg.num_hidden_layers, 60);
}
#[test]
fn test_deepseekv2_default_num_attention_heads() {
let cfg = DeepSeekV2Config::default();
assert_eq!(cfg.num_attention_heads, 128);
}
#[test]
fn test_deepseekv2_default_kv_lora_rank() {
let cfg = DeepSeekV2Config::default();
assert_eq!(cfg.kv_lora_rank, 512);
}
#[test]
fn test_deepseekv2_default_q_lora_rank() {
let cfg = DeepSeekV2Config::default();
assert_eq!(cfg.q_lora_rank, 1536);
}
#[test]
fn test_deepseekv2_default_qk_rope_head_dim() {
let cfg = DeepSeekV2Config::default();
assert_eq!(cfg.qk_rope_head_dim, 64);
}
#[test]
fn test_deepseekv2_default_qk_nope_head_dim() {
let cfg = DeepSeekV2Config::default();
assert_eq!(cfg.qk_nope_head_dim, 128);
}
#[test]
fn test_deepseekv2_default_v_head_dim() {
let cfg = DeepSeekV2Config::default();
assert_eq!(cfg.v_head_dim, 128);
}
#[test]
fn test_deepseekv2_default_n_routed_experts() {
let cfg = DeepSeekV2Config::default();
assert_eq!(cfg.n_routed_experts, 160);
}
#[test]
fn test_deepseekv2_default_num_experts_per_tok() {
let cfg = DeepSeekV2Config::default();
assert_eq!(cfg.num_experts_per_tok, 6);
}
#[test]
fn test_deepseekv2_default_n_shared_experts() {
let cfg = DeepSeekV2Config::default();
assert_eq!(cfg.n_shared_experts, 2);
}
#[test]
fn test_deepseekv2_validate_passes_default() {
let cfg = DeepSeekV2Config::default();
assert!(cfg.validate().is_ok());
}
#[test]
fn test_deepseekv2_validate_fails_zero_vocab_size() {
let cfg = DeepSeekV2Config {
vocab_size: 0,
..DeepSeekV2Config::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn test_deepseekv2_validate_fails_zero_hidden_size() {
let cfg = DeepSeekV2Config {
hidden_size: 0,
..DeepSeekV2Config::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn test_deepseekv2_validate_fails_zero_kv_lora_rank() {
let cfg = DeepSeekV2Config {
kv_lora_rank: 0,
..DeepSeekV2Config::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn test_deepseekv2_validate_fails_experts_per_tok_exceeds_total() {
let cfg = DeepSeekV2Config {
num_experts_per_tok: 200,
..DeepSeekV2Config::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn test_deepseekv2_qk_head_dim() {
let cfg = DeepSeekV2Config::default();
assert_eq!(
cfg.qk_head_dim(),
cfg.qk_rope_head_dim + cfg.qk_nope_head_dim
);
assert_eq!(cfg.qk_head_dim(), 192);
}
#[test]
fn test_deepseekv2_is_dense_layer_first() {
let cfg = DeepSeekV2Config::default();
assert!(cfg.is_dense_layer(0));
}
#[test]
fn test_deepseekv2_mla_kv_cache_smaller_than_mha() {
let cfg = DeepSeekV2Config::default();
let ratio = cfg.kv_cache_compression_ratio();
assert!(ratio < 1.0, "MLA should compress KV cache relative to MHA");
}
#[test]
fn test_deepseekv2_topk_method_default() {
let cfg = DeepSeekV2Config::default();
assert_eq!(cfg.topk_method, TopKMethod::GroupLimitedGreedy);
}
#[test]
fn test_deepseekv2_hidden_act_default() {
let cfg = DeepSeekV2Config::default();
assert_eq!(cfg.hidden_act, ActivationType::SiLU);
}
#[test]
fn test_deepseekv2_mha_kv_cache_size() {
let cfg = DeepSeekV2Config::default();
let expected = cfg.num_attention_heads
* (cfg.qk_nope_head_dim + cfg.qk_rope_head_dim + cfg.v_head_dim);
assert_eq!(cfg.mha_kv_cache_per_token_per_layer(), expected);
}
#[test]
fn test_deepseekv2_mla_kv_cache_size() {
let cfg = DeepSeekV2Config::default();
let expected = cfg.kv_lora_rank + cfg.qk_rope_head_dim;
assert_eq!(cfg.mla_kv_cache_per_token_per_layer(), expected);
}
#[test]
fn test_deepseekv2_lcg_values_in_range() {
let mut s = 42u64;
s = s.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
let v = (s % 1000) as f32 / 1000.0;
assert!((0.0..1.0).contains(&v));
}
#[test]
fn test_deepseekv2_architecture_name() {
let cfg = DeepSeekV2Config::default();
assert_eq!(cfg.architecture(), "DeepSeek-V2");
}
}