1use std::fmt;
27use std::io::Read as _;
28use std::path::Path;
29
30use serde_json::Value;
31
32use crate::error::{MIError, Result};
33
34pub const SUPPORTED_MODEL_TYPES: &[&str] = &[
44 "gemma",
45 "gemma2",
46 "llama",
47 "mistral",
48 "phi3",
49 "qwen2",
50 "starcoder2",
51];
52
53#[non_exhaustive]
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60pub enum NormType {
61 RmsNorm,
63 LayerNorm,
65 GemmaRmsNorm,
68}
69
70impl fmt::Display for NormType {
71 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72 match self {
73 Self::RmsNorm => write!(f, "RmsNorm"),
74 Self::LayerNorm => write!(f, "LayerNorm"),
75 Self::GemmaRmsNorm => write!(f, "GemmaRmsNorm"),
76 }
77 }
78}
79
80#[non_exhaustive]
82#[derive(Debug, Clone, Copy, PartialEq, Eq)]
83pub enum Activation {
84 Silu,
86 Gelu,
88 GeluApprox,
93}
94
95impl fmt::Display for Activation {
96 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97 match self {
98 Self::Silu => write!(f, "SiLU"),
99 Self::Gelu => write!(f, "GELU"),
100 Self::GeluApprox => write!(f, "GELU (tanh approx)"),
101 }
102 }
103}
104
105#[non_exhaustive]
107#[derive(Debug, Clone, Copy, PartialEq, Eq)]
108pub enum QkvLayout {
109 Separate,
111 Fused,
113}
114
115impl fmt::Display for QkvLayout {
116 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
117 match self {
118 Self::Separate => write!(f, "Separate"),
119 Self::Fused => write!(f, "Fused"),
120 }
121 }
122}
123
124#[non_exhaustive]
126#[derive(Debug, Clone, Copy, PartialEq, Eq)]
127pub enum MlpLayout {
128 GatedSeparate,
131 GatedFused,
134 Plain,
136}
137
138impl fmt::Display for MlpLayout {
139 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140 match self {
141 Self::GatedSeparate => write!(f, "GatedSeparate"),
142 Self::GatedFused => write!(f, "GatedFused"),
143 Self::Plain => write!(f, "Plain"),
144 }
145 }
146}
147
148#[derive(Debug, Clone, PartialEq)]
242#[allow(clippy::struct_excessive_bools)] pub struct TransformerConfig {
244 pub hidden_size: usize,
247 pub num_layers: usize,
249 pub num_attention_heads: usize,
251 pub num_kv_heads: usize,
253 pub head_dim: usize,
255 pub intermediate_size: usize,
257 pub vocab_size: usize,
259
260 pub norm_type: NormType,
263 pub norm_eps: f64,
265 pub activation: Activation,
267 pub qkv_layout: QkvLayout,
269 pub mlp_layout: MlpLayout,
271 pub qkv_bias: bool,
273 pub o_proj_bias: bool,
275 pub mlp_bias: bool,
277 pub embedding_scale: Option<f64>,
279 pub tie_word_embeddings: bool,
281
282 pub rope_theta: f64,
285 pub max_position_embeddings: usize,
287
288 pub attn_logit_softcapping: Option<f64>,
292 pub final_logit_softcapping: Option<f64>,
295 pub query_pre_attn_scalar: Option<f64>,
299 pub use_post_norms: bool,
302
303 pub sliding_window: Option<usize>,
306 pub alternating_sliding_window: bool,
310}
311
312impl TransformerConfig {
317 pub fn from_hf_config(config: &Value) -> Result<Self> {
329 let model_type = config
330 .get("model_type")
331 .and_then(Value::as_str)
332 .ok_or_else(|| MIError::Config("missing 'model_type' field".into()))?;
333
334 match model_type {
336 "llama" => Self::parse_llama(config),
337 "qwen2" => Self::parse_qwen2(config),
338 "gemma" => Self::parse_gemma(config),
339 "gemma2" => Self::parse_gemma2(config),
340 "phi3" => Self::parse_phi3(config),
341 "starcoder2" => Self::parse_starcoder2(config),
342 "mistral" => Self::parse_mistral(config),
343 other => Err(MIError::Config(format!(
344 "unsupported model_type: '{other}'"
345 ))),
346 }
347 }
348}
349
350impl TransformerConfig {
355 fn parse_llama(config: &Value) -> Result<Self> {
364 let hidden_size = get_usize(config, "hidden_size")?;
365 let num_attention_heads = get_usize(config, "num_attention_heads")?;
366
367 Ok(Self {
368 hidden_size,
369 num_layers: get_usize(config, "num_hidden_layers")?,
370 num_attention_heads,
371 num_kv_heads: get_usize_or(config, "num_key_value_heads", num_attention_heads),
372 head_dim: get_head_dim(config, hidden_size, num_attention_heads)?,
373 intermediate_size: get_usize(config, "intermediate_size")?,
374 vocab_size: get_usize(config, "vocab_size")?,
375
376 norm_type: NormType::RmsNorm,
377 norm_eps: get_f64_or(config, "rms_norm_eps", 1e-5),
378 activation: Activation::Silu,
379 qkv_layout: QkvLayout::Separate,
380 mlp_layout: MlpLayout::GatedSeparate,
381 qkv_bias: false,
382 o_proj_bias: false,
383 mlp_bias: false,
384 embedding_scale: None,
385 tie_word_embeddings: get_bool_or(config, "tie_word_embeddings", false),
386
387 rope_theta: get_f64_or(config, "rope_theta", 10_000.0),
388 max_position_embeddings: get_usize_or(config, "max_position_embeddings", 4096),
389
390 attn_logit_softcapping: None,
391 final_logit_softcapping: None,
392 query_pre_attn_scalar: None,
393 use_post_norms: false,
394 sliding_window: None,
395 alternating_sliding_window: false,
396 })
397 }
398
399 fn parse_qwen2(config: &Value) -> Result<Self> {
408 let hidden_size = get_usize(config, "hidden_size")?;
409 let num_attention_heads = get_usize(config, "num_attention_heads")?;
410
411 Ok(Self {
412 hidden_size,
413 num_layers: get_usize(config, "num_hidden_layers")?,
414 num_attention_heads,
415 num_kv_heads: get_usize_or(config, "num_key_value_heads", num_attention_heads),
416 head_dim: get_head_dim(config, hidden_size, num_attention_heads)?,
417 intermediate_size: get_usize(config, "intermediate_size")?,
418 vocab_size: get_usize(config, "vocab_size")?,
419
420 norm_type: NormType::RmsNorm,
421 norm_eps: get_f64_or(config, "rms_norm_eps", 1e-6),
422 activation: Activation::Silu,
423 qkv_layout: QkvLayout::Separate,
424 mlp_layout: MlpLayout::GatedSeparate,
425 qkv_bias: get_bool_or(config, "attention_bias", true),
426 o_proj_bias: false,
427 mlp_bias: false,
428 embedding_scale: None,
429 tie_word_embeddings: get_bool_or(config, "tie_word_embeddings", false),
430
431 rope_theta: get_f64_or(config, "rope_theta", 1_000_000.0),
432 max_position_embeddings: get_usize_or(config, "max_position_embeddings", 32_768),
433
434 attn_logit_softcapping: None,
435 final_logit_softcapping: None,
436 query_pre_attn_scalar: None,
437 use_post_norms: false,
438 sliding_window: None,
439 alternating_sliding_window: false,
440 })
441 }
442
443 fn parse_gemma(config: &Value) -> Result<Self> {
451 let hidden_size = get_usize(config, "hidden_size")?;
452 let num_attention_heads = get_usize(config, "num_attention_heads")?;
453
454 Ok(Self {
455 hidden_size,
456 num_layers: get_usize(config, "num_hidden_layers")?,
457 num_attention_heads,
458 num_kv_heads: get_usize_or(config, "num_key_value_heads", num_attention_heads),
459 head_dim: get_head_dim(config, hidden_size, num_attention_heads)?,
460 intermediate_size: get_usize(config, "intermediate_size")?,
461 vocab_size: get_usize(config, "vocab_size")?,
462
463 norm_type: NormType::GemmaRmsNorm,
464 norm_eps: get_f64_or(config, "rms_norm_eps", 1e-6),
465 activation: Activation::GeluApprox,
466 qkv_layout: QkvLayout::Separate,
467 mlp_layout: MlpLayout::GatedSeparate,
468 qkv_bias: false,
469 o_proj_bias: false,
470 mlp_bias: false,
471 #[allow(clippy::cast_precision_loss, clippy::as_conversions)]
473 embedding_scale: Some((hidden_size as f64).sqrt()),
475 tie_word_embeddings: get_bool_or(config, "tie_word_embeddings", true),
476
477 rope_theta: get_f64_or(config, "rope_theta", 10_000.0),
478 max_position_embeddings: get_usize_or(
479 config,
480 "max_position_embeddings",
481 8192,
482 ),
483
484 attn_logit_softcapping: None,
485 final_logit_softcapping: None,
486 query_pre_attn_scalar: None,
487 use_post_norms: false,
488 sliding_window: None,
489 alternating_sliding_window: false,
490 })
491 }
492
493 fn parse_gemma2(config: &Value) -> Result<Self> {
502 let hidden_size = get_usize(config, "hidden_size")?;
503 let num_attention_heads = get_usize(config, "num_attention_heads")?;
504
505 Ok(Self {
506 hidden_size,
507 num_layers: get_usize(config, "num_hidden_layers")?,
508 num_attention_heads,
509 num_kv_heads: get_usize_or(config, "num_key_value_heads", num_attention_heads),
510 head_dim: get_head_dim(config, hidden_size, num_attention_heads)?,
511 intermediate_size: get_usize(config, "intermediate_size")?,
512 vocab_size: get_usize(config, "vocab_size")?,
513
514 norm_type: NormType::GemmaRmsNorm,
515 norm_eps: get_f64_or(config, "rms_norm_eps", 1e-6),
516 activation: Activation::GeluApprox,
517 qkv_layout: QkvLayout::Separate,
518 mlp_layout: MlpLayout::GatedSeparate,
519 qkv_bias: false,
520 o_proj_bias: false,
521 mlp_bias: false,
522 #[allow(clippy::cast_precision_loss, clippy::as_conversions)]
524 embedding_scale: Some((hidden_size as f64).sqrt()),
526 tie_word_embeddings: get_bool_or(config, "tie_word_embeddings", true),
527
528 rope_theta: get_f64_or(config, "rope_theta", 10_000.0),
529 max_position_embeddings: get_usize_or(
530 config,
531 "max_position_embeddings",
532 8192,
533 ),
534
535 attn_logit_softcapping: get_optional_f64(config, "attn_logit_softcapping"),
536 final_logit_softcapping: get_optional_f64(config, "final_logit_softcapping"),
537 query_pre_attn_scalar: get_optional_f64(config, "query_pre_attn_scalar")
538 .or(Some(256.0)),
539 use_post_norms: true,
540 sliding_window: get_optional_usize(config, "sliding_window"),
541 alternating_sliding_window: true,
542 })
543 }
544
545 fn parse_phi3(config: &Value) -> Result<Self> {
553 let hidden_size = get_usize(config, "hidden_size")?;
554 let num_attention_heads = get_usize(config, "num_attention_heads")?;
555
556 Ok(Self {
557 hidden_size,
558 num_layers: get_usize(config, "num_hidden_layers")?,
559 num_attention_heads,
560 num_kv_heads: get_usize_or(config, "num_key_value_heads", num_attention_heads),
561 head_dim: get_head_dim(config, hidden_size, num_attention_heads)?,
562 intermediate_size: get_usize(config, "intermediate_size")?,
563 vocab_size: get_usize(config, "vocab_size")?,
564
565 norm_type: NormType::RmsNorm,
566 norm_eps: get_f64_or(config, "rms_norm_eps", 1e-5),
567 activation: Activation::Silu,
568 qkv_layout: QkvLayout::Fused,
569 mlp_layout: MlpLayout::GatedFused,
570 qkv_bias: false,
571 o_proj_bias: false,
572 mlp_bias: false,
573 embedding_scale: None,
574 tie_word_embeddings: get_bool_or(config, "tie_word_embeddings", false),
575
576 rope_theta: get_f64_or(config, "rope_theta", 10_000.0),
577 max_position_embeddings: get_usize_or(config, "max_position_embeddings", 4096),
578
579 attn_logit_softcapping: None,
580 final_logit_softcapping: None,
581 query_pre_attn_scalar: None,
582 use_post_norms: false,
583 sliding_window: None,
584 alternating_sliding_window: false,
585 })
586 }
587
588 fn parse_starcoder2(config: &Value) -> Result<Self> {
597 let hidden_size = get_usize(config, "hidden_size")?;
598 let num_attention_heads = get_usize(config, "num_attention_heads")?;
599 let use_bias = get_bool_or(config, "use_bias", true);
600
601 let norm_type = match config.get("norm_type").and_then(Value::as_str) {
603 Some("layer_norm") => NormType::LayerNorm,
604 _ => NormType::RmsNorm,
605 };
606
607 Ok(Self {
608 hidden_size,
609 num_layers: get_usize(config, "num_hidden_layers")?,
610 num_attention_heads,
611 num_kv_heads: get_usize_or(config, "num_key_value_heads", num_attention_heads),
612 head_dim: get_head_dim(config, hidden_size, num_attention_heads)?,
613 intermediate_size: get_usize(config, "intermediate_size")?,
614 vocab_size: get_usize(config, "vocab_size")?,
615
616 norm_type,
617 norm_eps: get_f64_or(config, "norm_epsilon", 1e-5),
618 activation: Activation::GeluApprox,
619 qkv_layout: QkvLayout::Separate,
620 mlp_layout: MlpLayout::Plain,
621 qkv_bias: use_bias,
622 o_proj_bias: use_bias,
623 mlp_bias: use_bias,
624 embedding_scale: None,
625 tie_word_embeddings: get_bool_or(config, "tie_word_embeddings", true),
626
627 rope_theta: get_f64_or(config, "rope_theta", 10_000.0),
628 max_position_embeddings: get_usize_or(config, "max_position_embeddings", 16_384),
629
630 attn_logit_softcapping: None,
631 final_logit_softcapping: None,
632 query_pre_attn_scalar: None,
633 use_post_norms: false,
634 sliding_window: get_optional_usize(config, "sliding_window"),
635 alternating_sliding_window: false,
636 })
637 }
638
639 fn parse_mistral(config: &Value) -> Result<Self> {
647 let hidden_size = get_usize(config, "hidden_size")?;
648 let num_attention_heads = get_usize(config, "num_attention_heads")?;
649
650 Ok(Self {
651 hidden_size,
652 num_layers: get_usize(config, "num_hidden_layers")?,
653 num_attention_heads,
654 num_kv_heads: get_usize_or(config, "num_key_value_heads", num_attention_heads),
655 head_dim: get_head_dim(config, hidden_size, num_attention_heads)?,
656 intermediate_size: get_usize(config, "intermediate_size")?,
657 vocab_size: get_usize(config, "vocab_size")?,
658
659 norm_type: NormType::RmsNorm,
660 norm_eps: get_f64_or(config, "rms_norm_eps", 1e-5),
661 activation: Activation::Silu,
662 qkv_layout: QkvLayout::Separate,
663 mlp_layout: MlpLayout::GatedSeparate,
664 qkv_bias: false,
665 o_proj_bias: false,
666 mlp_bias: false,
667 embedding_scale: None,
668 tie_word_embeddings: get_bool_or(config, "tie_word_embeddings", false),
669
670 rope_theta: get_f64_or(config, "rope_theta", 10_000.0),
671 max_position_embeddings: get_usize_or(config, "max_position_embeddings", 32_768),
672
673 attn_logit_softcapping: None,
674 final_logit_softcapping: None,
675 query_pre_attn_scalar: None,
676 use_post_norms: false,
677 sliding_window: get_optional_usize(config, "sliding_window"),
678 alternating_sliding_window: false,
679 })
680 }
681}
682
683pub(crate) fn get_usize(config: &Value, key: &str) -> Result<usize> {
689 let val = config
690 .get(key)
691 .and_then(Value::as_u64)
692 .ok_or_else(|| MIError::Config(format!("missing or invalid field '{key}'")))?;
693 usize::try_from(val)
694 .map_err(|_| MIError::Config(format!("field '{key}' value {val} overflows usize")))
695}
696
697pub(crate) fn get_usize_or(config: &Value, key: &str, default: usize) -> usize {
699 config
700 .get(key)
701 .and_then(Value::as_u64)
702 .and_then(|v| usize::try_from(v).ok())
703 .unwrap_or(default)
704}
705
706pub(crate) fn get_optional_usize(config: &Value, key: &str) -> Option<usize> {
708 config
709 .get(key)
710 .and_then(Value::as_u64)
711 .and_then(|v| usize::try_from(v).ok())
712}
713
714pub(crate) fn get_f64_or(config: &Value, key: &str, default: f64) -> f64 {
716 config.get(key).and_then(Value::as_f64).unwrap_or(default)
717}
718
719pub(crate) fn get_optional_f64(config: &Value, key: &str) -> Option<f64> {
721 config.get(key).and_then(Value::as_f64)
722}
723
724pub(crate) fn get_bool_or(config: &Value, key: &str, default: bool) -> bool {
726 config.get(key).and_then(Value::as_bool).unwrap_or(default)
727}
728
729pub(crate) fn get_head_dim(
731 config: &Value,
732 hidden_size: usize,
733 num_attention_heads: usize,
734) -> Result<usize> {
735 let explicit = config.get("head_dim").and_then(Value::as_u64).map(|hd| {
737 usize::try_from(hd).map_err(|_| MIError::Config("head_dim overflows usize".into()))
738 });
739
740 match explicit {
741 Some(result) => result,
742 None if num_attention_heads == 0 => Err(MIError::Config(
743 "num_attention_heads is 0, cannot compute head_dim".into(),
744 )),
745 None => Ok(hidden_size / num_attention_heads),
746 }
747}
748
749fn parse_activation_str(config: &Value) -> Activation {
758 let act_str = config
759 .get("hidden_activation")
760 .or_else(|| config.get("hidden_act"))
761 .and_then(Value::as_str);
762 match act_str {
763 Some("gelu_pytorch_tanh") => Activation::GeluApprox,
764 Some("gelu") => Activation::Gelu,
765 _ => Activation::Silu,
766 }
767}
768
769pub fn tensor_names_from_safetensors(path: &Path) -> Result<Vec<String>> {
783 let mut file = std::fs::File::open(path)?;
784 let mut len_buf = [0u8; 8];
785 file.read_exact(&mut len_buf)?;
786 let header_len = u64::from_le_bytes(len_buf);
787 let header_len = usize::try_from(header_len)
788 .map_err(|_| MIError::Config("safetensors header length overflows usize".into()))?;
789 let mut header_buf = vec![0u8; header_len];
790 file.read_exact(&mut header_buf)?;
791 let header: Value = serde_json::from_slice(&header_buf)
792 .map_err(|e| MIError::Config(format!("failed to parse safetensors header: {e}")))?;
793 let obj = header
794 .as_object()
795 .ok_or_else(|| MIError::Config("safetensors header is not a JSON object".into()))?;
796 Ok(obj
797 .keys()
798 .filter(|k| *k != "__metadata__")
799 .cloned()
800 .collect())
801}
802
803pub fn tensor_names_from_index(path: &Path) -> Result<Vec<String>> {
812 let content = std::fs::read_to_string(path)?;
813 let index: Value = serde_json::from_str(&content)
814 .map_err(|e| MIError::Config(format!("failed to parse safetensors index: {e}")))?;
815 let weight_map = index
816 .get("weight_map")
817 .and_then(Value::as_object)
818 .ok_or_else(|| MIError::Config("missing 'weight_map' in safetensors index".into()))?;
819 Ok(weight_map.keys().cloned().collect())
820}
821
822impl TransformerConfig {
827 pub fn from_hf_config_auto(config: &Value, tensor_names: &[String]) -> Result<Self> {
847 let model_type = config
848 .get("model_type")
849 .and_then(Value::as_str)
850 .ok_or_else(|| MIError::Config("missing 'model_type' field".into()))?;
851
852 if SUPPORTED_MODEL_TYPES.contains(&model_type) {
854 return Self::from_hf_config(config);
855 }
856
857 Self::parse_auto(config, tensor_names, model_type)
859 }
860
861 #[allow(clippy::cast_precision_loss, clippy::as_conversions)]
874 fn parse_auto(config: &Value, tensor_names: &[String], model_type: &str) -> Result<Self> {
875 let has_layer0 = |suffix: &str| {
877 tensor_names
878 .iter()
879 .any(|n| n.contains("layers.0.") && n.ends_with(suffix))
880 };
881
882 let hidden_size = get_usize(config, "hidden_size")?;
884 let num_attention_heads = get_usize(config, "num_attention_heads")?;
885
886 let norm_eps = config
888 .get("rms_norm_eps")
889 .and_then(Value::as_f64)
890 .or_else(|| config.get("norm_epsilon").and_then(Value::as_f64))
891 .unwrap_or(1e-5);
892
893 let activation = parse_activation_str(config);
894
895 let sliding_window =
897 if config.get("use_sliding_window").and_then(Value::as_bool) == Some(false) {
898 None
899 } else {
900 get_optional_usize(config, "sliding_window")
901 };
902
903 let tie_word_embeddings = config
905 .get("tie_word_embeddings")
906 .and_then(Value::as_bool)
907 .unwrap_or_else(|| !tensor_names.iter().any(|n| n == "lm_head.weight"));
908
909 let attn_logit_softcapping = get_optional_f64(config, "attn_logit_softcapping");
911 let final_logit_softcapping = get_optional_f64(config, "final_logit_softcapping");
912 let query_pre_attn_scalar = get_optional_f64(config, "query_pre_attn_scalar");
913
914 let qkv_layout = if has_layer0("self_attn.qkv_proj.weight") {
918 QkvLayout::Fused
919 } else {
920 QkvLayout::Separate
921 };
922
923 let mlp_layout = if has_layer0("mlp.gate_up_proj.weight") {
925 MlpLayout::GatedFused
926 } else if has_layer0("mlp.gate_proj.weight") {
927 MlpLayout::GatedSeparate
928 } else if has_layer0("mlp.c_fc.weight") {
929 MlpLayout::Plain
930 } else {
931 MlpLayout::GatedSeparate };
933
934 let qkv_bias = has_layer0("self_attn.q_proj.bias") || has_layer0("self_attn.qkv_proj.bias");
936 let o_proj_bias = has_layer0("self_attn.o_proj.bias");
937 let mlp_bias = has_layer0("mlp.down_proj.bias")
938 || has_layer0("mlp.c_fc.bias")
939 || has_layer0("mlp.gate_proj.bias")
940 || has_layer0("mlp.gate_up_proj.bias");
941
942 let has_norm_bias = has_layer0("input_layernorm.bias");
944 let base_norm_type = if has_norm_bias {
945 NormType::LayerNorm
946 } else {
947 NormType::RmsNorm
948 };
949
950 let use_post_norms = has_layer0("post_feedforward_layernorm.weight")
952 || has_layer0("pre_feedforward_layernorm.weight");
953
954 let is_gemma = model_type.contains("gemma");
956
957 let norm_type = if is_gemma {
958 NormType::GemmaRmsNorm
959 } else {
960 base_norm_type
961 };
962
963 let embedding_scale = if is_gemma {
966 Some((hidden_size as f64).sqrt())
967 } else {
968 None
969 };
970
971 let alternating_sliding_window = is_gemma && use_post_norms;
972
973 let query_pre_attn_scalar = if is_gemma && use_post_norms {
975 query_pre_attn_scalar.or(Some(256.0))
976 } else {
977 query_pre_attn_scalar
978 };
979
980 Ok(Self {
981 hidden_size,
982 num_layers: get_usize(config, "num_hidden_layers")?,
983 num_attention_heads,
984 num_kv_heads: get_usize_or(config, "num_key_value_heads", num_attention_heads),
985 head_dim: get_head_dim(config, hidden_size, num_attention_heads)?,
986 intermediate_size: get_usize(config, "intermediate_size")?,
987 vocab_size: get_usize(config, "vocab_size")?,
988
989 norm_type,
990 norm_eps,
991 activation,
992 qkv_layout,
993 mlp_layout,
994 qkv_bias,
995 o_proj_bias,
996 mlp_bias,
997 embedding_scale,
998 tie_word_embeddings,
999
1000 rope_theta: get_f64_or(config, "rope_theta", 10_000.0),
1001 max_position_embeddings: get_usize_or(config, "max_position_embeddings", 4096),
1002
1003 attn_logit_softcapping,
1004 final_logit_softcapping,
1005 query_pre_attn_scalar,
1006 use_post_norms,
1007 sliding_window,
1008 alternating_sliding_window,
1009 })
1010 }
1011}
1012
1013#[derive(Debug, Clone)]
1021pub struct CompatibilityReport {
1022 pub compatible: bool,
1024 pub issues: Vec<String>,
1026}
1027
1028impl CompatibilityReport {
1029 pub fn into_result(self) -> Result<()> {
1036 if self.compatible {
1037 Ok(())
1038 } else {
1039 Err(MIError::Config(format!(
1040 "model is not compatible with GenericTransformer:\n - {}",
1041 self.issues.join("\n - ")
1042 )))
1043 }
1044 }
1045}
1046
1047impl TransformerConfig {
1048 #[must_use]
1059 pub fn check_config_fields(config: &Value) -> CompatibilityReport {
1060 let required = [
1061 "hidden_size",
1062 "num_hidden_layers",
1063 "num_attention_heads",
1064 "intermediate_size",
1065 "vocab_size",
1066 ];
1067 let mut issues = Vec::new();
1068 for key in &required {
1069 if config.get(*key).and_then(Value::as_u64).is_none() {
1070 issues.push(format!("missing or invalid required field '{key}'"));
1071 }
1072 }
1073 CompatibilityReport {
1074 compatible: issues.is_empty(),
1075 issues,
1076 }
1077 }
1078
1079 #[must_use]
1097 pub fn check_auto_compatibility(
1098 config: &Value,
1099 tensor_names: &[String],
1100 ) -> CompatibilityReport {
1101 let mut issues = Vec::new();
1102
1103 let field_report = Self::check_config_fields(config);
1105 issues.extend(field_report.issues);
1106
1107 let has_tensor_issues = check_tensor_names(config, tensor_names, &mut issues);
1109
1110 if has_tensor_issues && !tensor_names.is_empty() {
1112 if let Some(hint) = detect_naming_convention(tensor_names) {
1113 issues.push(hint);
1114 }
1115 }
1116
1117 CompatibilityReport {
1118 compatible: issues.is_empty(),
1119 issues,
1120 }
1121 }
1122}
1123
1124#[allow(clippy::too_many_lines)]
1130fn check_tensor_names(config: &Value, tensor_names: &[String], issues: &mut Vec<String>) -> bool {
1131 let has = |name: &str| tensor_names.iter().any(|n| n == name);
1133 let has_layer0 = |suffix: &str| {
1134 tensor_names
1135 .iter()
1136 .any(|n| n.contains("layers.0.") && n.ends_with(suffix))
1137 };
1138
1139 let find_matching = |keyword: &str, limit: usize| -> Vec<&str> {
1141 tensor_names
1142 .iter()
1143 .filter(|n| n.to_lowercase().contains(keyword))
1144 .take(limit)
1145 .map(String::as_str)
1146 .collect::<Vec<_>>()
1147 };
1148
1149 let mut has_issues = false;
1150
1151 if !has("model.embed_tokens.weight") {
1153 has_issues = true;
1154 let found: Vec<&str> = tensor_names
1155 .iter()
1156 .filter(|n| n.contains("embed") || n.contains("wte") || n.contains("word_embeddings"))
1157 .take(3)
1158 .map(String::as_str)
1159 .collect();
1160 let hint = if found.is_empty() {
1161 String::new()
1162 } else {
1163 format!("; found embedding-like tensors: {}", found.join(", "))
1164 };
1165 issues.push(format!(
1166 "missing embedding tensor 'model.embed_tokens.weight'{hint}"
1167 ));
1168 }
1169
1170 if !has_layer0("input_layernorm.weight") {
1172 has_issues = true;
1173 let found = find_matching("norm", 4);
1174 let hint = if found.is_empty() {
1175 String::new()
1176 } else {
1177 format!("; found norm-like tensors: {}", found.join(", "))
1178 };
1179 issues.push(format!(
1180 "missing normalization tensor \
1181 'model.layers.0.input_layernorm.weight'{hint}"
1182 ));
1183 }
1184 if !has_layer0("post_attention_layernorm.weight")
1185 && !has_layer0("pre_feedforward_layernorm.weight")
1186 {
1187 has_issues = true;
1188 issues.push(
1189 "missing normalization tensor \
1190 'model.layers.0.post_attention_layernorm.weight'"
1191 .into(),
1192 );
1193 }
1194
1195 if !has("model.norm.weight") {
1197 has_issues = true;
1198 let found: Vec<&str> = tensor_names
1199 .iter()
1200 .filter(|n| {
1201 (n.contains("ln_f") || n.contains("final_layer_norm") || n.contains("ln_out"))
1202 && n.ends_with(".weight")
1203 })
1204 .take(2)
1205 .map(String::as_str)
1206 .collect();
1207 let hint = if found.is_empty() {
1208 String::new()
1209 } else {
1210 format!("; found final-norm-like tensors: {}", found.join(", "))
1211 };
1212 issues.push(format!(
1213 "missing final norm tensor 'model.norm.weight'{hint}"
1214 ));
1215 }
1216
1217 let has_separate_attn = has_layer0("self_attn.q_proj.weight");
1219 let has_fused_attn = has_layer0("self_attn.qkv_proj.weight");
1220 if !has_separate_attn && !has_fused_attn {
1221 has_issues = true;
1222 let found = find_matching("attn", 4);
1223 let hint = if found.is_empty() {
1224 String::new()
1225 } else {
1226 format!("; found attention-like tensors: {}", found.join(", "))
1227 };
1228 issues.push(format!(
1229 "missing attention projections: expected \
1230 'self_attn.q_proj.weight' or 'self_attn.qkv_proj.weight'{hint}"
1231 ));
1232 }
1233
1234 let has_gated_separate = has_layer0("mlp.gate_proj.weight");
1236 let has_gated_fused = has_layer0("mlp.gate_up_proj.weight");
1237 let has_plain = has_layer0("mlp.c_fc.weight");
1238 let has_down = has_layer0("mlp.down_proj.weight");
1240 if !has_gated_separate && !has_gated_fused && !has_plain && !has_down {
1241 has_issues = true;
1242 let found: Vec<&str> = tensor_names
1243 .iter()
1244 .filter(|n| n.contains("mlp") || n.contains("ffn") || n.contains("fc"))
1245 .take(4)
1246 .map(String::as_str)
1247 .collect();
1248 let hint = if found.is_empty() {
1249 String::new()
1250 } else {
1251 format!("; found MLP-like tensors: {}", found.join(", "))
1252 };
1253 issues.push(format!(
1254 "missing MLP projections: expected 'mlp.gate_proj.weight', \
1255 'mlp.gate_up_proj.weight', or 'mlp.c_fc.weight'{hint}"
1256 ));
1257 }
1258
1259 let tie = config
1261 .get("tie_word_embeddings")
1262 .and_then(Value::as_bool)
1263 .unwrap_or_else(|| !tensor_names.iter().any(|n| n == "lm_head.weight"));
1264 if !tie && !has("lm_head.weight") {
1265 issues.push("tie_word_embeddings is false but 'lm_head.weight' tensor is missing".into());
1266 }
1267
1268 has_issues
1269}
1270
1271fn detect_naming_convention(tensor_names: &[String]) -> Option<String> {
1276 let patterns: &[(&str, &str)] = &[
1278 (
1279 "transformer.h.",
1280 "GPT-2 / GPT-J / GPT-NeoX (uses 'transformer.h.{i}' prefix)",
1281 ),
1282 (
1283 "transformer.blocks.",
1284 "Falcon / MPT (uses 'transformer.blocks.{i}' prefix)",
1285 ),
1286 (
1287 "gpt_neox.layers.",
1288 "GPT-NeoX / Pythia (uses 'gpt_neox.layers.{i}' prefix)",
1289 ),
1290 (
1291 "transformer.layer.",
1292 "BLOOM (uses 'transformer.layer.{i}' prefix)",
1293 ),
1294 ];
1295
1296 for &(prefix, description) in patterns {
1297 if tensor_names.iter().any(|n| n.starts_with(prefix)) {
1298 return Some(format!(
1299 "this model uses {description} — candle-mi currently requires \
1300 HF-standard 'model.layers.{{i}}' weight naming. \
1301 Support for this architecture is planned in Phase 9 \
1302 (tensor name remapping)"
1303 ));
1304 }
1305 }
1306
1307 if !tensor_names.iter().any(|n| n.starts_with("model.layers.")) {
1310 let sample: Vec<&str> = tensor_names.iter().take(5).map(String::as_str).collect();
1311 return Some(format!(
1312 "weight tensors use an unrecognized naming convention \
1313 (first 5: {}). candle-mi expects 'model.layers.{{i}}.self_attn.*' / \
1314 'model.layers.{{i}}.mlp.*' naming",
1315 sample.join(", ")
1316 ));
1317 }
1318
1319 None
1320}
1321
1322#[cfg(test)]
1327#[allow(clippy::unwrap_used)]
1328mod tests {
1329 use super::*;
1330
1331 fn llama_config_json() -> Value {
1333 serde_json::json!({
1334 "model_type": "llama",
1335 "hidden_size": 2048,
1336 "num_hidden_layers": 16,
1337 "num_attention_heads": 32,
1338 "num_key_value_heads": 8,
1339 "intermediate_size": 8192,
1340 "vocab_size": 128256,
1341 "rms_norm_eps": 1e-5,
1342 "rope_theta": 500000.0,
1343 "max_position_embeddings": 131072
1344 })
1345 }
1346
1347 #[test]
1348 fn parse_llama_basic() {
1349 let config = TransformerConfig::from_hf_config(&llama_config_json()).unwrap();
1350 assert_eq!(config.hidden_size, 2048);
1351 assert_eq!(config.num_layers, 16);
1352 assert_eq!(config.num_attention_heads, 32);
1353 assert_eq!(config.num_kv_heads, 8);
1354 assert_eq!(config.head_dim, 64);
1355 assert_eq!(config.intermediate_size, 8192);
1356 assert_eq!(config.vocab_size, 128256);
1357 assert_eq!(config.norm_type, NormType::RmsNorm);
1358 assert_eq!(config.activation, Activation::Silu);
1359 assert_eq!(config.qkv_layout, QkvLayout::Separate);
1360 assert_eq!(config.mlp_layout, MlpLayout::GatedSeparate);
1361 assert!(!config.qkv_bias);
1362 assert!(!config.o_proj_bias);
1363 assert!(!config.mlp_bias);
1364 assert!(config.embedding_scale.is_none());
1365 assert!(!config.tie_word_embeddings);
1366 assert!((config.rope_theta - 500_000.0).abs() < f64::EPSILON);
1367 assert!(config.attn_logit_softcapping.is_none());
1368 assert!(config.sliding_window.is_none());
1369 }
1370
1371 #[test]
1372 fn parse_qwen2_bias() {
1373 let json = serde_json::json!({
1374 "model_type": "qwen2",
1375 "hidden_size": 896,
1376 "num_hidden_layers": 24,
1377 "num_attention_heads": 14,
1378 "num_key_value_heads": 2,
1379 "intermediate_size": 4864,
1380 "vocab_size": 151936,
1381 "attention_bias": true,
1382 "tie_word_embeddings": true
1383 });
1384 let config = TransformerConfig::from_hf_config(&json).unwrap();
1385 assert!(config.qkv_bias);
1386 assert!(!config.o_proj_bias);
1387 assert!(config.tie_word_embeddings);
1388 }
1389
1390 #[test]
1391 fn parse_gemma2_extensions() {
1392 let json = serde_json::json!({
1393 "model_type": "gemma2",
1394 "hidden_size": 2304,
1395 "num_hidden_layers": 26,
1396 "num_attention_heads": 8,
1397 "num_key_value_heads": 4,
1398 "head_dim": 256,
1399 "intermediate_size": 9216,
1400 "vocab_size": 256000,
1401 "attn_logit_softcapping": 50.0,
1402 "final_logit_softcapping": 30.0,
1403 "query_pre_attn_scalar": 256,
1404 "sliding_window": 4096
1405 });
1406 let config = TransformerConfig::from_hf_config(&json).unwrap();
1407 assert_eq!(config.norm_type, NormType::GemmaRmsNorm);
1408 assert_eq!(config.head_dim, 256);
1409 assert!(config.embedding_scale.is_some());
1410 assert!((config.attn_logit_softcapping.unwrap() - 50.0).abs() < f64::EPSILON);
1411 assert!((config.final_logit_softcapping.unwrap() - 30.0).abs() < f64::EPSILON);
1412 assert!((config.query_pre_attn_scalar.unwrap() - 256.0).abs() < f64::EPSILON);
1413 assert!(config.use_post_norms);
1414 assert_eq!(config.sliding_window, Some(4096));
1415 assert!(config.alternating_sliding_window);
1416 }
1417
1418 #[test]
1419 fn parse_phi3_fused() {
1420 let json = serde_json::json!({
1421 "model_type": "phi3",
1422 "hidden_size": 3072,
1423 "num_hidden_layers": 32,
1424 "num_attention_heads": 32,
1425 "num_key_value_heads": 32,
1426 "intermediate_size": 8192,
1427 "vocab_size": 32064
1428 });
1429 let config = TransformerConfig::from_hf_config(&json).unwrap();
1430 assert_eq!(config.qkv_layout, QkvLayout::Fused);
1431 assert_eq!(config.mlp_layout, MlpLayout::GatedFused);
1432 }
1433
1434 #[test]
1435 fn parse_starcoder2_bias_and_plain_mlp() {
1436 let json = serde_json::json!({
1437 "model_type": "starcoder2",
1438 "hidden_size": 3072,
1439 "num_hidden_layers": 30,
1440 "num_attention_heads": 24,
1441 "num_key_value_heads": 2,
1442 "intermediate_size": 12288,
1443 "vocab_size": 49152,
1444 "use_bias": true,
1445 "norm_type": "layer_norm"
1446 });
1447 let config = TransformerConfig::from_hf_config(&json).unwrap();
1448 assert_eq!(config.mlp_layout, MlpLayout::Plain);
1449 assert_eq!(config.activation, Activation::GeluApprox);
1450 assert_eq!(config.norm_type, NormType::LayerNorm);
1451 assert!(config.qkv_bias);
1452 assert!(config.o_proj_bias);
1453 assert!(config.mlp_bias);
1454 }
1455
1456 #[test]
1457 fn parse_mistral_sliding_window() {
1458 let json = serde_json::json!({
1459 "model_type": "mistral",
1460 "hidden_size": 4096,
1461 "num_hidden_layers": 32,
1462 "num_attention_heads": 32,
1463 "num_key_value_heads": 8,
1464 "intermediate_size": 14336,
1465 "vocab_size": 32000,
1466 "sliding_window": 4096
1467 });
1468 let config = TransformerConfig::from_hf_config(&json).unwrap();
1469 assert_eq!(config.sliding_window, Some(4096));
1470 assert!(!config.alternating_sliding_window);
1471 }
1472
1473 #[test]
1474 fn unsupported_model_type_errors() {
1475 let json = serde_json::json!({ "model_type": "bert" });
1476 let result = TransformerConfig::from_hf_config(&json);
1477 assert!(result.is_err());
1478 }
1479
1480 #[test]
1481 fn missing_model_type_errors() {
1482 let json = serde_json::json!({ "hidden_size": 768 });
1483 let result = TransformerConfig::from_hf_config(&json);
1484 assert!(result.is_err());
1485 }
1486
1487 fn tensor_names(names: &[&str]) -> Vec<String> {
1504 names.iter().map(|s| (*s).to_owned()).collect()
1505 }
1506
1507 #[test]
1508 fn auto_config_matches_llama() {
1509 let json = serde_json::json!({
1511 "model_type": "llama",
1512 "hidden_size": 2048,
1513 "num_hidden_layers": 16,
1514 "num_attention_heads": 32,
1515 "num_key_value_heads": 8,
1516 "head_dim": 64,
1517 "intermediate_size": 8192,
1518 "vocab_size": 128256,
1519 "rms_norm_eps": 1e-5,
1520 "rope_theta": 500000.0,
1521 "max_position_embeddings": 131072,
1522 "hidden_act": "silu",
1523 "attention_bias": false,
1524 "mlp_bias": false,
1525 "tie_word_embeddings": true
1526 });
1527 let names = tensor_names(&[
1528 "model.embed_tokens.weight",
1529 "model.layers.0.input_layernorm.weight",
1530 "model.layers.0.mlp.down_proj.weight",
1531 "model.layers.0.mlp.gate_proj.weight",
1532 "model.layers.0.mlp.up_proj.weight",
1533 "model.layers.0.post_attention_layernorm.weight",
1534 "model.layers.0.self_attn.k_proj.weight",
1535 "model.layers.0.self_attn.o_proj.weight",
1536 "model.layers.0.self_attn.q_proj.weight",
1537 "model.layers.0.self_attn.v_proj.weight",
1538 "model.norm.weight",
1539 ]);
1540
1541 let manual = TransformerConfig::from_hf_config(&json).unwrap();
1542 let auto = TransformerConfig::parse_auto(&json, &names, "llama").unwrap();
1543 assert_eq!(auto, manual);
1544 }
1545
1546 #[test]
1547 fn auto_config_matches_qwen2() {
1548 let json = serde_json::json!({
1550 "model_type": "qwen2",
1551 "hidden_size": 2048,
1552 "num_hidden_layers": 36,
1553 "num_attention_heads": 16,
1554 "num_key_value_heads": 2,
1555 "intermediate_size": 11008,
1556 "vocab_size": 151936,
1557 "rms_norm_eps": 1e-6,
1558 "rope_theta": 1000000.0,
1559 "max_position_embeddings": 32768,
1560 "hidden_act": "silu",
1561 "tie_word_embeddings": true,
1562 "sliding_window": 32768,
1563 "use_sliding_window": false
1564 });
1565 let names = tensor_names(&[
1566 "model.embed_tokens.weight",
1567 "model.layers.0.input_layernorm.weight",
1568 "model.layers.0.mlp.down_proj.weight",
1569 "model.layers.0.mlp.gate_proj.weight",
1570 "model.layers.0.mlp.up_proj.weight",
1571 "model.layers.0.post_attention_layernorm.weight",
1572 "model.layers.0.self_attn.k_proj.bias",
1573 "model.layers.0.self_attn.k_proj.weight",
1574 "model.layers.0.self_attn.o_proj.weight",
1575 "model.layers.0.self_attn.q_proj.bias",
1576 "model.layers.0.self_attn.q_proj.weight",
1577 "model.layers.0.self_attn.v_proj.bias",
1578 "model.layers.0.self_attn.v_proj.weight",
1579 "model.norm.weight",
1580 ]);
1581
1582 let manual = TransformerConfig::from_hf_config(&json).unwrap();
1583 let auto = TransformerConfig::parse_auto(&json, &names, "qwen2").unwrap();
1584 assert_eq!(auto, manual);
1585 }
1586
1587 #[test]
1588 fn auto_config_matches_gemma() {
1589 let json = serde_json::json!({
1591 "model_type": "gemma",
1592 "hidden_size": 3072,
1593 "num_hidden_layers": 28,
1594 "num_attention_heads": 16,
1595 "num_key_value_heads": 16,
1596 "head_dim": 256,
1597 "intermediate_size": 24576,
1598 "vocab_size": 256000,
1599 "rms_norm_eps": 1e-6,
1600 "rope_theta": 10000.0,
1601 "max_position_embeddings": 8192,
1602 "hidden_activation": "gelu_pytorch_tanh"
1603 });
1604 let names = tensor_names(&[
1605 "model.embed_tokens.weight",
1606 "model.layers.0.input_layernorm.weight",
1607 "model.layers.0.mlp.down_proj.weight",
1608 "model.layers.0.mlp.gate_proj.weight",
1609 "model.layers.0.mlp.up_proj.weight",
1610 "model.layers.0.post_attention_layernorm.weight",
1611 "model.layers.0.self_attn.k_proj.weight",
1612 "model.layers.0.self_attn.o_proj.weight",
1613 "model.layers.0.self_attn.q_proj.weight",
1614 "model.layers.0.self_attn.v_proj.weight",
1615 "model.norm.weight",
1616 ]);
1617
1618 let manual = TransformerConfig::from_hf_config(&json).unwrap();
1619 let auto = TransformerConfig::parse_auto(&json, &names, "gemma").unwrap();
1620 assert_eq!(auto, manual);
1621 }
1622
1623 #[test]
1624 fn auto_config_matches_gemma2() {
1625 let json = serde_json::json!({
1627 "model_type": "gemma2",
1628 "hidden_size": 2304,
1629 "num_hidden_layers": 26,
1630 "num_attention_heads": 8,
1631 "num_key_value_heads": 4,
1632 "head_dim": 256,
1633 "intermediate_size": 9216,
1634 "vocab_size": 256000,
1635 "rms_norm_eps": 1e-6,
1636 "rope_theta": 10000.0,
1637 "max_position_embeddings": 8192,
1638 "hidden_act": "gelu_pytorch_tanh",
1639 "hidden_activation": "gelu_pytorch_tanh",
1640 "attn_logit_softcapping": 50.0,
1641 "final_logit_softcapping": 30.0,
1642 "query_pre_attn_scalar": 256,
1643 "sliding_window": 4096
1644 });
1645 let names = tensor_names(&[
1646 "model.embed_tokens.weight",
1647 "model.layers.0.input_layernorm.weight",
1648 "model.layers.0.mlp.down_proj.weight",
1649 "model.layers.0.mlp.gate_proj.weight",
1650 "model.layers.0.mlp.up_proj.weight",
1651 "model.layers.0.post_attention_layernorm.weight",
1652 "model.layers.0.post_feedforward_layernorm.weight",
1653 "model.layers.0.pre_feedforward_layernorm.weight",
1654 "model.layers.0.self_attn.k_proj.weight",
1655 "model.layers.0.self_attn.o_proj.weight",
1656 "model.layers.0.self_attn.q_proj.weight",
1657 "model.layers.0.self_attn.v_proj.weight",
1658 "model.norm.weight",
1659 ]);
1660
1661 let manual = TransformerConfig::from_hf_config(&json).unwrap();
1662 let auto = TransformerConfig::parse_auto(&json, &names, "gemma2").unwrap();
1663 assert_eq!(auto, manual);
1664 }
1665
1666 #[test]
1667 fn auto_config_matches_phi3() {
1668 let json = serde_json::json!({
1675 "model_type": "phi3",
1676 "hidden_size": 3072,
1677 "num_hidden_layers": 32,
1678 "num_attention_heads": 32,
1679 "num_key_value_heads": 32,
1680 "intermediate_size": 8192,
1681 "vocab_size": 32064,
1682 "rms_norm_eps": 1e-5,
1683 "rope_theta": 10000.0,
1684 "max_position_embeddings": 4096,
1685 "hidden_act": "silu",
1686 "tie_word_embeddings": false,
1687 "sliding_window": 2047,
1688 "attention_bias": false
1689 });
1690 let names = tensor_names(&[
1691 "lm_head.weight",
1692 "model.embed_tokens.weight",
1693 "model.layers.0.input_layernorm.weight",
1694 "model.layers.0.mlp.down_proj.weight",
1695 "model.layers.0.mlp.gate_up_proj.weight",
1696 "model.layers.0.post_attention_layernorm.weight",
1697 "model.layers.0.self_attn.o_proj.weight",
1698 "model.layers.0.self_attn.qkv_proj.weight",
1699 "model.norm.weight",
1700 ]);
1701
1702 let manual = TransformerConfig::from_hf_config(&json).unwrap();
1703 let auto = TransformerConfig::parse_auto(&json, &names, "phi3").unwrap();
1704
1705 assert_eq!(manual.sliding_window, None);
1707 assert_eq!(auto.sliding_window, Some(2047));
1708
1709 let mut auto_adjusted = auto;
1712 auto_adjusted.sliding_window = None;
1713 assert_eq!(auto_adjusted, manual);
1714 }
1715
1716 #[test]
1717 fn auto_config_matches_starcoder2() {
1718 let json = serde_json::json!({
1720 "model_type": "starcoder2",
1721 "hidden_size": 3072,
1722 "num_hidden_layers": 30,
1723 "num_attention_heads": 24,
1724 "num_key_value_heads": 2,
1725 "intermediate_size": 12288,
1726 "vocab_size": 49152,
1727 "norm_epsilon": 1e-5,
1728 "norm_type": "layer_norm",
1729 "rope_theta": 999999.4420358813,
1730 "max_position_embeddings": 16384,
1731 "hidden_act": "gelu_pytorch_tanh",
1732 "use_bias": true,
1733 "sliding_window": 4096
1734 });
1735 let names = tensor_names(&[
1736 "model.embed_tokens.weight",
1737 "model.layers.0.input_layernorm.bias",
1738 "model.layers.0.input_layernorm.weight",
1739 "model.layers.0.mlp.c_fc.bias",
1740 "model.layers.0.mlp.c_fc.weight",
1741 "model.layers.0.mlp.c_proj.bias",
1742 "model.layers.0.mlp.c_proj.weight",
1743 "model.layers.0.post_attention_layernorm.bias",
1744 "model.layers.0.post_attention_layernorm.weight",
1745 "model.layers.0.self_attn.k_proj.bias",
1746 "model.layers.0.self_attn.k_proj.weight",
1747 "model.layers.0.self_attn.o_proj.bias",
1748 "model.layers.0.self_attn.o_proj.weight",
1749 "model.layers.0.self_attn.q_proj.bias",
1750 "model.layers.0.self_attn.q_proj.weight",
1751 "model.layers.0.self_attn.v_proj.bias",
1752 "model.layers.0.self_attn.v_proj.weight",
1753 "model.norm.bias",
1754 "model.norm.weight",
1755 ]);
1756
1757 let manual = TransformerConfig::from_hf_config(&json).unwrap();
1758 let auto = TransformerConfig::parse_auto(&json, &names, "starcoder2").unwrap();
1759 assert_eq!(auto, manual);
1760 }
1761
1762 #[test]
1763 fn auto_config_matches_mistral() {
1764 let json = serde_json::json!({
1766 "model_type": "mistral",
1767 "hidden_size": 4096,
1768 "num_hidden_layers": 32,
1769 "num_attention_heads": 32,
1770 "num_key_value_heads": 8,
1771 "intermediate_size": 14336,
1772 "vocab_size": 32000,
1773 "rms_norm_eps": 1e-5,
1774 "rope_theta": 10000.0,
1775 "max_position_embeddings": 32768,
1776 "hidden_act": "silu",
1777 "tie_word_embeddings": false,
1778 "sliding_window": 4096
1779 });
1780 let names = tensor_names(&[
1781 "lm_head.weight",
1782 "model.embed_tokens.weight",
1783 "model.layers.0.input_layernorm.weight",
1784 "model.layers.0.mlp.down_proj.weight",
1785 "model.layers.0.mlp.gate_proj.weight",
1786 "model.layers.0.mlp.up_proj.weight",
1787 "model.layers.0.post_attention_layernorm.weight",
1788 "model.layers.0.self_attn.k_proj.weight",
1789 "model.layers.0.self_attn.o_proj.weight",
1790 "model.layers.0.self_attn.q_proj.weight",
1791 "model.layers.0.self_attn.v_proj.weight",
1792 "model.norm.weight",
1793 ]);
1794
1795 let manual = TransformerConfig::from_hf_config(&json).unwrap();
1796 let auto = TransformerConfig::parse_auto(&json, &names, "mistral").unwrap();
1797 assert_eq!(auto, manual);
1798 }
1799
1800 #[test]
1801 fn auto_config_unknown_model_type() {
1802 let json = serde_json::json!({
1805 "model_type": "my_custom_llama",
1806 "hidden_size": 2048,
1807 "num_hidden_layers": 16,
1808 "num_attention_heads": 32,
1809 "num_key_value_heads": 8,
1810 "intermediate_size": 8192,
1811 "vocab_size": 32000,
1812 "rms_norm_eps": 1e-5,
1813 "rope_theta": 10000.0,
1814 "max_position_embeddings": 4096,
1815 "hidden_act": "silu"
1816 });
1817 let names = tensor_names(&[
1818 "lm_head.weight",
1819 "model.embed_tokens.weight",
1820 "model.layers.0.input_layernorm.weight",
1821 "model.layers.0.mlp.down_proj.weight",
1822 "model.layers.0.mlp.gate_proj.weight",
1823 "model.layers.0.mlp.up_proj.weight",
1824 "model.layers.0.post_attention_layernorm.weight",
1825 "model.layers.0.self_attn.k_proj.weight",
1826 "model.layers.0.self_attn.o_proj.weight",
1827 "model.layers.0.self_attn.q_proj.weight",
1828 "model.layers.0.self_attn.v_proj.weight",
1829 "model.norm.weight",
1830 ]);
1831
1832 let config = TransformerConfig::from_hf_config_auto(&json, &names).unwrap();
1834 assert_eq!(config.hidden_size, 2048);
1835 assert_eq!(config.num_layers, 16);
1836 assert_eq!(config.num_attention_heads, 32);
1837 assert_eq!(config.num_kv_heads, 8);
1838 assert_eq!(config.head_dim, 64);
1839 assert_eq!(config.norm_type, NormType::RmsNorm);
1840 assert_eq!(config.activation, Activation::Silu);
1841 assert_eq!(config.qkv_layout, QkvLayout::Separate);
1842 assert_eq!(config.mlp_layout, MlpLayout::GatedSeparate);
1843 assert!(!config.qkv_bias);
1844 assert!(!config.o_proj_bias);
1845 assert!(!config.mlp_bias);
1846 assert!(config.embedding_scale.is_none());
1847 assert!(!config.tie_word_embeddings);
1848 assert!(config.sliding_window.is_none());
1849 }
1850
1851 #[test]
1852 fn auto_config_dispatches_known_families() {
1853 let json = llama_config_json();
1855 let names = tensor_names(&["model.embed_tokens.weight"]);
1856
1857 let auto = TransformerConfig::from_hf_config_auto(&json, &names).unwrap();
1858 let manual = TransformerConfig::from_hf_config(&json).unwrap();
1859 assert_eq!(auto, manual);
1860 }
1861
1862 #[test]
1867 fn compatibility_check_passes_standard_model() {
1868 let json = serde_json::json!({
1869 "model_type": "my_custom",
1870 "hidden_size": 2048,
1871 "num_hidden_layers": 16,
1872 "num_attention_heads": 32,
1873 "intermediate_size": 8192,
1874 "vocab_size": 32000,
1875 "tie_word_embeddings": true
1876 });
1877 let names = tensor_names(&[
1878 "model.embed_tokens.weight",
1879 "model.layers.0.input_layernorm.weight",
1880 "model.layers.0.post_attention_layernorm.weight",
1881 "model.layers.0.self_attn.q_proj.weight",
1882 "model.layers.0.mlp.gate_proj.weight",
1883 "model.norm.weight",
1884 ]);
1885 let report = TransformerConfig::check_auto_compatibility(&json, &names);
1886 assert!(report.compatible, "issues: {:?}", report.issues);
1887 }
1888
1889 #[test]
1890 fn compatibility_check_detects_missing_norms() {
1891 let json = serde_json::json!({
1893 "model_type": "olmo",
1894 "hidden_size": 2048,
1895 "num_hidden_layers": 16,
1896 "num_attention_heads": 16,
1897 "intermediate_size": 8192,
1898 "vocab_size": 50304
1899 });
1900 let names = tensor_names(&[
1901 "model.embed_tokens.weight",
1902 "model.layers.0.self_attn.q_proj.weight",
1903 "model.layers.0.mlp.gate_proj.weight",
1904 "model.layers.0.mlp.down_proj.weight",
1905 ]);
1906 let report = TransformerConfig::check_auto_compatibility(&json, &names);
1907 assert!(!report.compatible);
1908 assert!(report.issues.len() >= 3, "issues: {:?}", report.issues);
1910 assert!(
1911 report.issues.iter().any(|i| i.contains("input_layernorm")),
1912 "should mention input_layernorm"
1913 );
1914 assert!(
1915 report.issues.iter().any(|i| i.contains("model.norm")),
1916 "should mention model.norm"
1917 );
1918 }
1919
1920 #[test]
1921 fn compatibility_check_detects_missing_config_fields() {
1922 let json = serde_json::json!({
1923 "model_type": "mystery",
1924 "hidden_size": 768
1925 });
1926 let names = tensor_names(&[]);
1927 let report = TransformerConfig::check_auto_compatibility(&json, &names);
1928 assert!(!report.compatible);
1929 assert!(
1931 report
1932 .issues
1933 .iter()
1934 .any(|i| i.contains("num_hidden_layers")),
1935 "should mention num_hidden_layers"
1936 );
1937 }
1938
1939 #[test]
1940 fn compatibility_check_detects_missing_lm_head() {
1941 let json = serde_json::json!({
1942 "model_type": "custom",
1943 "hidden_size": 2048,
1944 "num_hidden_layers": 16,
1945 "num_attention_heads": 32,
1946 "intermediate_size": 8192,
1947 "vocab_size": 32000,
1948 "tie_word_embeddings": false
1949 });
1950 let names = tensor_names(&[
1951 "model.embed_tokens.weight",
1952 "model.layers.0.input_layernorm.weight",
1953 "model.layers.0.post_attention_layernorm.weight",
1954 "model.layers.0.self_attn.q_proj.weight",
1955 "model.layers.0.mlp.gate_proj.weight",
1956 "model.norm.weight",
1957 ]);
1959 let report = TransformerConfig::check_auto_compatibility(&json, &names);
1960 assert!(!report.compatible);
1961 assert!(
1962 report.issues.iter().any(|i| i.contains("lm_head")),
1963 "should mention lm_head"
1964 );
1965 }
1966
1967 #[test]
1968 fn compatibility_check_config_only() {
1969 let good = serde_json::json!({
1970 "hidden_size": 2048,
1971 "num_hidden_layers": 16,
1972 "num_attention_heads": 32,
1973 "intermediate_size": 8192,
1974 "vocab_size": 32000
1975 });
1976 assert!(TransformerConfig::check_config_fields(&good).compatible);
1977
1978 let bad = serde_json::json!({
1979 "hidden_size": 2048
1980 });
1981 let report = TransformerConfig::check_config_fields(&bad);
1982 assert!(!report.compatible);
1983 assert_eq!(report.issues.len(), 4); }
1985
1986 #[test]
1987 fn compatibility_into_result_error_message() {
1988 let json = serde_json::json!({
1989 "model_type": "olmo",
1990 "hidden_size": 2048,
1991 "num_hidden_layers": 16,
1992 "num_attention_heads": 16,
1993 "intermediate_size": 8192,
1994 "vocab_size": 50304
1995 });
1996 let names = tensor_names(&[
1997 "model.embed_tokens.weight",
1998 "model.layers.0.self_attn.q_proj.weight",
1999 "model.layers.0.mlp.gate_proj.weight",
2000 ]);
2001 let result = TransformerConfig::check_auto_compatibility(&json, &names).into_result();
2002 assert!(result.is_err());
2003 let msg = result.unwrap_err().to_string();
2004 assert!(
2005 msg.contains("not compatible with GenericTransformer"),
2006 "error should explain incompatibility: {msg}"
2007 );
2008 }
2009
2010 #[test]
2011 fn compatibility_check_shows_gpt2_naming_hint() {
2012 let json = serde_json::json!({
2013 "model_type": "gpt2",
2014 "hidden_size": 768,
2015 "num_hidden_layers": 12,
2016 "num_attention_heads": 12,
2017 "intermediate_size": 3072,
2018 "vocab_size": 50257
2019 });
2020 let names = tensor_names(&[
2021 "transformer.wte.weight",
2022 "transformer.wpe.weight",
2023 "transformer.h.0.ln_1.weight",
2024 "transformer.h.0.attn.c_attn.weight",
2025 "transformer.h.0.mlp.c_fc.weight",
2026 "transformer.ln_f.weight",
2027 ]);
2028 let report = TransformerConfig::check_auto_compatibility(&json, &names);
2029 assert!(!report.compatible);
2030 assert!(
2032 report.issues.iter().any(|i| i.contains("GPT-2")),
2033 "should detect GPT-2 naming convention: {:?}",
2034 report.issues
2035 );
2036 assert!(
2038 report
2039 .issues
2040 .iter()
2041 .any(|i| i.contains("transformer.wte.weight")),
2042 "should show found embedding tensor: {:?}",
2043 report.issues
2044 );
2045 assert!(
2047 report.issues.iter().any(|i| i.contains("c_attn")),
2048 "should show found attention tensor: {:?}",
2049 report.issues
2050 );
2051 }
2052
2053 #[test]
2054 fn compatibility_check_shows_found_tensors_for_unknown_naming() {
2055 let json = serde_json::json!({
2056 "model_type": "custom_arch",
2057 "hidden_size": 512,
2058 "num_hidden_layers": 6,
2059 "num_attention_heads": 8,
2060 "intermediate_size": 2048,
2061 "vocab_size": 30000
2062 });
2063 let names = tensor_names(&[
2064 "encoder.layer.0.attention.query.weight",
2065 "encoder.layer.0.attention.key.weight",
2066 "encoder.layer.0.ffn.dense.weight",
2067 "encoder.embeddings.weight",
2068 ]);
2069 let report = TransformerConfig::check_auto_compatibility(&json, &names);
2070 assert!(!report.compatible);
2071 assert!(
2073 report
2074 .issues
2075 .iter()
2076 .any(|i| i.contains("unrecognized naming convention")),
2077 "should flag unrecognized naming: {:?}",
2078 report.issues
2079 );
2080 assert!(
2082 report
2083 .issues
2084 .iter()
2085 .any(|i| i.contains("encoder.embeddings.weight")),
2086 "should show found embedding: {:?}",
2087 report.issues
2088 );
2089 }
2090
2091 #[test]
2092 fn compatibility_check_shows_found_norm_tensors() {
2093 let json = serde_json::json!({
2095 "model_type": "custom",
2096 "hidden_size": 2048,
2097 "num_hidden_layers": 16,
2098 "num_attention_heads": 32,
2099 "intermediate_size": 8192,
2100 "vocab_size": 32000,
2101 "tie_word_embeddings": true
2102 });
2103 let names = tensor_names(&[
2104 "model.embed_tokens.weight",
2105 "model.layers.0.self_attn.q_proj.weight",
2106 "model.layers.0.mlp.gate_proj.weight",
2107 "model.layers.0.attention_norm.weight",
2108 "model.layers.0.ffn_norm.weight",
2109 "model.final_norm.weight",
2110 ]);
2111 let report = TransformerConfig::check_auto_compatibility(&json, &names);
2112 assert!(!report.compatible);
2113 assert!(
2115 report.issues.iter().any(|i| i.contains("attention_norm")),
2116 "should show found norm tensors: {:?}",
2117 report.issues
2118 );
2119 }
2120}