1use crate::{
4 model::HyperparametersWriteError, Hyperparameters, KnownModel, LoadError, LoadProgress, Loader,
5};
6use ggml::format::{SaveError, SaveHandler, TensorLoadInfo, TensorSaveInfo};
7use half::f16;
8use std::{
9 collections::HashMap,
10 io::{BufRead, Seek, Write},
11 path::PathBuf,
12 sync::Arc,
13};
14use thiserror::Error;
15
16#[derive(Clone, Debug)]
17
18pub enum QuantizeProgress<'a> {
20 HyperparametersLoaded,
22 TensorLoading {
24 name: &'a str,
26 dims: [usize; 2],
28 element_type: ggml::Type,
30 n_elements: usize,
32 },
33 TensorQuantizing {
35 name: &'a str,
37 },
38 TensorQuantized {
40 name: &'a str,
42 original_size: usize,
44 reduced_size: usize,
46 history: Vec<f32>,
48 },
49 TensorSkipped {
51 name: &'a str,
53 size: usize,
55 },
56 Finished {
58 original_size: usize,
60 reduced_size: usize,
62 history: Vec<f32>,
64 },
65}
66
67#[derive(Error, Debug)]
68pub enum QuantizeError {
70 #[error("could not load model")]
71 Load(#[from] LoadError),
73 #[error("non-specific I/O error")]
74 Io(#[from] std::io::Error),
76 #[error("could not convert bytes to a UTF-8 string")]
77 InvalidUtf8(#[from] std::string::FromUtf8Error),
79 #[error("invalid integer conversion")]
80 InvalidIntegerConversion(#[from] std::num::TryFromIntError),
82 #[error("could not create file {path:?}")]
83 CreateFileFailed {
85 source: std::io::Error,
87 path: PathBuf,
89 },
90 #[error("invariant broken: {invariant} in {path:?}")]
94 InvariantBroken {
95 path: PathBuf,
97 invariant: String,
99 },
100 #[error("invalid quantization target {element_type:?}")]
102 InvalidQuantizationTarget {
103 element_type: ggml::Type,
105 },
106 #[error("unsupported element type {element_type:?}")]
108 UnsupportedElementType {
109 element_type: ggml::Type,
111 },
112 #[error("an error was encountered while writing the hyperparameters")]
114 HyperparametersWriteError(#[source] HyperparametersWriteError),
115}
116impl QuantizeError {
117 pub(crate) fn from_format_error(value: SaveError<QuantizeError>, path: PathBuf) -> Self {
118 match value {
119 SaveError::Io(io) => QuantizeError::Io(io),
120 SaveError::InvalidIntegerConversion(e) => QuantizeError::InvalidIntegerConversion(e),
121 SaveError::ImplementationError(e) => e,
122 SaveError::InvariantBroken(invariant) => {
123 QuantizeError::InvariantBroken { path, invariant }
124 }
125 }
126 }
127}
128
129pub fn quantize<M: KnownModel, R: BufRead + Seek, W: Write + Seek>(
131 reader: &mut R,
132 writer: &mut W,
133 desired_type: ggml::Type,
134 progress_callback: impl Fn(QuantizeProgress),
135) -> Result<(), QuantizeError> {
136 if !matches!(desired_type, ggml::Type::Q4_0 | ggml::Type::Q4_1) {
138 return Err(QuantizeError::InvalidQuantizationTarget {
139 element_type: desired_type,
140 });
141 }
142
143 let progress_callback = Arc::new(progress_callback);
145
146 let mut loader = Loader::<M::Hyperparameters, _>::new({
147 let progress_callback = progress_callback.clone();
148 move |p| {
149 if let LoadProgress::HyperparametersLoaded = p {
150 progress_callback(QuantizeProgress::HyperparametersLoaded)
151 }
152 }
153 });
154 ggml::format::load(reader, &mut loader)
155 .map_err(|err| LoadError::from_format_error(err, PathBuf::default()))?;
156
157 let Loader {
159 hyperparameters,
160 vocabulary,
161 tensors,
162 ..
163 } = loader;
164
165 let vocabulary = vocabulary
166 .id_to_token
167 .iter()
168 .cloned()
169 .zip(vocabulary.id_to_token_score)
170 .collect::<Vec<_>>();
171
172 let mut saver = QuantizeSaver::new(desired_type, &hyperparameters, &tensors, reader, |p| {
173 progress_callback(p)
174 });
175 ggml::format::save(
176 writer,
177 &mut saver,
178 &vocabulary,
179 &tensors.keys().cloned().collect::<Vec<_>>(),
180 )
181 .map_err(|err| QuantizeError::from_format_error(err, PathBuf::default()))?;
182
183 let sum_all: i64 = saver.history_all.iter().sum();
185 progress_callback(QuantizeProgress::Finished {
186 original_size: saver.total_size_original,
187 reduced_size: saver.total_size_new,
188 history: saver
189 .history_all
190 .iter()
191 .map(|hist| *hist as f32 / sum_all as f32)
192 .collect(),
193 });
194
195 Ok(())
196}
197
198struct QuantizeSaver<'a, F: Fn(QuantizeProgress), H: Hyperparameters, R: BufRead + Seek> {
199 quantization_type: ggml::Type,
201 hyperparameters: &'a H,
202 tensors: &'a HashMap<String, TensorLoadInfo>,
203 source_reader: &'a mut R,
204 progress_callback: F,
205
206 total_size_original: usize,
208 total_size_new: usize,
209 history_all: Vec<i64>,
210}
211impl<'a, F: Fn(QuantizeProgress), H: Hyperparameters, R: BufRead + Seek>
212 QuantizeSaver<'a, F, H, R>
213{
214 fn new(
215 quantization_type: ggml::Type,
216 hyperparameters: &'a H,
217 tensors: &'a HashMap<String, TensorLoadInfo>,
218 source_reader: &'a mut R,
219 progress_callback: F,
220 ) -> Self {
221 Self {
222 quantization_type,
223 hyperparameters,
224 tensors,
225 source_reader,
226 progress_callback,
227
228 total_size_original: 0,
229 total_size_new: 0,
230 history_all: vec![0; 16],
231 }
232 }
233}
234impl<F: Fn(QuantizeProgress), H: Hyperparameters, R: BufRead + Seek> SaveHandler<QuantizeError>
235 for QuantizeSaver<'_, F, H, R>
236{
237 fn write_hyperparameters(&mut self, writer: &mut dyn Write) -> Result<(), QuantizeError> {
238 self.hyperparameters
239 .write_ggml(writer)
240 .map_err(QuantizeError::HyperparametersWriteError)?;
241 Ok(())
242 }
243
244 fn tensor_data(&mut self, tensor_name: &str) -> Result<TensorSaveInfo, QuantizeError> {
245 let tensor = self.tensors.get(tensor_name).expect(
246 "tensor not found; should be impossible due to handler being populated from loader",
247 );
248
249 (self.progress_callback)(QuantizeProgress::TensorLoading {
250 name: tensor_name,
251 dims: tensor.dims,
252 n_elements: tensor.n_elements,
253 element_type: tensor.element_type,
254 });
255
256 let quantize = tensor_name.contains("weight") && tensor.n_dims == 2;
258 let raw_data = tensor.read_data(self.source_reader)?;
259
260 if quantize && !matches!(tensor.element_type, ggml::Type::F32 | ggml::Type::F16) {
261 return Err(QuantizeError::UnsupportedElementType {
262 element_type: tensor.element_type,
263 });
264 }
265
266 self.total_size_original += raw_data.len();
267
268 let (element_type, data) = if quantize {
269 (self.progress_callback)(QuantizeProgress::TensorQuantizing { name: tensor_name });
270
271 let data_f32: Vec<f32> = match tensor.element_type {
272 ggml::Type::F32 => raw_data
273 .chunks_exact(4)
274 .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()))
275 .collect(),
276 ggml::Type::F16 => raw_data
277 .chunks_exact(2)
278 .map(|chunk| {
279 f16::from_bits(u16::from_le_bytes(chunk.try_into().unwrap())).to_f32()
280 })
281 .collect(),
282 _ => unreachable!(),
283 };
284
285 let result = match self.quantization_type {
286 ggml::Type::Q4_0 => {
287 ggml::quantize_q4_0(&data_f32, tensor.n_elements, tensor.dims[0])
288 }
289 ggml::Type::Q4_1 => {
290 ggml::quantize_q4_1(&data_f32, tensor.n_elements, tensor.dims[0])
291 }
292 _ => unreachable!(),
293 };
294 let new_data = result.output;
295
296 let mut history_new = vec![];
297 for (i, val) in result.history.iter().enumerate() {
298 self.history_all[i] += val;
299 history_new.push(*val as f32 / tensor.n_elements as f32);
300 }
301
302 (self.progress_callback)(QuantizeProgress::TensorQuantized {
303 name: tensor_name,
304 original_size: raw_data.len(),
305 reduced_size: new_data.len(),
306 history: history_new,
307 });
308
309 self.total_size_new += new_data.len();
310
311 (self.quantization_type, new_data)
312 } else {
313 (self.progress_callback)(QuantizeProgress::TensorSkipped {
314 name: tensor_name,
315 size: raw_data.len(),
316 });
317 self.total_size_new += raw_data.len();
318 (tensor.element_type, raw_data)
319 };
320
321 Ok(TensorSaveInfo {
322 n_dims: tensor.n_dims,
323 dims: tensor.dims,
324 element_type,
325 data,
326 })
327 }
328}