1use anyhow::{Result, anyhow};
54use pest::Parser;
55use pest_derive::Parser;
56use rust_embed::Embed;
57use serde::{Deserialize, Serialize};
58use std::collections::HashMap;
59
60#[derive(Embed)]
62#[folder = "well-known/"]
63#[include = "*.mal"]
64struct WellKnown;
65
66#[derive(Parser)]
67#[grammar = "mal/mal.pest"]
68pub struct MalParser;
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
76pub enum PositionEncoding {
77 Rope { theta: f64, scaling: Option<f64> },
78 Alibi { learned_slopes: bool },
79 Learned { max_positions: usize },
80 None,
81}
82
83impl Default for PositionEncoding {
84 fn default() -> Self {
85 Self::Rope {
86 theta: 10000.0,
87 scaling: None,
88 }
89 }
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct AttentionDef {
95 pub name: String,
96 pub num_heads: Option<usize>,
97 pub num_kv_heads: Option<usize>,
98 pub head_dim: Option<usize>,
99 pub dropout: f64,
100 pub bias: bool,
101 pub position_encoding: PositionEncoding,
102 pub window_size: Option<usize>,
103 pub causal: bool,
104}
105
106impl Default for AttentionDef {
107 fn default() -> Self {
108 Self {
109 name: "default".to_string(),
110 num_heads: None,
111 num_kv_heads: None,
112 head_dim: None,
113 dropout: 0.0,
114 bias: false,
115 position_encoding: PositionEncoding::default(),
116 window_size: None,
117 causal: true,
118 }
119 }
120}
121
122#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
124pub enum NormType {
125 #[default]
126 RmsNorm,
127 LayerNorm,
128 None,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize, Default)]
133pub struct NormConfig {
134 pub norm_type: NormType,
135 pub eps: f64,
136}
137
138#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
140pub enum Activation {
141 #[default]
142 SwiGLU,
143 GELU,
144 SiLU,
145 ReLU,
146 GELUNew,
147 GELUTanh,
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct FfnDef {
153 pub name: String,
154 pub hidden_dim: Option<usize>,
155 pub activation: Activation,
156 pub bias: bool,
157 pub dropout: f64,
158 pub gate: bool,
159}
160
161impl Default for FfnDef {
162 fn default() -> Self {
163 Self {
164 name: "default".to_string(),
165 hidden_dim: None,
166 activation: Activation::default(),
167 bias: false,
168 dropout: 0.0,
169 gate: true,
170 }
171 }
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct BlockDef {
177 pub name: String,
178 pub attention: AttentionDef,
179 pub ffn: FfnDef,
180 pub norm: NormConfig,
181 pub norm_position: NormPosition,
182 pub residual: bool,
183 pub dropout: f64,
184}
185
186#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
188pub enum NormPosition {
189 #[default]
190 Pre,
191 Post,
192}
193
194impl Default for BlockDef {
195 fn default() -> Self {
196 Self {
197 name: "default".to_string(),
198 attention: AttentionDef::default(),
199 ffn: FfnDef::default(),
200 norm: NormConfig {
201 norm_type: NormType::RmsNorm,
202 eps: 1e-5,
203 },
204 norm_position: NormPosition::Pre,
205 residual: true,
206 dropout: 0.0,
207 }
208 }
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize, Default)]
213pub struct EmbeddingsConfig {
214 pub tie_weights: bool,
215 pub dropout: f64,
216 pub scale: Option<f64>,
217}
218
219#[derive(Debug, Clone, Serialize, Deserialize, Default)]
221pub struct OutputConfig {
222 pub bias: bool,
223 pub norm: Option<NormConfig>,
224}
225
226#[derive(Debug, Clone, Serialize, Deserialize)]
228pub struct ModelDef {
229 pub name: String,
230 pub description: Option<String>,
231 pub vocab_size: usize,
232 pub max_seq_len: usize,
233 pub hidden_size: usize,
234 pub num_layers: usize,
235 pub block: BlockDef,
236 pub embeddings: EmbeddingsConfig,
237 pub output: OutputConfig,
238}
239
240impl Default for ModelDef {
241 fn default() -> Self {
242 Self {
243 name: "default".to_string(),
244 description: None,
245 vocab_size: 32000,
246 max_seq_len: 2048,
247 hidden_size: 768,
248 num_layers: 12,
249 block: BlockDef::default(),
250 embeddings: EmbeddingsConfig::default(),
251 output: OutputConfig::default(),
252 }
253 }
254}
255
256impl std::fmt::Display for ModelDef {
257 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258 writeln!(f, "model {} {{", self.name)?;
259 if let Some(desc) = &self.description {
260 writeln!(f, " description: \"{}\"", desc)?;
261 }
262 writeln!(f, " vocab_size: {}", self.vocab_size)?;
263 writeln!(f, " max_seq_len: {}", self.max_seq_len)?;
264 writeln!(f, " hidden_size: {}", self.hidden_size)?;
265 writeln!(f, " num_layers: {}", self.num_layers)?;
266 writeln!(f, "}}")?;
267 writeln!(f)?;
268
269 writeln!(f, "attention {{")?;
271 if let Some(h) = self.block.attention.num_heads {
272 writeln!(f, " num_heads: {}", h)?;
273 }
274 if let Some(kv) = self.block.attention.num_kv_heads {
275 writeln!(f, " num_kv_heads: {}", kv)?;
276 }
277 if let Some(hd) = self.block.attention.head_dim {
278 writeln!(f, " head_dim: {}", hd)?;
279 }
280 writeln!(f, " bias: {}", self.block.attention.bias)?;
281 writeln!(f, "}}")?;
282 writeln!(f)?;
283
284 writeln!(f, "ffn {{")?;
286 if let Some(dim) = self.block.ffn.hidden_dim {
287 writeln!(f, " hidden_dim: {}", dim)?;
288 }
289 writeln!(f, " activation: {:?}", self.block.ffn.activation)?;
290 writeln!(f, " bias: {}", self.block.ffn.bias)?;
291 writeln!(f, "}}")?;
292 writeln!(f)?;
293
294 writeln!(f, "block {{")?;
296 writeln!(f, " norm: {:?}", self.block.norm.norm_type)?;
297 writeln!(f, " norm_position: {:?}", self.block.norm_position)?;
298 writeln!(f, " residual: {}", self.block.residual)?;
299 writeln!(f, "}}")?;
300 writeln!(f)?;
301
302 let params = self.estimated_params();
304 writeln!(
305 f,
306 "Estimated parameters: {:.2}B",
307 params as f64 / 1_000_000_000.0
308 )
309 }
310}
311
312impl ModelDef {
313 pub fn num_heads(&self) -> usize {
318 self.block.attention.num_heads.unwrap_or(12)
319 }
320
321 pub fn num_kv_heads(&self) -> usize {
322 self.block
323 .attention
324 .num_kv_heads
325 .unwrap_or(self.num_heads())
326 }
327
328 pub fn head_dim(&self) -> usize {
329 self.block
330 .attention
331 .head_dim
332 .unwrap_or(self.hidden_size / self.num_heads())
333 }
334
335 pub fn intermediate_size(&self) -> usize {
336 self.block.ffn.hidden_dim.unwrap_or(self.hidden_size * 4)
337 }
338
339 pub fn dropout(&self) -> f64 {
340 self.block.dropout
341 }
342
343 pub fn use_bias(&self) -> bool {
344 self.block.ffn.bias || self.block.attention.bias
345 }
346
347 pub fn norm_eps(&self) -> f64 {
348 if self.block.norm.eps > 0.0 {
349 self.block.norm.eps
350 } else {
351 1e-5
352 }
353 }
354
355 pub fn rope_theta(&self) -> f64 {
356 match &self.block.attention.position_encoding {
357 PositionEncoding::Rope { theta, .. } => *theta,
358 _ => 10000.0,
359 }
360 }
361
362 pub fn use_swiglu(&self) -> bool {
363 matches!(self.block.ffn.activation, Activation::SwiGLU)
364 }
365
366 pub fn use_rmsnorm(&self) -> bool {
367 matches!(self.block.norm.norm_type, NormType::RmsNorm)
368 }
369
370 pub fn estimated_params(&self) -> usize {
372 let embed_params = self.vocab_size * self.hidden_size;
373 let attn_params = 4 * self.hidden_size * self.hidden_size;
374 let ff_params = 3 * self.hidden_size * self.intermediate_size();
375 let layer_params = attn_params + ff_params + 2 * self.hidden_size;
376 let head_params = self.hidden_size * self.vocab_size;
377 embed_params + self.num_layers * layer_params + head_params
378 }
379
380 pub fn from_json(path: &str) -> Result<Self> {
382 let content = std::fs::read_to_string(path)?;
383 Ok(serde_json::from_str(&content)?)
384 }
385
386 pub fn save_json(&self, path: &str) -> Result<()> {
388 let content = serde_json::to_string_pretty(self)?;
389 std::fs::write(path, content)?;
390 Ok(())
391 }
392}
393
394#[derive(Debug, Clone, Default)]
396pub struct MalFile {
397 pub attentions: HashMap<String, AttentionDef>,
398 pub ffns: HashMap<String, FfnDef>,
399 pub blocks: HashMap<String, BlockDef>,
400 pub models: HashMap<String, ModelDef>,
401}
402
403fn parse_activation(s: &str) -> Activation {
409 match s {
410 "swiglu" => Activation::SwiGLU,
411 "gelu" => Activation::GELU,
412 "silu" => Activation::SiLU,
413 "relu" => Activation::ReLU,
414 "gelu_new" => Activation::GELUNew,
415 "gelu_tanh" => Activation::GELUTanh,
416 _ => Activation::SwiGLU,
417 }
418}
419
420fn parse_model_prop(
422 pair: pest::iterators::Pair<Rule>,
423 def: &mut ModelDef,
424 file: &MalFile,
425) -> Result<()> {
426 for inner in pair.into_inner() {
427 match inner.as_rule() {
428 Rule::vocab_size_prop => {
429 if let Some(val) = inner.into_inner().next() {
430 def.vocab_size = val.as_str().parse()?;
431 }
432 }
433 Rule::max_seq_len_prop => {
434 if let Some(val) = inner.into_inner().next() {
435 def.max_seq_len = val.as_str().parse()?;
436 }
437 }
438 Rule::hidden_size_prop => {
439 if let Some(val) = inner.into_inner().next() {
440 def.hidden_size = val.as_str().parse()?;
441 }
442 }
443 Rule::num_layers_prop => {
444 if let Some(val) = inner.into_inner().next() {
445 def.num_layers = val.as_str().parse()?;
446 }
447 }
448 Rule::block_ref_prop => {
449 for child in inner.into_inner() {
450 match child.as_rule() {
451 Rule::identifier => {
452 let name = child.as_str();
453 if let Some(block) = file.blocks.get(name) {
454 def.block = block.clone();
455 }
456 }
457 Rule::inline_block => {
458 let mut block = BlockDef::default();
459 for prop in child.into_inner() {
460 if prop.as_rule() == Rule::block_prop {
461 parse_block_prop(prop, &mut block, file)?;
462 }
463 }
464 def.block = block;
465 }
466 _ => {}
467 }
468 }
469 }
470 Rule::description_prop => {
471 if let Some(val) = inner.into_inner().next() {
472 let s = val.as_str();
473 def.description = Some(s[1..s.len() - 1].to_string());
474 }
475 }
476 _ => {}
477 }
478 }
479 Ok(())
480}
481
482fn parse_model_def(pair: pest::iterators::Pair<Rule>, file: &MalFile) -> Result<ModelDef> {
484 let mut def = ModelDef::default();
485 let mut inner = pair.into_inner();
486
487 if let Some(name) = inner.next() {
489 def.name = name.as_str().to_string();
490 }
491
492 for prop in inner {
494 if prop.as_rule() == Rule::model_prop {
495 parse_model_prop(prop, &mut def, file)?;
496 }
497 }
498
499 Ok(def)
500}
501
502pub fn parse_mal(input: &str) -> Result<ModelDef> {
504 let file = parse_mal_full(input)?;
505 file.models
506 .into_values()
507 .next()
508 .ok_or_else(|| anyhow!("No model definition found"))
509}
510
511pub fn parse_mal_full(input: &str) -> Result<MalFile> {
513 let pairs = MalParser::parse(Rule::file, input).map_err(|e| anyhow!("Parse error: {}", e))?;
514
515 let mut file = MalFile::default();
516
517 for pair in pairs {
518 if pair.as_rule() == Rule::file {
519 for inner in pair.into_inner() {
520 if inner.as_rule() == Rule::definition {
521 for def in inner.into_inner() {
522 match def.as_rule() {
523 Rule::model_def => {
524 let model = parse_model_def(def, &file)?;
525 file.models.insert(model.name.clone(), model);
526 }
527 Rule::attention_def => {
528 let attn = parse_attention_def(def)?;
529 file.attentions.insert(attn.name.clone(), attn);
530 }
531 Rule::ffn_def => {
532 let ffn = parse_ffn_def(def)?;
533 file.ffns.insert(ffn.name.clone(), ffn);
534 }
535 Rule::block_def => {
536 let block = parse_block_def(def, &file)?;
537 file.blocks.insert(block.name.clone(), block);
538 }
539 _ => {}
540 }
541 }
542 }
543 }
544 }
545 }
546
547 Ok(file)
548}
549
550fn parse_attention_def(pair: pest::iterators::Pair<Rule>) -> Result<AttentionDef> {
552 let mut def = AttentionDef::default();
553 let mut inner = pair.into_inner();
554
555 if let Some(name) = inner.next() {
556 def.name = name.as_str().to_string();
557 }
558
559 for prop in inner {
560 if prop.as_rule() == Rule::attention_prop {
561 parse_attention_prop(prop, &mut def)?;
562 }
563 }
564
565 Ok(def)
566}
567
568fn parse_attention_prop(pair: pest::iterators::Pair<Rule>, def: &mut AttentionDef) -> Result<()> {
570 for inner in pair.into_inner() {
571 match inner.as_rule() {
572 Rule::num_heads_prop => {
573 if let Some(val) = inner.into_inner().next() {
574 def.num_heads = Some(val.as_str().parse()?);
575 }
576 }
577 Rule::num_kv_heads_prop => {
578 if let Some(val) = inner.into_inner().next() {
579 def.num_kv_heads = Some(val.as_str().parse()?);
580 }
581 }
582 Rule::head_dim_prop => {
583 if let Some(val) = inner.into_inner().next() {
584 def.head_dim = Some(val.as_str().parse()?);
585 }
586 }
587 Rule::dropout_prop => {
588 if let Some(val) = inner.into_inner().next() {
589 def.dropout = val.as_str().parse()?;
590 }
591 }
592 Rule::bias_prop => {
593 if let Some(val) = inner.into_inner().next() {
594 def.bias = val.as_str() == "true";
595 }
596 }
597 Rule::causal_prop => {
598 if let Some(val) = inner.into_inner().next() {
599 def.causal = val.as_str() == "true";
600 }
601 }
602 Rule::window_size_prop => {
603 if let Some(val) = inner.into_inner().next() {
604 def.window_size = Some(val.as_str().parse()?);
605 }
606 }
607 _ => {}
608 }
609 }
610 Ok(())
611}
612
613fn parse_ffn_def(pair: pest::iterators::Pair<Rule>) -> Result<FfnDef> {
615 let mut def = FfnDef::default();
616 let mut inner = pair.into_inner();
617
618 if let Some(name) = inner.next() {
619 def.name = name.as_str().to_string();
620 }
621
622 for prop in inner {
623 if prop.as_rule() == Rule::ffn_prop {
624 parse_ffn_prop(prop, &mut def)?;
625 }
626 }
627
628 Ok(def)
629}
630
631fn parse_ffn_prop(pair: pest::iterators::Pair<Rule>, def: &mut FfnDef) -> Result<()> {
633 for inner in pair.into_inner() {
634 match inner.as_rule() {
635 Rule::hidden_dim_prop => {
636 if let Some(val) = inner.into_inner().next() {
637 def.hidden_dim = Some(val.as_str().parse()?);
638 }
639 }
640 Rule::activation_prop => {
641 if let Some(val) = inner.into_inner().next() {
642 def.activation = parse_activation(val.as_str());
643 }
644 }
645 Rule::bias_prop => {
646 if let Some(val) = inner.into_inner().next() {
647 def.bias = val.as_str() == "true";
648 }
649 }
650 Rule::dropout_prop => {
651 if let Some(val) = inner.into_inner().next() {
652 def.dropout = val.as_str().parse()?;
653 }
654 }
655 Rule::gate_prop => {
656 if let Some(val) = inner.into_inner().next() {
657 def.gate = val.as_str() == "true";
658 }
659 }
660 _ => {}
661 }
662 }
663 Ok(())
664}
665
666fn parse_block_def(pair: pest::iterators::Pair<Rule>, file: &MalFile) -> Result<BlockDef> {
668 let mut def = BlockDef::default();
669 let mut inner = pair.into_inner();
670
671 if let Some(name) = inner.next() {
672 def.name = name.as_str().to_string();
673 }
674
675 for prop in inner {
676 if prop.as_rule() == Rule::block_prop {
677 parse_block_prop(prop, &mut def, file)?;
678 }
679 }
680
681 Ok(def)
682}
683
684fn parse_block_prop(
686 pair: pest::iterators::Pair<Rule>,
687 def: &mut BlockDef,
688 file: &MalFile,
689) -> Result<()> {
690 for inner in pair.into_inner() {
691 match inner.as_rule() {
692 Rule::attention_ref_prop => {
693 for child in inner.into_inner() {
695 match child.as_rule() {
696 Rule::identifier => {
697 let name = child.as_str();
698 if let Some(attn) = file.attentions.get(name) {
699 def.attention = attn.clone();
700 }
701 }
702 Rule::inline_attention => {
703 let mut attn = AttentionDef::default();
704 for prop in child.into_inner() {
705 if prop.as_rule() == Rule::attention_prop {
706 parse_attention_prop(prop, &mut attn)?;
707 }
708 }
709 def.attention = attn;
710 }
711 _ => {}
712 }
713 }
714 }
715 Rule::ffn_ref_prop => {
716 for child in inner.into_inner() {
717 match child.as_rule() {
718 Rule::identifier => {
719 let name = child.as_str();
720 if let Some(ffn) = file.ffns.get(name) {
721 def.ffn = ffn.clone();
722 }
723 }
724 Rule::inline_ffn => {
725 let mut ffn = FfnDef::default();
726 for prop in child.into_inner() {
727 if prop.as_rule() == Rule::ffn_prop {
728 parse_ffn_prop(prop, &mut ffn)?;
729 }
730 }
731 def.ffn = ffn;
732 }
733 _ => {}
734 }
735 }
736 }
737 Rule::norm_position_prop => {
738 if let Some(val) = inner.into_inner().next() {
739 def.norm_position = match val.as_str() {
740 "pre" => NormPosition::Pre,
741 "post" => NormPosition::Post,
742 _ => NormPosition::Pre,
743 };
744 }
745 }
746 Rule::residual_prop => {
747 if let Some(val) = inner.into_inner().next() {
748 def.residual = val.as_str() == "true";
749 }
750 }
751 Rule::dropout_prop => {
752 if let Some(val) = inner.into_inner().next() {
753 def.dropout = val.as_str().parse()?;
754 }
755 }
756 _ => {}
757 }
758 }
759 Ok(())
760}
761
762pub fn parse_mal_file<P: AsRef<std::path::Path>>(path: P) -> Result<ModelDef> {
764 let content = std::fs::read_to_string(path)?;
765 parse_mal(&content)
766}
767
768pub fn get_builtin_model(name: &str) -> Option<ModelDef> {
779 let mal = get_wellknown_mal(name)?;
780 parse_mal(&mal).ok()
781}
782
783pub fn get_wellknown_mal(name: &str) -> Option<String> {
787 let name = name.strip_prefix("well-known/").unwrap_or(name);
789 let filename = if name.ends_with(".mal") {
790 name.to_string()
791 } else {
792 format!("{}.mal", name.replace('-', "_"))
794 };
795
796 WellKnown::get(&filename).map(|f| String::from_utf8_lossy(&f.data).into_owned())
797}
798
799pub fn list_wellknown_models() -> Vec<String> {
801 WellKnown::iter()
802 .filter_map(|path| {
803 let path: &str = path.as_ref();
804 if path.ends_with(".mal") {
805 Some(path.strip_suffix(".mal").unwrap().replace('_', "-"))
806 } else {
807 None
808 }
809 })
810 .collect()
811}
812
813#[cfg(test)]
814mod tests {
815 use super::*;
816
817 #[test]
818 fn test_parse_simple_model() {
819 let mal = r#"
820 attention test_attn {
821 num_heads: 8
822 bias: false
823 }
824
825 ffn test_ffn {
826 hidden_dim: 2048
827 activation: gelu
828 }
829
830 block test_block {
831 attention: test_attn
832 ffn: test_ffn
833 norm_position: pre
834 }
835
836 model test {
837 vocab_size: 32000
838 hidden_size: 512
839 num_layers: 8
840 block: test_block
841 }
842 "#;
843
844 let def = parse_mal(mal).unwrap();
845 assert_eq!(def.name, "test");
846 assert_eq!(def.vocab_size, 32000);
847 assert_eq!(def.hidden_size, 512);
848 assert_eq!(def.num_layers, 8);
849 }
850
851 #[test]
852 fn test_parse_with_block_props() {
853 let mal = r#"
854 attention full_attn {
855 num_heads: 16
856 num_kv_heads: 4
857 bias: true
858 dropout: 0.1
859 }
860
861 ffn full_ffn {
862 hidden_dim: 4096
863 activation: gelu
864 bias: true
865 dropout: 0.1
866 }
867
868 block full_block {
869 attention: full_attn
870 ffn: full_ffn
871 norm: layernorm { eps: 1e-6 }
872 norm_position: pre
873 residual: true
874 }
875
876 model full_test {
877 description: "A test model"
878 vocab_size: 50000
879 max_seq_len: 4096
880 hidden_size: 1024
881 num_layers: 12
882 block: full_block
883 }
884 "#;
885
886 let def = parse_mal(mal).unwrap();
887 assert_eq!(def.description, Some("A test model".to_string()));
888 assert_eq!(def.vocab_size, 50000);
889 assert_eq!(def.max_seq_len, 4096);
890 assert_eq!(def.block.attention.num_heads, Some(16));
891 assert_eq!(def.block.attention.num_kv_heads, Some(4));
892 assert_eq!(def.block.ffn.hidden_dim, Some(4096));
893 assert!(matches!(def.block.ffn.activation, Activation::GELU));
894 }
895
896 #[test]
897 fn test_wellknown_models() {
898 for name in list_wellknown_models() {
899 let def = get_builtin_model(&name).unwrap_or_else(|| panic!("Failed to get {}", name));
900 assert!(def.num_heads() > 0);
902 assert!(def.intermediate_size() > 0);
903 }
904 }
905
906 #[test]
907 fn test_model_properties() {
908 let def = get_builtin_model("tiny").unwrap();
909
910 assert_eq!(def.vocab_size, 32000);
911 assert_eq!(def.hidden_size, 128);
912 assert_eq!(def.num_layers, 4);
913 assert_eq!(def.num_heads(), 4);
914 }
915
916 #[test]
917 fn test_comments() {
918 let mal = r#"
919 # This is a comment
920 attention test_attn {
921 # Comment in attention
922 num_heads: 2
923 }
924
925 ffn test_ffn {
926 hidden_dim: 256
927 }
928
929 block test_block {
930 attention: test_attn
931 ffn: test_ffn
932 }
933
934 # Comment before model
935 model test {
936 vocab_size: 1000
937 hidden_size: 64
938 num_layers: 2
939 block: test_block
940 }
941 "#;
942
943 let def = parse_mal(mal).unwrap();
944 assert_eq!(def.vocab_size, 1000);
945 }
946
947 #[test]
948 fn test_composable_architecture() {
949 let mal = r#"
950 attention my_attn {
951 num_heads: 16
952 num_kv_heads: 4
953 head_dim: 128
954 bias: false
955 }
956
957 ffn my_ffn {
958 hidden_dim: 11008
959 activation: swiglu
960 bias: false
961 }
962
963 block my_block {
964 attention: my_attn
965 ffn: my_ffn
966 norm: rmsnorm { eps: 1e-5 }
967 norm_position: pre
968 residual: true
969 }
970
971 model my_model {
972 description: "LLaMA 7B architecture"
973 vocab_size: 32000
974 max_seq_len: 4096
975 hidden_size: 4096
976 num_layers: 32
977 block: my_block
978 }
979 "#;
980
981 let file = parse_mal_full(mal).unwrap();
982
983 assert!(file.attentions.contains_key("my_attn"));
984 assert!(file.ffns.contains_key("my_ffn"));
985 assert!(file.blocks.contains_key("my_block"));
986 assert!(file.models.contains_key("my_model"));
987
988 let attn = file.attentions.get("my_attn").unwrap();
989 assert_eq!(attn.num_heads, Some(16));
990 assert_eq!(attn.num_kv_heads, Some(4));
991
992 let ffn = file.ffns.get("my_ffn").unwrap();
993 assert_eq!(ffn.hidden_dim, Some(11008));
994 assert!(matches!(ffn.activation, Activation::SwiGLU));
995
996 let block = file.blocks.get("my_block").unwrap();
997 assert!(matches!(block.norm_position, NormPosition::Pre));
998 assert!(block.residual);
999 }
1000}