entrenar/config/cli/
quant_merge.rs1use clap::Parser;
4use std::path::PathBuf;
5
6#[derive(Parser, Debug, Clone, PartialEq)]
8pub struct QuantizeArgs {
9 #[arg(value_name = "MODEL")]
11 pub model: PathBuf,
12
13 #[arg(short, long)]
15 pub output: PathBuf,
16
17 #[arg(short, long, default_value = "4")]
19 pub bits: u8,
20
21 #[arg(short, long, default_value = "symmetric")]
23 pub method: QuantMethod,
24
25 #[arg(long)]
27 pub per_channel: bool,
28
29 #[arg(long)]
31 pub calibration_data: Option<PathBuf>,
32
33 #[arg(long)]
35 pub safetensors: bool,
36}
37
38#[derive(Parser, Debug, Clone, PartialEq)]
40pub struct MergeArgs {
41 #[arg(value_name = "MODELS", num_args = 0..)]
43 pub models: Vec<PathBuf>,
44
45 #[arg(short, long)]
47 pub output: PathBuf,
48
49 #[arg(short, long, default_value = "ties")]
51 pub method: MergeMethod,
52
53 #[arg(short, long)]
55 pub weight: Option<f32>,
56
57 #[arg(short, long)]
59 pub density: Option<f32>,
60
61 #[arg(long)]
63 pub weights: Option<String>,
64
65 #[arg(long)]
67 pub base: Option<PathBuf>,
68
69 #[arg(long)]
71 pub adapter: Option<PathBuf>,
72}
73
74#[derive(Debug, Clone, Copy, PartialEq, Default)]
76pub enum QuantMethod {
77 #[default]
78 Symmetric,
79 Asymmetric,
80}
81
82impl std::str::FromStr for QuantMethod {
83 type Err = String;
84
85 fn from_str(s: &str) -> Result<Self, Self::Err> {
86 match s.to_lowercase().as_str() {
87 "symmetric" | "sym" => Ok(QuantMethod::Symmetric),
88 "asymmetric" | "asym" => Ok(QuantMethod::Asymmetric),
89 _ => Err(format!(
90 "Unknown quantization method: {s}. Valid methods: symmetric, asymmetric"
91 )),
92 }
93 }
94}
95
96#[derive(Debug, Clone, Copy, PartialEq, Default)]
98pub enum MergeMethod {
99 #[default]
100 Ties,
101 Dare,
102 Slerp,
103 Average,
104 LoraAdapter,
106}
107
108impl std::str::FromStr for MergeMethod {
109 type Err = String;
110
111 fn from_str(s: &str) -> Result<Self, Self::Err> {
112 match s.to_lowercase().as_str() {
113 "ties" => Ok(MergeMethod::Ties),
114 "dare" => Ok(MergeMethod::Dare),
115 "slerp" => Ok(MergeMethod::Slerp),
116 "average" | "avg" => Ok(MergeMethod::Average),
117 "lora-adapter" | "lora" => Ok(MergeMethod::LoraAdapter),
118 _ => Err(format!(
119 "Unknown merge method: {s}. Valid: ties, dare, slerp, average, lora-adapter"
120 )),
121 }
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128
129 #[test]
130 fn test_quant_method_from_str() {
131 assert_eq!(
132 "symmetric".parse::<QuantMethod>().expect("parsing should succeed"),
133 QuantMethod::Symmetric
134 );
135 assert_eq!(
136 "sym".parse::<QuantMethod>().expect("parsing should succeed"),
137 QuantMethod::Symmetric
138 );
139 assert_eq!(
140 "asymmetric".parse::<QuantMethod>().expect("parsing should succeed"),
141 QuantMethod::Asymmetric
142 );
143 assert_eq!(
144 "asym".parse::<QuantMethod>().expect("parsing should succeed"),
145 QuantMethod::Asymmetric
146 );
147 assert!("invalid".parse::<QuantMethod>().is_err());
148 }
149
150 #[test]
151 fn test_merge_method_from_str() {
152 assert_eq!(
153 "ties".parse::<MergeMethod>().expect("parsing should succeed"),
154 MergeMethod::Ties
155 );
156 assert_eq!(
157 "dare".parse::<MergeMethod>().expect("parsing should succeed"),
158 MergeMethod::Dare
159 );
160 assert_eq!(
161 "slerp".parse::<MergeMethod>().expect("parsing should succeed"),
162 MergeMethod::Slerp
163 );
164 assert_eq!(
165 "average".parse::<MergeMethod>().expect("parsing should succeed"),
166 MergeMethod::Average
167 );
168 assert_eq!(
169 "avg".parse::<MergeMethod>().expect("parsing should succeed"),
170 MergeMethod::Average
171 );
172 assert_eq!(
173 "lora-adapter".parse::<MergeMethod>().expect("parsing should succeed"),
174 MergeMethod::LoraAdapter
175 );
176 assert_eq!(
177 "lora".parse::<MergeMethod>().expect("parsing should succeed"),
178 MergeMethod::LoraAdapter
179 );
180 assert!("invalid".parse::<MergeMethod>().is_err());
181 }
182
183 #[test]
184 fn test_quant_method_default() {
185 assert_eq!(QuantMethod::default(), QuantMethod::Symmetric);
186 }
187
188 #[test]
189 fn test_merge_method_default() {
190 assert_eq!(MergeMethod::default(), MergeMethod::Ties);
191 }
192}