Skip to main content

hermes_llm/mal/
mod.rs

1//! Model Architecture Language (MAL) for Hermes LLM
2//!
3//! A composable DSL for defining LLM model architectures using pest parser.
4//!
5//! # Example MAL - Simple (flat) style
6//!
7//! ```text
8//! model tiny {
9//!     vocab_size: 32000
10//!     hidden_size: 128
11//!     num_layers: 4
12//!     num_heads: 4
13//!     intermediate_size: 512
14//! }
15//! ```
16//!
17//! # Example MAL - Composable style
18//!
19//! ```text
20//! # Define attention mechanism
21//! attention gqa {
22//!     num_heads: 32
23//!     num_kv_heads: 8
24//!     head_dim: 128
25//!     position_encoding: rope { theta: 10000.0 }
26//! }
27//!
28//! # Define FFN
29//! ffn swiglu_mlp {
30//!     hidden_dim: 14336
31//!     activation: swiglu
32//!     bias: false
33//! }
34//!
35//! # Define transformer block
36//! block llama_block {
37//!     attention: gqa
38//!     ffn: swiglu_mlp
39//!     norm: rmsnorm { eps: 1e-5 }
40//!     norm_position: pre
41//! }
42//!
43//! # Define model using the block
44//! model llama_7b {
45//!     vocab_size: 32000
46//!     max_seq_len: 4096
47//!     hidden_size: 4096
48//!     block: llama_block
49//!     num_layers: 32
50//! }
51//! ```
52
53use anyhow::{Result, anyhow};
54use pest::Parser;
55use pest_derive::Parser;
56use rust_embed::Embed;
57use serde::{Deserialize, Serialize};
58use std::collections::HashMap;
59
60/// Embedded well-known model definitions
61#[derive(Embed)]
62#[folder = "well-known/"]
63#[include = "*.mal"]
64struct WellKnown;
65
66#[derive(Parser)]
67#[grammar = "mal/mal.pest"]
68pub struct MalParser;
69
70// ============================================================================
71// AST Types
72// ============================================================================
73
74/// Position encoding type
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub enum PositionEncoding {
77    Rope { theta: f64, scaling: Option<f64> },
78    Alibi { learned_slopes: bool },
79    Learned { max_positions: usize },
80    None,
81}
82
83impl Default for PositionEncoding {
84    fn default() -> Self {
85        Self::Rope {
86            theta: 10000.0,
87            scaling: None,
88        }
89    }
90}
91
92/// Attention mechanism definition
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct AttentionDef {
95    pub name: String,
96    pub num_heads: Option<usize>,
97    pub num_kv_heads: Option<usize>,
98    pub head_dim: Option<usize>,
99    pub dropout: f64,
100    pub bias: bool,
101    pub position_encoding: PositionEncoding,
102    pub window_size: Option<usize>,
103    pub causal: bool,
104}
105
106impl Default for AttentionDef {
107    fn default() -> Self {
108        Self {
109            name: "default".to_string(),
110            num_heads: None,
111            num_kv_heads: None,
112            head_dim: None,
113            dropout: 0.0,
114            bias: false,
115            position_encoding: PositionEncoding::default(),
116            window_size: None,
117            causal: true,
118        }
119    }
120}
121
122/// Normalization type
123#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
124pub enum NormType {
125    #[default]
126    RmsNorm,
127    LayerNorm,
128    None,
129}
130
131/// Normalization configuration
132#[derive(Debug, Clone, Serialize, Deserialize, Default)]
133pub struct NormConfig {
134    pub norm_type: NormType,
135    pub eps: f64,
136}
137
138/// Activation function
139#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
140pub enum Activation {
141    #[default]
142    SwiGLU,
143    GELU,
144    SiLU,
145    ReLU,
146    GELUNew,
147    GELUTanh,
148}
149
150/// Feed-forward network definition
151#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct FfnDef {
153    pub name: String,
154    pub hidden_dim: Option<usize>,
155    pub activation: Activation,
156    pub bias: bool,
157    pub dropout: f64,
158    pub gate: bool,
159}
160
161impl Default for FfnDef {
162    fn default() -> Self {
163        Self {
164            name: "default".to_string(),
165            hidden_dim: None,
166            activation: Activation::default(),
167            bias: false,
168            dropout: 0.0,
169            gate: true,
170        }
171    }
172}
173
174/// Transformer block definition
175#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct BlockDef {
177    pub name: String,
178    pub attention: AttentionDef,
179    pub ffn: FfnDef,
180    pub norm: NormConfig,
181    pub norm_position: NormPosition,
182    pub residual: bool,
183    pub dropout: f64,
184}
185
186/// Normalization position in block
187#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
188pub enum NormPosition {
189    #[default]
190    Pre,
191    Post,
192}
193
194impl Default for BlockDef {
195    fn default() -> Self {
196        Self {
197            name: "default".to_string(),
198            attention: AttentionDef::default(),
199            ffn: FfnDef::default(),
200            norm: NormConfig {
201                norm_type: NormType::RmsNorm,
202                eps: 1e-5,
203            },
204            norm_position: NormPosition::Pre,
205            residual: true,
206            dropout: 0.0,
207        }
208    }
209}
210
211/// Embeddings configuration
212#[derive(Debug, Clone, Serialize, Deserialize, Default)]
213pub struct EmbeddingsConfig {
214    pub tie_weights: bool,
215    pub dropout: f64,
216    pub scale: Option<f64>,
217}
218
219/// Output head configuration
220#[derive(Debug, Clone, Serialize, Deserialize, Default)]
221pub struct OutputConfig {
222    pub bias: bool,
223    pub norm: Option<NormConfig>,
224}
225
226/// Parsed model definition from MAL
227#[derive(Debug, Clone, Serialize, Deserialize)]
228pub struct ModelDef {
229    pub name: String,
230    pub description: Option<String>,
231    pub vocab_size: usize,
232    pub max_seq_len: usize,
233    pub hidden_size: usize,
234    pub num_layers: usize,
235    pub block: BlockDef,
236    pub embeddings: EmbeddingsConfig,
237    pub output: OutputConfig,
238}
239
240impl Default for ModelDef {
241    fn default() -> Self {
242        Self {
243            name: "default".to_string(),
244            description: None,
245            vocab_size: 32000,
246            max_seq_len: 2048,
247            hidden_size: 768,
248            num_layers: 12,
249            block: BlockDef::default(),
250            embeddings: EmbeddingsConfig::default(),
251            output: OutputConfig::default(),
252        }
253    }
254}
255
256impl std::fmt::Display for ModelDef {
257    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258        writeln!(f, "model {} {{", self.name)?;
259        if let Some(desc) = &self.description {
260            writeln!(f, "    description: \"{}\"", desc)?;
261        }
262        writeln!(f, "    vocab_size: {}", self.vocab_size)?;
263        writeln!(f, "    max_seq_len: {}", self.max_seq_len)?;
264        writeln!(f, "    hidden_size: {}", self.hidden_size)?;
265        writeln!(f, "    num_layers: {}", self.num_layers)?;
266        writeln!(f, "}}")?;
267        writeln!(f)?;
268
269        // Attention
270        writeln!(f, "attention {{")?;
271        if let Some(h) = self.block.attention.num_heads {
272            writeln!(f, "    num_heads: {}", h)?;
273        }
274        if let Some(kv) = self.block.attention.num_kv_heads {
275            writeln!(f, "    num_kv_heads: {}", kv)?;
276        }
277        if let Some(hd) = self.block.attention.head_dim {
278            writeln!(f, "    head_dim: {}", hd)?;
279        }
280        writeln!(f, "    bias: {}", self.block.attention.bias)?;
281        writeln!(f, "}}")?;
282        writeln!(f)?;
283
284        // FFN
285        writeln!(f, "ffn {{")?;
286        if let Some(dim) = self.block.ffn.hidden_dim {
287            writeln!(f, "    hidden_dim: {}", dim)?;
288        }
289        writeln!(f, "    activation: {:?}", self.block.ffn.activation)?;
290        writeln!(f, "    bias: {}", self.block.ffn.bias)?;
291        writeln!(f, "}}")?;
292        writeln!(f)?;
293
294        // Block
295        writeln!(f, "block {{")?;
296        writeln!(f, "    norm: {:?}", self.block.norm.norm_type)?;
297        writeln!(f, "    norm_position: {:?}", self.block.norm_position)?;
298        writeln!(f, "    residual: {}", self.block.residual)?;
299        writeln!(f, "}}")?;
300        writeln!(f)?;
301
302        // Estimated parameters
303        let params = self.estimated_params();
304        writeln!(
305            f,
306            "Estimated parameters: {:.2}B",
307            params as f64 / 1_000_000_000.0
308        )
309    }
310}
311
312impl ModelDef {
313    // ========================================================================
314    // Computed properties for model construction
315    // ========================================================================
316
317    pub fn num_heads(&self) -> usize {
318        self.block.attention.num_heads.unwrap_or(12)
319    }
320
321    pub fn num_kv_heads(&self) -> usize {
322        self.block
323            .attention
324            .num_kv_heads
325            .unwrap_or(self.num_heads())
326    }
327
328    pub fn head_dim(&self) -> usize {
329        self.block
330            .attention
331            .head_dim
332            .unwrap_or(self.hidden_size / self.num_heads())
333    }
334
335    pub fn intermediate_size(&self) -> usize {
336        self.block.ffn.hidden_dim.unwrap_or(self.hidden_size * 4)
337    }
338
339    pub fn dropout(&self) -> f64 {
340        self.block.dropout
341    }
342
343    pub fn use_bias(&self) -> bool {
344        self.block.ffn.bias || self.block.attention.bias
345    }
346
347    pub fn norm_eps(&self) -> f64 {
348        if self.block.norm.eps > 0.0 {
349            self.block.norm.eps
350        } else {
351            1e-5
352        }
353    }
354
355    pub fn rope_theta(&self) -> f64 {
356        match &self.block.attention.position_encoding {
357            PositionEncoding::Rope { theta, .. } => *theta,
358            _ => 10000.0,
359        }
360    }
361
362    pub fn use_swiglu(&self) -> bool {
363        matches!(self.block.ffn.activation, Activation::SwiGLU)
364    }
365
366    pub fn use_rmsnorm(&self) -> bool {
367        matches!(self.block.norm.norm_type, NormType::RmsNorm)
368    }
369
370    /// Estimate total parameters
371    pub fn estimated_params(&self) -> usize {
372        let embed_params = self.vocab_size * self.hidden_size;
373        let attn_params = 4 * self.hidden_size * self.hidden_size;
374        let ff_params = 3 * self.hidden_size * self.intermediate_size();
375        let layer_params = attn_params + ff_params + 2 * self.hidden_size;
376        let head_params = self.hidden_size * self.vocab_size;
377        embed_params + self.num_layers * layer_params + head_params
378    }
379
380    /// Load from JSON file
381    pub fn from_json(path: &str) -> Result<Self> {
382        let content = std::fs::read_to_string(path)?;
383        Ok(serde_json::from_str(&content)?)
384    }
385
386    /// Save to JSON file
387    pub fn save_json(&self, path: &str) -> Result<()> {
388        let content = serde_json::to_string_pretty(self)?;
389        std::fs::write(path, content)?;
390        Ok(())
391    }
392}
393
394/// Complete parsed MAL file with all definitions
395#[derive(Debug, Clone, Default)]
396pub struct MalFile {
397    pub attentions: HashMap<String, AttentionDef>,
398    pub ffns: HashMap<String, FfnDef>,
399    pub blocks: HashMap<String, BlockDef>,
400    pub models: HashMap<String, ModelDef>,
401}
402
403// ============================================================================
404// Parsing Functions
405// ============================================================================
406
407/// Parse activation type from string
408fn parse_activation(s: &str) -> Activation {
409    match s {
410        "swiglu" => Activation::SwiGLU,
411        "gelu" => Activation::GELU,
412        "silu" => Activation::SiLU,
413        "relu" => Activation::ReLU,
414        "gelu_new" => Activation::GELUNew,
415        "gelu_tanh" => Activation::GELUTanh,
416        _ => Activation::SwiGLU,
417    }
418}
419
420/// Parse a model property (block-based only)
421fn parse_model_prop(
422    pair: pest::iterators::Pair<Rule>,
423    def: &mut ModelDef,
424    file: &MalFile,
425) -> Result<()> {
426    for inner in pair.into_inner() {
427        match inner.as_rule() {
428            Rule::vocab_size_prop => {
429                if let Some(val) = inner.into_inner().next() {
430                    def.vocab_size = val.as_str().parse()?;
431                }
432            }
433            Rule::max_seq_len_prop => {
434                if let Some(val) = inner.into_inner().next() {
435                    def.max_seq_len = val.as_str().parse()?;
436                }
437            }
438            Rule::hidden_size_prop => {
439                if let Some(val) = inner.into_inner().next() {
440                    def.hidden_size = val.as_str().parse()?;
441                }
442            }
443            Rule::num_layers_prop => {
444                if let Some(val) = inner.into_inner().next() {
445                    def.num_layers = val.as_str().parse()?;
446                }
447            }
448            Rule::block_ref_prop => {
449                for child in inner.into_inner() {
450                    match child.as_rule() {
451                        Rule::identifier => {
452                            let name = child.as_str();
453                            if let Some(block) = file.blocks.get(name) {
454                                def.block = block.clone();
455                            }
456                        }
457                        Rule::inline_block => {
458                            let mut block = BlockDef::default();
459                            for prop in child.into_inner() {
460                                if prop.as_rule() == Rule::block_prop {
461                                    parse_block_prop(prop, &mut block, file)?;
462                                }
463                            }
464                            def.block = block;
465                        }
466                        _ => {}
467                    }
468                }
469            }
470            Rule::description_prop => {
471                if let Some(val) = inner.into_inner().next() {
472                    let s = val.as_str();
473                    def.description = Some(s[1..s.len() - 1].to_string());
474                }
475            }
476            _ => {}
477        }
478    }
479    Ok(())
480}
481
482/// Parse a model definition from pest pair
483fn parse_model_def(pair: pest::iterators::Pair<Rule>, file: &MalFile) -> Result<ModelDef> {
484    let mut def = ModelDef::default();
485    let mut inner = pair.into_inner();
486
487    // Get model name
488    if let Some(name) = inner.next() {
489        def.name = name.as_str().to_string();
490    }
491
492    // Parse properties
493    for prop in inner {
494        if prop.as_rule() == Rule::model_prop {
495            parse_model_prop(prop, &mut def, file)?;
496        }
497    }
498
499    Ok(def)
500}
501
502/// Parse MAL from a string (returns first model found)
503pub fn parse_mal(input: &str) -> Result<ModelDef> {
504    let file = parse_mal_full(input)?;
505    file.models
506        .into_values()
507        .next()
508        .ok_or_else(|| anyhow!("No model definition found"))
509}
510
511/// Parse complete MAL file with all definitions
512pub fn parse_mal_full(input: &str) -> Result<MalFile> {
513    let pairs = MalParser::parse(Rule::file, input).map_err(|e| anyhow!("Parse error: {}", e))?;
514
515    let mut file = MalFile::default();
516
517    for pair in pairs {
518        if pair.as_rule() == Rule::file {
519            for inner in pair.into_inner() {
520                if inner.as_rule() == Rule::definition {
521                    for def in inner.into_inner() {
522                        match def.as_rule() {
523                            Rule::model_def => {
524                                let model = parse_model_def(def, &file)?;
525                                file.models.insert(model.name.clone(), model);
526                            }
527                            Rule::attention_def => {
528                                let attn = parse_attention_def(def)?;
529                                file.attentions.insert(attn.name.clone(), attn);
530                            }
531                            Rule::ffn_def => {
532                                let ffn = parse_ffn_def(def)?;
533                                file.ffns.insert(ffn.name.clone(), ffn);
534                            }
535                            Rule::block_def => {
536                                let block = parse_block_def(def, &file)?;
537                                file.blocks.insert(block.name.clone(), block);
538                            }
539                            _ => {}
540                        }
541                    }
542                }
543            }
544        }
545    }
546
547    Ok(file)
548}
549
550/// Parse an attention definition
551fn parse_attention_def(pair: pest::iterators::Pair<Rule>) -> Result<AttentionDef> {
552    let mut def = AttentionDef::default();
553    let mut inner = pair.into_inner();
554
555    if let Some(name) = inner.next() {
556        def.name = name.as_str().to_string();
557    }
558
559    for prop in inner {
560        if prop.as_rule() == Rule::attention_prop {
561            parse_attention_prop(prop, &mut def)?;
562        }
563    }
564
565    Ok(def)
566}
567
568/// Parse attention properties
569fn parse_attention_prop(pair: pest::iterators::Pair<Rule>, def: &mut AttentionDef) -> Result<()> {
570    for inner in pair.into_inner() {
571        match inner.as_rule() {
572            Rule::num_heads_prop => {
573                if let Some(val) = inner.into_inner().next() {
574                    def.num_heads = Some(val.as_str().parse()?);
575                }
576            }
577            Rule::num_kv_heads_prop => {
578                if let Some(val) = inner.into_inner().next() {
579                    def.num_kv_heads = Some(val.as_str().parse()?);
580                }
581            }
582            Rule::head_dim_prop => {
583                if let Some(val) = inner.into_inner().next() {
584                    def.head_dim = Some(val.as_str().parse()?);
585                }
586            }
587            Rule::dropout_prop => {
588                if let Some(val) = inner.into_inner().next() {
589                    def.dropout = val.as_str().parse()?;
590                }
591            }
592            Rule::bias_prop => {
593                if let Some(val) = inner.into_inner().next() {
594                    def.bias = val.as_str() == "true";
595                }
596            }
597            Rule::causal_prop => {
598                if let Some(val) = inner.into_inner().next() {
599                    def.causal = val.as_str() == "true";
600                }
601            }
602            Rule::window_size_prop => {
603                if let Some(val) = inner.into_inner().next() {
604                    def.window_size = Some(val.as_str().parse()?);
605                }
606            }
607            _ => {}
608        }
609    }
610    Ok(())
611}
612
613/// Parse an FFN definition
614fn parse_ffn_def(pair: pest::iterators::Pair<Rule>) -> Result<FfnDef> {
615    let mut def = FfnDef::default();
616    let mut inner = pair.into_inner();
617
618    if let Some(name) = inner.next() {
619        def.name = name.as_str().to_string();
620    }
621
622    for prop in inner {
623        if prop.as_rule() == Rule::ffn_prop {
624            parse_ffn_prop(prop, &mut def)?;
625        }
626    }
627
628    Ok(def)
629}
630
631/// Parse FFN properties
632fn parse_ffn_prop(pair: pest::iterators::Pair<Rule>, def: &mut FfnDef) -> Result<()> {
633    for inner in pair.into_inner() {
634        match inner.as_rule() {
635            Rule::hidden_dim_prop => {
636                if let Some(val) = inner.into_inner().next() {
637                    def.hidden_dim = Some(val.as_str().parse()?);
638                }
639            }
640            Rule::activation_prop => {
641                if let Some(val) = inner.into_inner().next() {
642                    def.activation = parse_activation(val.as_str());
643                }
644            }
645            Rule::bias_prop => {
646                if let Some(val) = inner.into_inner().next() {
647                    def.bias = val.as_str() == "true";
648                }
649            }
650            Rule::dropout_prop => {
651                if let Some(val) = inner.into_inner().next() {
652                    def.dropout = val.as_str().parse()?;
653                }
654            }
655            Rule::gate_prop => {
656                if let Some(val) = inner.into_inner().next() {
657                    def.gate = val.as_str() == "true";
658                }
659            }
660            _ => {}
661        }
662    }
663    Ok(())
664}
665
666/// Parse a block definition
667fn parse_block_def(pair: pest::iterators::Pair<Rule>, file: &MalFile) -> Result<BlockDef> {
668    let mut def = BlockDef::default();
669    let mut inner = pair.into_inner();
670
671    if let Some(name) = inner.next() {
672        def.name = name.as_str().to_string();
673    }
674
675    for prop in inner {
676        if prop.as_rule() == Rule::block_prop {
677            parse_block_prop(prop, &mut def, file)?;
678        }
679    }
680
681    Ok(def)
682}
683
684/// Parse block properties
685fn parse_block_prop(
686    pair: pest::iterators::Pair<Rule>,
687    def: &mut BlockDef,
688    file: &MalFile,
689) -> Result<()> {
690    for inner in pair.into_inner() {
691        match inner.as_rule() {
692            Rule::attention_ref_prop => {
693                // Can be identifier or inline definition
694                for child in inner.into_inner() {
695                    match child.as_rule() {
696                        Rule::identifier => {
697                            let name = child.as_str();
698                            if let Some(attn) = file.attentions.get(name) {
699                                def.attention = attn.clone();
700                            }
701                        }
702                        Rule::inline_attention => {
703                            let mut attn = AttentionDef::default();
704                            for prop in child.into_inner() {
705                                if prop.as_rule() == Rule::attention_prop {
706                                    parse_attention_prop(prop, &mut attn)?;
707                                }
708                            }
709                            def.attention = attn;
710                        }
711                        _ => {}
712                    }
713                }
714            }
715            Rule::ffn_ref_prop => {
716                for child in inner.into_inner() {
717                    match child.as_rule() {
718                        Rule::identifier => {
719                            let name = child.as_str();
720                            if let Some(ffn) = file.ffns.get(name) {
721                                def.ffn = ffn.clone();
722                            }
723                        }
724                        Rule::inline_ffn => {
725                            let mut ffn = FfnDef::default();
726                            for prop in child.into_inner() {
727                                if prop.as_rule() == Rule::ffn_prop {
728                                    parse_ffn_prop(prop, &mut ffn)?;
729                                }
730                            }
731                            def.ffn = ffn;
732                        }
733                        _ => {}
734                    }
735                }
736            }
737            Rule::norm_position_prop => {
738                if let Some(val) = inner.into_inner().next() {
739                    def.norm_position = match val.as_str() {
740                        "pre" => NormPosition::Pre,
741                        "post" => NormPosition::Post,
742                        _ => NormPosition::Pre,
743                    };
744                }
745            }
746            Rule::residual_prop => {
747                if let Some(val) = inner.into_inner().next() {
748                    def.residual = val.as_str() == "true";
749                }
750            }
751            Rule::dropout_prop => {
752                if let Some(val) = inner.into_inner().next() {
753                    def.dropout = val.as_str().parse()?;
754                }
755            }
756            _ => {}
757        }
758    }
759    Ok(())
760}
761
762/// Parse MAL from a file
763pub fn parse_mal_file<P: AsRef<std::path::Path>>(path: P) -> Result<ModelDef> {
764    let content = std::fs::read_to_string(path)?;
765    parse_mal(&content)
766}
767
768// ============================================================================
769// Built-in model definitions
770// ============================================================================
771
772/// Get a well-known model definition by name
773///
774/// Accepts:
775/// - Short names: "nano", "tiny", "gpt2-small", etc.
776/// - Well-known paths: "well-known/nano.mal", "well-known/gpt2_small.mal"
777/// - Filenames: "nano.mal", "gpt2_small.mal"
778pub fn get_builtin_model(name: &str) -> Option<ModelDef> {
779    let mal = get_wellknown_mal(name)?;
780    parse_mal(&mal).ok()
781}
782
783/// Get the raw MAL content for a well-known model
784///
785/// Dynamically loads from embedded well-known/ directory.
786pub fn get_wellknown_mal(name: &str) -> Option<String> {
787    // Normalize: strip well-known/ prefix, ensure .mal suffix
788    let name = name.strip_prefix("well-known/").unwrap_or(name);
789    let filename = if name.ends_with(".mal") {
790        name.to_string()
791    } else {
792        // Convert kebab-case to snake_case for filename
793        format!("{}.mal", name.replace('-', "_"))
794    };
795
796    WellKnown::get(&filename).map(|f| String::from_utf8_lossy(&f.data).into_owned())
797}
798
799/// List all well-known model names (auto-discovered from embedded files)
800pub fn list_wellknown_models() -> Vec<String> {
801    WellKnown::iter()
802        .filter_map(|path| {
803            let path: &str = path.as_ref();
804            if path.ends_with(".mal") {
805                Some(path.strip_suffix(".mal").unwrap().replace('_', "-"))
806            } else {
807                None
808            }
809        })
810        .collect()
811}
812
813#[cfg(test)]
814mod tests {
815    use super::*;
816
817    #[test]
818    fn test_parse_simple_model() {
819        let mal = r#"
820            attention test_attn {
821                num_heads: 8
822                bias: false
823            }
824
825            ffn test_ffn {
826                hidden_dim: 2048
827                activation: gelu
828            }
829
830            block test_block {
831                attention: test_attn
832                ffn: test_ffn
833                norm_position: pre
834            }
835
836            model test {
837                vocab_size: 32000
838                hidden_size: 512
839                num_layers: 8
840                block: test_block
841            }
842        "#;
843
844        let def = parse_mal(mal).unwrap();
845        assert_eq!(def.name, "test");
846        assert_eq!(def.vocab_size, 32000);
847        assert_eq!(def.hidden_size, 512);
848        assert_eq!(def.num_layers, 8);
849    }
850
851    #[test]
852    fn test_parse_with_block_props() {
853        let mal = r#"
854            attention full_attn {
855                num_heads: 16
856                num_kv_heads: 4
857                bias: true
858                dropout: 0.1
859            }
860
861            ffn full_ffn {
862                hidden_dim: 4096
863                activation: gelu
864                bias: true
865                dropout: 0.1
866            }
867
868            block full_block {
869                attention: full_attn
870                ffn: full_ffn
871                norm: layernorm { eps: 1e-6 }
872                norm_position: pre
873                residual: true
874            }
875
876            model full_test {
877                description: "A test model"
878                vocab_size: 50000
879                max_seq_len: 4096
880                hidden_size: 1024
881                num_layers: 12
882                block: full_block
883            }
884        "#;
885
886        let def = parse_mal(mal).unwrap();
887        assert_eq!(def.description, Some("A test model".to_string()));
888        assert_eq!(def.vocab_size, 50000);
889        assert_eq!(def.max_seq_len, 4096);
890        assert_eq!(def.block.attention.num_heads, Some(16));
891        assert_eq!(def.block.attention.num_kv_heads, Some(4));
892        assert_eq!(def.block.ffn.hidden_dim, Some(4096));
893        assert!(matches!(def.block.ffn.activation, Activation::GELU));
894    }
895
896    #[test]
897    fn test_wellknown_models() {
898        for name in list_wellknown_models() {
899            let def = get_builtin_model(&name).unwrap_or_else(|| panic!("Failed to get {}", name));
900            // Verify computed properties work
901            assert!(def.num_heads() > 0);
902            assert!(def.intermediate_size() > 0);
903        }
904    }
905
906    #[test]
907    fn test_model_properties() {
908        let def = get_builtin_model("tiny").unwrap();
909
910        assert_eq!(def.vocab_size, 32000);
911        assert_eq!(def.hidden_size, 128);
912        assert_eq!(def.num_layers, 4);
913        assert_eq!(def.num_heads(), 4);
914    }
915
916    #[test]
917    fn test_comments() {
918        let mal = r#"
919            # This is a comment
920            attention test_attn {
921                # Comment in attention
922                num_heads: 2
923            }
924
925            ffn test_ffn {
926                hidden_dim: 256
927            }
928
929            block test_block {
930                attention: test_attn
931                ffn: test_ffn
932            }
933
934            # Comment before model
935            model test {
936                vocab_size: 1000
937                hidden_size: 64
938                num_layers: 2
939                block: test_block
940            }
941        "#;
942
943        let def = parse_mal(mal).unwrap();
944        assert_eq!(def.vocab_size, 1000);
945    }
946
947    #[test]
948    fn test_composable_architecture() {
949        let mal = r#"
950            attention my_attn {
951                num_heads: 16
952                num_kv_heads: 4
953                head_dim: 128
954                bias: false
955            }
956
957            ffn my_ffn {
958                hidden_dim: 11008
959                activation: swiglu
960                bias: false
961            }
962
963            block my_block {
964                attention: my_attn
965                ffn: my_ffn
966                norm: rmsnorm { eps: 1e-5 }
967                norm_position: pre
968                residual: true
969            }
970
971            model my_model {
972                description: "LLaMA 7B architecture"
973                vocab_size: 32000
974                max_seq_len: 4096
975                hidden_size: 4096
976                num_layers: 32
977                block: my_block
978            }
979        "#;
980
981        let file = parse_mal_full(mal).unwrap();
982
983        assert!(file.attentions.contains_key("my_attn"));
984        assert!(file.ffns.contains_key("my_ffn"));
985        assert!(file.blocks.contains_key("my_block"));
986        assert!(file.models.contains_key("my_model"));
987
988        let attn = file.attentions.get("my_attn").unwrap();
989        assert_eq!(attn.num_heads, Some(16));
990        assert_eq!(attn.num_kv_heads, Some(4));
991
992        let ffn = file.ffns.get("my_ffn").unwrap();
993        assert_eq!(ffn.hidden_dim, Some(11008));
994        assert!(matches!(ffn.activation, Activation::SwiGLU));
995
996        let block = file.blocks.get("my_block").unwrap();
997        assert!(matches!(block.norm_position, NormPosition::Pre));
998        assert!(block.residual);
999    }
1000}