1use std::collections::HashMap;
18use std::fmt;
19
20use crate::error::{AprenderError, Result};
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum AttentionType {
29 Mha,
31 Gqa,
33 Mqa,
35}
36
37impl fmt::Display for AttentionType {
38 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39 match self {
40 Self::Mha => write!(f, "MHA"),
41 Self::Gqa => write!(f, "GQA"),
42 Self::Mqa => write!(f, "MQA"),
43 }
44 }
45}
46
47impl AttentionType {
48 pub fn from_str_contract(s: &str) -> Result<Self> {
50 match s.to_lowercase().as_str() {
51 "mha" => Ok(Self::Mha),
52 "gqa" => Ok(Self::Gqa),
53 "mqa" => Ok(Self::Mqa),
54 _ => Err(AprenderError::FormatError {
55 message: format!("Unknown attention type: {s}. Expected: mha, gqa, mqa"),
56 }),
57 }
58 }
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum Activation {
64 Silu,
66 Gelu,
68 Relu,
70}
71
72impl fmt::Display for Activation {
73 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74 match self {
75 Self::Silu => write!(f, "SiLU"),
76 Self::Gelu => write!(f, "GELU"),
77 Self::Relu => write!(f, "ReLU"),
78 }
79 }
80}
81
82impl Activation {
83 pub fn from_str_contract(s: &str) -> Result<Self> {
84 match s.to_lowercase().as_str() {
85 "silu" | "swish" => Ok(Self::Silu),
86 "gelu" => Ok(Self::Gelu),
87 "relu" => Ok(Self::Relu),
88 _ => Err(AprenderError::FormatError {
89 message: format!("Unknown activation: {s}. Expected: silu, gelu, relu"),
90 }),
91 }
92 }
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq)]
97pub enum NormType {
98 RmsNorm,
100 LayerNorm,
102}
103
104impl fmt::Display for NormType {
105 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106 match self {
107 Self::RmsNorm => write!(f, "RMSNorm"),
108 Self::LayerNorm => write!(f, "LayerNorm"),
109 }
110 }
111}
112
113impl NormType {
114 pub fn from_str_contract(s: &str) -> Result<Self> {
115 match s.to_lowercase().as_str() {
116 "rmsnorm" | "rms_norm" => Ok(Self::RmsNorm),
117 "layernorm" | "layer_norm" => Ok(Self::LayerNorm),
118 _ => Err(AprenderError::FormatError {
119 message: format!("Unknown norm type: {s}. Expected: rmsnorm, layernorm"),
120 }),
121 }
122 }
123}
124
125#[derive(Debug, Clone, Copy, PartialEq, Eq)]
127pub enum PositionalEncoding {
128 Rope,
130 Alibi,
132 Absolute,
134 Relative,
136}
137
138impl fmt::Display for PositionalEncoding {
139 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140 match self {
141 Self::Rope => write!(f, "RoPE"),
142 Self::Alibi => write!(f, "ALiBi"),
143 Self::Absolute => write!(f, "Absolute"),
144 Self::Relative => write!(f, "Relative"),
145 }
146 }
147}
148
149impl PositionalEncoding {
150 pub fn from_str_contract(s: &str) -> Result<Self> {
151 match s.to_lowercase().as_str() {
152 "rope" => Ok(Self::Rope),
153 "alibi" => Ok(Self::Alibi),
154 "absolute" | "sinusoidal" => Ok(Self::Absolute),
155 "relative" => Ok(Self::Relative),
156 _ => Err(AprenderError::FormatError {
157 message: format!(
158 "Unknown positional encoding: {s}. Expected: rope, alibi, absolute, relative"
159 ),
160 }),
161 }
162 }
163}
164
165#[derive(Debug, Clone, Copy, PartialEq, Eq)]
167pub enum MlpType {
168 SwiGlu,
170 GeluMlp,
172 GatedMlp,
174}
175
176impl fmt::Display for MlpType {
177 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
178 match self {
179 Self::SwiGlu => write!(f, "SwiGLU"),
180 Self::GeluMlp => write!(f, "GELU MLP"),
181 Self::GatedMlp => write!(f, "Gated MLP"),
182 }
183 }
184}
185
186impl MlpType {
187 pub fn from_str_contract(s: &str) -> Result<Self> {
188 match s.to_lowercase().as_str() {
189 "swiglu" => Ok(Self::SwiGlu),
190 "gelu_mlp" | "gelu" => Ok(Self::GeluMlp),
191 "gated_mlp" | "gated" => Ok(Self::GatedMlp),
192 _ => Err(AprenderError::FormatError {
193 message: format!("Unknown MLP type: {s}. Expected: swiglu, gelu_mlp, gated_mlp"),
194 }),
195 }
196 }
197}
198
199#[derive(Debug, Clone)]
205pub struct ModelSizeConfig {
206 pub parameters: String,
208 pub hidden_dim: usize,
210 pub num_layers: usize,
212 pub num_heads: usize,
214 pub num_kv_heads: usize,
216 pub intermediate_dim: usize,
218 pub vocab_size: usize,
220 pub max_position_embeddings: usize,
222 pub head_dim: usize,
224 pub rope_theta: f64,
226 pub norm_eps: f64,
228}
229
230#[derive(Debug, Clone)]
232pub struct ModelConstraints {
233 pub attention_type: AttentionType,
234 pub activation: Activation,
235 pub norm_type: NormType,
236 pub has_bias: bool,
237 pub tied_embeddings: bool,
238 pub positional_encoding: PositionalEncoding,
239 pub mlp_type: MlpType,
240}
241
242#[derive(Debug, Clone)]
244pub struct TensorTemplate {
245 pub embedding: String,
247 pub lm_head: Option<String>,
249 pub final_norm: Option<String>,
251 pub per_layer: HashMap<String, Option<String>>,
253}
254
255#[derive(Debug, Clone)]
257pub struct ShapeTemplate {
258 pub shapes: HashMap<String, String>,
261}
262
263#[derive(Debug, Clone)]
265pub struct ChatTemplateConfig {
266 pub format: String,
267 pub template: String,
268 pub bos_token: String,
269 pub eos_token: String,
270 pub special_tokens: HashMap<String, String>,
271}
272
273#[derive(Debug, Clone)]
275pub struct CertificationConfig {
276 pub playbook_path: String,
277 pub csv_family_key: String,
278 pub size_categories: HashMap<String, String>,
279}
280
281#[derive(Debug, Clone)]
283pub struct ModelFamilyConfig {
284 pub family: String,
286 pub display_name: String,
288 pub vendor: String,
290 pub architectures: Vec<String>,
292 pub hf_pattern: String,
294 pub size_variants: HashMap<String, ModelSizeConfig>,
296 pub constraints: ModelConstraints,
298 pub tensor_template: TensorTemplate,
300 pub shape_template: ShapeTemplate,
302 pub quantizations: Vec<String>,
304 pub chat_template: Option<ChatTemplateConfig>,
306 pub certification: Option<CertificationConfig>,
308}
309
310#[derive(Debug, Clone)]
316pub struct ContractError {
317 pub family: String,
318 pub message: String,
319}
320
321impl fmt::Display for ContractError {
322 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
323 write!(
324 f,
325 "Model family contract error [{}]: {}",
326 self.family, self.message
327 )
328 }
329}
330
331impl std::error::Error for ContractError {}
332
333impl From<ContractError> for AprenderError {
334 fn from(err: ContractError) -> Self {
335 AprenderError::FormatError {
336 message: err.to_string(),
337 }
338 }
339}
340
341pub trait ModelFamily: fmt::Debug + Send + Sync {
351 fn family_name(&self) -> &str;
353
354 fn display_name(&self) -> &str;
356
357 fn config(&self) -> &ModelFamilyConfig;
359
360 fn size_config(&self, size: &str) -> Option<&ModelSizeConfig>;
362
363 fn detect_size(&self, hidden_dim: usize, num_layers: usize) -> Option<String>;
365
366 fn constraints(&self) -> &ModelConstraints;
368
369 fn expected_tensor_count(&self, size: &str) -> Option<usize>;
371
372 fn validate_tensor_names(
374 &self,
375 names: &[&str],
376 size: &str,
377 ) -> std::result::Result<(), ContractError>;
378}
379
380#[derive(Debug, Clone)]
387pub struct DynModelFamily {
388 config: ModelFamilyConfig,
389}
390
391impl DynModelFamily {
392 #[must_use]
394 pub fn new(config: ModelFamilyConfig) -> Self {
395 Self { config }
396 }
397}
398
399include!("model_family_part_02.rs");
400include!("model_family_part_03.rs");