Skip to main content

entrenar/config/cli/
quant_merge.rs

1//! Quantization and merge command types
2
3use clap::Parser;
4use std::path::PathBuf;
5
6/// Arguments for the quantize command
7#[derive(Parser, Debug, Clone, PartialEq)]
8pub struct QuantizeArgs {
9    /// Path to model file
10    #[arg(value_name = "MODEL")]
11    pub model: PathBuf,
12
13    /// Output path for quantized model
14    #[arg(short, long)]
15    pub output: PathBuf,
16
17    /// Quantization bits (4 or 8)
18    #[arg(short, long, default_value = "4")]
19    pub bits: u8,
20
21    /// Quantization method (symmetric or asymmetric)
22    #[arg(short, long, default_value = "symmetric")]
23    pub method: QuantMethod,
24
25    /// Use per-channel quantization
26    #[arg(long)]
27    pub per_channel: bool,
28
29    /// Path to calibration data
30    #[arg(long)]
31    pub calibration_data: Option<PathBuf>,
32
33    /// Output as SafeTensors with I8 dtype + scale tensors (for WASM deployment)
34    #[arg(long)]
35    pub safetensors: bool,
36}
37
38/// Arguments for the merge command
39#[derive(Parser, Debug, Clone, PartialEq)]
40pub struct MergeArgs {
41    /// Paths to models to merge (for ties/dare/slerp/average)
42    #[arg(value_name = "MODELS", num_args = 0..)]
43    pub models: Vec<PathBuf>,
44
45    /// Output path for merged model
46    #[arg(short, long)]
47    pub output: PathBuf,
48
49    /// Merge method (ties, dare, slerp, average, lora-adapter)
50    #[arg(short, long, default_value = "ties")]
51    pub method: MergeMethod,
52
53    /// Interpolation weight (for slerp)
54    #[arg(short, long)]
55    pub weight: Option<f32>,
56
57    /// Density threshold (for ties/dare)
58    #[arg(short, long)]
59    pub density: Option<f32>,
60
61    /// Model weights (comma-separated, for weighted average)
62    #[arg(long)]
63    pub weights: Option<String>,
64
65    /// Base model path (for lora-adapter merge, ENT-LoRA-017)
66    #[arg(long)]
67    pub base: Option<PathBuf>,
68
69    /// LoRA adapter directory (for lora-adapter merge, ENT-LoRA-017)
70    #[arg(long)]
71    pub adapter: Option<PathBuf>,
72}
73
74/// Quantization method
75#[derive(Debug, Clone, Copy, PartialEq, Default)]
76pub enum QuantMethod {
77    #[default]
78    Symmetric,
79    Asymmetric,
80}
81
82impl std::str::FromStr for QuantMethod {
83    type Err = String;
84
85    fn from_str(s: &str) -> Result<Self, Self::Err> {
86        match s.to_lowercase().as_str() {
87            "symmetric" | "sym" => Ok(QuantMethod::Symmetric),
88            "asymmetric" | "asym" => Ok(QuantMethod::Asymmetric),
89            _ => Err(format!(
90                "Unknown quantization method: {s}. Valid methods: symmetric, asymmetric"
91            )),
92        }
93    }
94}
95
96/// Merge method
97#[derive(Debug, Clone, Copy, PartialEq, Default)]
98pub enum MergeMethod {
99    #[default]
100    Ties,
101    Dare,
102    Slerp,
103    Average,
104    /// Merge LoRA adapter into base model (ENT-LoRA-017)
105    LoraAdapter,
106}
107
108impl std::str::FromStr for MergeMethod {
109    type Err = String;
110
111    fn from_str(s: &str) -> Result<Self, Self::Err> {
112        match s.to_lowercase().as_str() {
113            "ties" => Ok(MergeMethod::Ties),
114            "dare" => Ok(MergeMethod::Dare),
115            "slerp" => Ok(MergeMethod::Slerp),
116            "average" | "avg" => Ok(MergeMethod::Average),
117            "lora-adapter" | "lora" => Ok(MergeMethod::LoraAdapter),
118            _ => Err(format!(
119                "Unknown merge method: {s}. Valid: ties, dare, slerp, average, lora-adapter"
120            )),
121        }
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128
129    #[test]
130    fn test_quant_method_from_str() {
131        assert_eq!(
132            "symmetric".parse::<QuantMethod>().expect("parsing should succeed"),
133            QuantMethod::Symmetric
134        );
135        assert_eq!(
136            "sym".parse::<QuantMethod>().expect("parsing should succeed"),
137            QuantMethod::Symmetric
138        );
139        assert_eq!(
140            "asymmetric".parse::<QuantMethod>().expect("parsing should succeed"),
141            QuantMethod::Asymmetric
142        );
143        assert_eq!(
144            "asym".parse::<QuantMethod>().expect("parsing should succeed"),
145            QuantMethod::Asymmetric
146        );
147        assert!("invalid".parse::<QuantMethod>().is_err());
148    }
149
150    #[test]
151    fn test_merge_method_from_str() {
152        assert_eq!(
153            "ties".parse::<MergeMethod>().expect("parsing should succeed"),
154            MergeMethod::Ties
155        );
156        assert_eq!(
157            "dare".parse::<MergeMethod>().expect("parsing should succeed"),
158            MergeMethod::Dare
159        );
160        assert_eq!(
161            "slerp".parse::<MergeMethod>().expect("parsing should succeed"),
162            MergeMethod::Slerp
163        );
164        assert_eq!(
165            "average".parse::<MergeMethod>().expect("parsing should succeed"),
166            MergeMethod::Average
167        );
168        assert_eq!(
169            "avg".parse::<MergeMethod>().expect("parsing should succeed"),
170            MergeMethod::Average
171        );
172        assert_eq!(
173            "lora-adapter".parse::<MergeMethod>().expect("parsing should succeed"),
174            MergeMethod::LoraAdapter
175        );
176        assert_eq!(
177            "lora".parse::<MergeMethod>().expect("parsing should succeed"),
178            MergeMethod::LoraAdapter
179        );
180        assert!("invalid".parse::<MergeMethod>().is_err());
181    }
182
183    #[test]
184    fn test_quant_method_default() {
185        assert_eq!(QuantMethod::default(), QuantMethod::Symmetric);
186    }
187
188    #[test]
189    fn test_merge_method_default() {
190        assert_eq!(MergeMethod::default(), MergeMethod::Ties);
191    }
192}