Skip to main content

entrenar/yaml_mode/manifest/
model.rs

1//! Model Configuration
2//!
3//! Contains model-related configuration types for training manifests.
4
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8use super::shorthand::deserialize_human_usize_opt;
9
10/// Model configuration
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ModelConfig {
13    /// Model source URI (pacha://, hf://, or local path)
14    pub source: String,
15
16    /// Model format (safetensors, gguf, apr, pt)
17    #[serde(default, skip_serializing_if = "Option::is_none")]
18    pub format: Option<String>,
19
20    /// Architecture override
21    #[serde(default, skip_serializing_if = "Option::is_none")]
22    pub architecture: Option<ArchitectureConfig>,
23
24    /// Layers to freeze
25    #[serde(default, skip_serializing_if = "Option::is_none")]
26    pub freeze: Option<Vec<String>>,
27
28    /// Device placement (auto, cpu, cuda, cuda:0, mps)
29    #[serde(default, skip_serializing_if = "Option::is_none")]
30    pub device: Option<String>,
31
32    /// Data type (float32, float16, bfloat16)
33    #[serde(default, skip_serializing_if = "Option::is_none")]
34    pub dtype: Option<String>,
35}
36
37/// Model architecture configuration
38///
39/// Supports both preset names and custom architecture parameters.
40/// Custom params override values from config.json or preset defaults.
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ArchitectureConfig {
43    /// Architecture type (transformer, sequential)
44    #[serde(rename = "type")]
45    pub arch_type: String,
46
47    /// Hidden size (embedding dimension). Accepts shorthand: `"1K"` = 1024.
48    #[serde(
49        default,
50        skip_serializing_if = "Option::is_none",
51        deserialize_with = "deserialize_human_usize_opt"
52    )]
53    pub hidden_size: Option<usize>,
54
55    /// Number of transformer layers
56    #[serde(
57        default,
58        skip_serializing_if = "Option::is_none",
59        alias = "num_hidden_layers",
60        deserialize_with = "deserialize_human_usize_opt"
61    )]
62    pub num_layers: Option<usize>,
63
64    /// Number of attention heads
65    #[serde(
66        default,
67        skip_serializing_if = "Option::is_none",
68        alias = "num_attention_heads",
69        deserialize_with = "deserialize_human_usize_opt"
70    )]
71    pub num_heads: Option<usize>,
72
73    /// Number of key-value heads (for grouped-query attention)
74    #[serde(
75        default,
76        skip_serializing_if = "Option::is_none",
77        alias = "num_key_value_heads",
78        deserialize_with = "deserialize_human_usize_opt"
79    )]
80    pub num_kv_heads: Option<usize>,
81
82    /// FFN intermediate dimension. Accepts shorthand: `"4K"` = 4096.
83    #[serde(
84        default,
85        skip_serializing_if = "Option::is_none",
86        deserialize_with = "deserialize_human_usize_opt"
87    )]
88    pub intermediate_size: Option<usize>,
89
90    /// Vocabulary size. Accepts shorthand: `"32K"` = 32768.
91    #[serde(
92        default,
93        skip_serializing_if = "Option::is_none",
94        deserialize_with = "deserialize_human_usize_opt"
95    )]
96    pub vocab_size: Option<usize>,
97
98    /// Maximum sequence/position length. Accepts shorthand: `"2K"` = 2048, `"128K"` = 131072.
99    #[serde(
100        default,
101        skip_serializing_if = "Option::is_none",
102        alias = "max_position_embeddings",
103        deserialize_with = "deserialize_human_usize_opt"
104    )]
105    pub max_seq_length: Option<usize>,
106
107    /// RMS normalization epsilon
108    #[serde(default, skip_serializing_if = "Option::is_none")]
109    pub rms_norm_eps: Option<f32>,
110
111    /// RoPE theta (rotary positional encoding base)
112    #[serde(default, skip_serializing_if = "Option::is_none")]
113    pub rope_theta: Option<f32>,
114
115    /// Whether to use bias in linear layers
116    #[serde(default, skip_serializing_if = "Option::is_none")]
117    pub use_bias: Option<bool>,
118
119    /// Per-head dimension override (for models where head_dim != hidden_size / num_heads)
120    #[serde(
121        default,
122        skip_serializing_if = "Option::is_none",
123        deserialize_with = "deserialize_human_usize_opt"
124    )]
125    pub head_dim: Option<usize>,
126
127    /// Sequential layers (for sequential architecture type)
128    #[serde(default, skip_serializing_if = "Option::is_none")]
129    pub layers: Option<Vec<HashMap<String, serde_json::Value>>>,
130}