Skip to main content

aprender/format/
model_family.rs

1//! Model Family Contract Types (PMAT-241)
2//!
3//! Defines the `ModelFamily` trait and associated configuration types for
4//! compiler-enforced model family contracts.
5//!
6//! # Theoretical Foundation
7//!
8//! - Shingo (1986): Poka-Yoke / Zero Quality Control
9//! - Strom & Yemini (1986): Typestate programming
10//! - Parsons (2019): Parse, Don't Validate
11//!
12//! # Contract
13//!
14//! See `contracts/model-families/*.yaml` and
15//! `docs/specifications/compiler-enforced-model-types-model-oracle.md`
16
17use std::collections::HashMap;
18use std::fmt;
19
20use crate::error::{AprenderError, Result};
21
22// ============================================================================
23// Enums
24// ============================================================================
25
26/// Attention mechanism type
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum AttentionType {
29    /// Multi-Head Attention (standard transformer)
30    Mha,
31    /// Grouped Query Attention (GQA)
32    Gqa,
33    /// Multi-Query Attention (MQA)
34    Mqa,
35}
36
37impl fmt::Display for AttentionType {
38    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39        match self {
40            Self::Mha => write!(f, "MHA"),
41            Self::Gqa => write!(f, "GQA"),
42            Self::Mqa => write!(f, "MQA"),
43        }
44    }
45}
46
47impl AttentionType {
48    /// Parse from YAML string
49    pub fn from_str_contract(s: &str) -> Result<Self> {
50        match s.to_lowercase().as_str() {
51            "mha" => Ok(Self::Mha),
52            "gqa" => Ok(Self::Gqa),
53            "mqa" => Ok(Self::Mqa),
54            _ => Err(AprenderError::FormatError {
55                message: format!("Unknown attention type: {s}. Expected: mha, gqa, mqa"),
56            }),
57        }
58    }
59}
60
61/// Activation function type
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum Activation {
64    /// SiLU (Swish) activation
65    Silu,
66    /// GELU activation
67    Gelu,
68    /// ReLU activation
69    Relu,
70}
71
72impl fmt::Display for Activation {
73    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74        match self {
75            Self::Silu => write!(f, "SiLU"),
76            Self::Gelu => write!(f, "GELU"),
77            Self::Relu => write!(f, "ReLU"),
78        }
79    }
80}
81
82impl Activation {
83    pub fn from_str_contract(s: &str) -> Result<Self> {
84        match s.to_lowercase().as_str() {
85            "silu" | "swish" => Ok(Self::Silu),
86            "gelu" => Ok(Self::Gelu),
87            "relu" => Ok(Self::Relu),
88            _ => Err(AprenderError::FormatError {
89                message: format!("Unknown activation: {s}. Expected: silu, gelu, relu"),
90            }),
91        }
92    }
93}
94
95/// Normalization type
96#[derive(Debug, Clone, Copy, PartialEq, Eq)]
97pub enum NormType {
98    /// RMS Normalization (LLaMA, Qwen2)
99    RmsNorm,
100    /// Layer Normalization (BERT, Whisper)
101    LayerNorm,
102}
103
104impl fmt::Display for NormType {
105    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106        match self {
107            Self::RmsNorm => write!(f, "RMSNorm"),
108            Self::LayerNorm => write!(f, "LayerNorm"),
109        }
110    }
111}
112
113impl NormType {
114    pub fn from_str_contract(s: &str) -> Result<Self> {
115        match s.to_lowercase().as_str() {
116            "rmsnorm" | "rms_norm" => Ok(Self::RmsNorm),
117            "layernorm" | "layer_norm" => Ok(Self::LayerNorm),
118            _ => Err(AprenderError::FormatError {
119                message: format!("Unknown norm type: {s}. Expected: rmsnorm, layernorm"),
120            }),
121        }
122    }
123}
124
125/// Positional encoding type
126#[derive(Debug, Clone, Copy, PartialEq, Eq)]
127pub enum PositionalEncoding {
128    /// Rotary Position Embeddings (LLaMA, Qwen2)
129    Rope,
130    /// ALiBi (Bloom)
131    Alibi,
132    /// Absolute position embeddings (BERT, Whisper)
133    Absolute,
134    /// Relative position embeddings
135    Relative,
136}
137
138impl fmt::Display for PositionalEncoding {
139    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140        match self {
141            Self::Rope => write!(f, "RoPE"),
142            Self::Alibi => write!(f, "ALiBi"),
143            Self::Absolute => write!(f, "Absolute"),
144            Self::Relative => write!(f, "Relative"),
145        }
146    }
147}
148
149impl PositionalEncoding {
150    pub fn from_str_contract(s: &str) -> Result<Self> {
151        match s.to_lowercase().as_str() {
152            "rope" => Ok(Self::Rope),
153            "alibi" => Ok(Self::Alibi),
154            "absolute" | "sinusoidal" => Ok(Self::Absolute),
155            "relative" => Ok(Self::Relative),
156            _ => Err(AprenderError::FormatError {
157                message: format!(
158                    "Unknown positional encoding: {s}. Expected: rope, alibi, absolute, relative"
159                ),
160            }),
161        }
162    }
163}
164
165/// MLP type
166#[derive(Debug, Clone, Copy, PartialEq, Eq)]
167pub enum MlpType {
168    /// SwiGLU (LLaMA, Qwen2) - gated with SiLU
169    SwiGlu,
170    /// Standard GELU MLP (BERT, Whisper)
171    GeluMlp,
172    /// Gated MLP (generic)
173    GatedMlp,
174}
175
176impl fmt::Display for MlpType {
177    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
178        match self {
179            Self::SwiGlu => write!(f, "SwiGLU"),
180            Self::GeluMlp => write!(f, "GELU MLP"),
181            Self::GatedMlp => write!(f, "Gated MLP"),
182        }
183    }
184}
185
186impl MlpType {
187    pub fn from_str_contract(s: &str) -> Result<Self> {
188        match s.to_lowercase().as_str() {
189            "swiglu" => Ok(Self::SwiGlu),
190            "gelu_mlp" | "gelu" => Ok(Self::GeluMlp),
191            "gated_mlp" | "gated" => Ok(Self::GatedMlp),
192            _ => Err(AprenderError::FormatError {
193                message: format!("Unknown MLP type: {s}. Expected: swiglu, gelu_mlp, gated_mlp"),
194            }),
195        }
196    }
197}
198
199// ============================================================================
200// Configuration Structs
201// ============================================================================
202
203/// Configuration for a specific model size within a family.
204#[derive(Debug, Clone)]
205pub struct ModelSizeConfig {
206    /// Human-readable parameter count (e.g., "0.5B", "7B")
207    pub parameters: String,
208    /// Hidden dimension
209    pub hidden_dim: usize,
210    /// Number of transformer layers
211    pub num_layers: usize,
212    /// Number of attention heads
213    pub num_heads: usize,
214    /// Number of key-value heads (for GQA)
215    pub num_kv_heads: usize,
216    /// Intermediate (FFN) dimension
217    pub intermediate_dim: usize,
218    /// Vocabulary size
219    pub vocab_size: usize,
220    /// Maximum position embeddings
221    pub max_position_embeddings: usize,
222    /// Per-head dimension (`hidden_dim / num_heads`)
223    pub head_dim: usize,
224    /// RoPE theta frequency (0.0 if not using RoPE)
225    pub rope_theta: f64,
226    /// Normalization epsilon
227    pub norm_eps: f64,
228}
229
230/// Architectural constraints for a model family.
231#[derive(Debug, Clone)]
232pub struct ModelConstraints {
233    pub attention_type: AttentionType,
234    pub activation: Activation,
235    pub norm_type: NormType,
236    pub has_bias: bool,
237    pub tied_embeddings: bool,
238    pub positional_encoding: PositionalEncoding,
239    pub mlp_type: MlpType,
240}
241
242/// Tensor name template for a model family.
243#[derive(Debug, Clone)]
244pub struct TensorTemplate {
245    /// Embedding tensor name (e.g., "model.embed\_tokens.weight")
246    pub embedding: String,
247    /// LM head tensor name (e.g., "lm\_head.weight")
248    pub lm_head: Option<String>,
249    /// Final normalization tensor name
250    pub final_norm: Option<String>,
251    /// Per-layer tensor name patterns (keys: q\_proj, k\_proj, etc., values contain {n} placeholder)
252    pub per_layer: HashMap<String, Option<String>>,
253}
254
255/// Shape template for a model family (parameterized expressions).
256#[derive(Debug, Clone)]
257pub struct ShapeTemplate {
258    /// Map of tensor role to parameterized shape expression
259    /// e.g., "q\_proj" maps to "\[num\_heads * head\_dim, hidden\_dim\]"
260    pub shapes: HashMap<String, String>,
261}
262
263/// Chat template configuration.
264#[derive(Debug, Clone)]
265pub struct ChatTemplateConfig {
266    pub format: String,
267    pub template: String,
268    pub bos_token: String,
269    pub eos_token: String,
270    pub special_tokens: HashMap<String, String>,
271}
272
273/// Certification cross-reference configuration.
274#[derive(Debug, Clone)]
275pub struct CertificationConfig {
276    pub playbook_path: String,
277    pub csv_family_key: String,
278    pub size_categories: HashMap<String, String>,
279}
280
281/// Complete configuration for a model family.
282#[derive(Debug, Clone)]
283pub struct ModelFamilyConfig {
284    /// Canonical family name (e.g., "qwen2")
285    pub family: String,
286    /// Human-readable display name
287    pub display_name: String,
288    /// Vendor/organization
289    pub vendor: String,
290    /// HuggingFace architecture identifiers
291    pub architectures: Vec<String>,
292    /// HuggingFace repo name pattern
293    pub hf_pattern: String,
294    /// Size variants keyed by name (e.g., "0.5b", "7b")
295    pub size_variants: HashMap<String, ModelSizeConfig>,
296    /// Architectural constraints
297    pub constraints: ModelConstraints,
298    /// Tensor name template
299    pub tensor_template: TensorTemplate,
300    /// Shape template
301    pub shape_template: ShapeTemplate,
302    /// Supported quantization formats
303    pub quantizations: Vec<String>,
304    /// Chat template (None for non-chat models like Whisper, BERT)
305    pub chat_template: Option<ChatTemplateConfig>,
306    /// Certification cross-reference
307    pub certification: Option<CertificationConfig>,
308}
309
310// ============================================================================
311// Contract Error
312// ============================================================================
313
314/// Model family contract error
315#[derive(Debug, Clone)]
316pub struct ContractError {
317    pub family: String,
318    pub message: String,
319}
320
321impl fmt::Display for ContractError {
322    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
323        write!(
324            f,
325            "Model family contract error [{}]: {}",
326            self.family, self.message
327        )
328    }
329}
330
331impl std::error::Error for ContractError {}
332
333impl From<ContractError> for AprenderError {
334    fn from(err: ContractError) -> Self {
335        AprenderError::FormatError {
336            message: err.to_string(),
337        }
338    }
339}
340
341// ============================================================================
342// ModelFamily Trait
343// ============================================================================
344
345/// Trait implemented by each model family.
346///
347/// This trait is the compile-time bridge between YAML contracts and Rust code.
348/// Implementations can be generated by build.rs from model family YAMLs (PMAT-250)
349/// or loaded at runtime from YAML files (PMAT-242).
350pub trait ModelFamily: fmt::Debug + Send + Sync {
351    /// Canonical family name (e.g., "qwen2")
352    fn family_name(&self) -> &str;
353
354    /// Human-readable display name
355    fn display_name(&self) -> &str;
356
357    /// Get the full configuration
358    fn config(&self) -> &ModelFamilyConfig;
359
360    /// Get configuration for a specific size variant
361    fn size_config(&self, size: &str) -> Option<&ModelSizeConfig>;
362
363    /// Detect size variant from model config (`hidden_dim`, `num_layers`)
364    fn detect_size(&self, hidden_dim: usize, num_layers: usize) -> Option<String>;
365
366    /// Get architectural constraints
367    fn constraints(&self) -> &ModelConstraints;
368
369    /// Expected tensor count for a given size variant
370    fn expected_tensor_count(&self, size: &str) -> Option<usize>;
371
372    /// Validate that a set of tensor names matches the contract
373    fn validate_tensor_names(
374        &self,
375        names: &[&str],
376        size: &str,
377    ) -> std::result::Result<(), ContractError>;
378}
379
380// ============================================================================
381// DynModelFamily - Runtime implementation backed by ModelFamilyConfig
382// ============================================================================
383
384/// Dynamic model family implementation backed by a `ModelFamilyConfig`.
385/// Used when family is loaded from YAML at runtime.
386#[derive(Debug, Clone)]
387pub struct DynModelFamily {
388    config: ModelFamilyConfig,
389}
390
391impl DynModelFamily {
392    /// Create from a loaded config
393    #[must_use]
394    pub fn new(config: ModelFamilyConfig) -> Self {
395        Self { config }
396    }
397}
398
399include!("model_family_part_02.rs");
400include!("model_family_part_03.rs");