llm_base/
quantize.rs

1//! Implements quantization of weights.
2
3use crate::{
4    model::HyperparametersWriteError, Hyperparameters, KnownModel, LoadError, LoadProgress, Loader,
5};
6use ggml::format::{SaveError, SaveHandler, TensorLoadInfo, TensorSaveInfo};
7use half::f16;
8use std::{
9    collections::HashMap,
10    io::{BufRead, Seek, Write},
11    path::PathBuf,
12    sync::Arc,
13};
14use thiserror::Error;
15
16#[derive(Clone, Debug)]
17
18/// Progress of quantization.
19pub enum QuantizeProgress<'a> {
20    /// Hyperparameters have been loaded.
21    HyperparametersLoaded,
22    /// A tensor is being loaded.
23    TensorLoading {
24        /// Name of the tensor.
25        name: &'a str,
26        /// Size of the tensor.
27        dims: [usize; 2],
28        /// Type of the tensor.
29        element_type: ggml::Type,
30        /// Number of elements in the tensor.
31        n_elements: usize,
32    },
33    /// A tensor is being quantized.
34    TensorQuantizing {
35        /// Name of the tensor.
36        name: &'a str,
37    },
38    /// A tensor has been quantized.
39    TensorQuantized {
40        /// Name of the tensor.
41        name: &'a str,
42        /// The original size of the tensor.
43        original_size: usize,
44        /// The reduced size of the tensor.
45        reduced_size: usize,
46        /// The history of the quantization.
47        history: Vec<f32>,
48    },
49    /// A tensor has been skipped.
50    TensorSkipped {
51        /// Name of the tensor.
52        name: &'a str,
53        /// The original size (in bytes) of the tensor data.
54        size: usize,
55    },
56    /// A model has been quantized.
57    Finished {
58        /// The original size (in bytes) of the model.
59        original_size: usize,
60        /// The reduced size (in bytes) of the model.
61        reduced_size: usize,
62        /// The history of the quantization.
63        history: Vec<f32>,
64    },
65}
66
67#[derive(Error, Debug)]
68/// Errors encountered during the quantization process.
69pub enum QuantizeError {
70    #[error("could not load model")]
71    /// There was an error while attempting to load the model.
72    Load(#[from] LoadError),
73    #[error("non-specific I/O error")]
74    /// A non-specific IO error.
75    Io(#[from] std::io::Error),
76    #[error("could not convert bytes to a UTF-8 string")]
77    /// One of the strings encountered was not valid UTF-8.
78    InvalidUtf8(#[from] std::string::FromUtf8Error),
79    #[error("invalid integer conversion")]
80    /// One of the integers encountered could not be converted to a more appropriate type.
81    InvalidIntegerConversion(#[from] std::num::TryFromIntError),
82    #[error("could not create file {path:?}")]
83    /// A file failed to create.
84    CreateFileFailed {
85        /// The original error.
86        source: std::io::Error,
87        /// The path that failed.
88        path: PathBuf,
89    },
90    /// An invariant was broken.
91    ///
92    /// This error is not relevant unless `loader2` is being used.
93    #[error("invariant broken: {invariant} in {path:?}")]
94    InvariantBroken {
95        /// The path that failed.
96        path: PathBuf,
97        /// The invariant that was broken.
98        invariant: String,
99    },
100    /// Attempted to quantize to an invalid target.
101    #[error("invalid quantization target {element_type:?}")]
102    InvalidQuantizationTarget {
103        /// The quantization target.
104        element_type: ggml::Type,
105    },
106    /// The quantization process encountered an unsupported element type.
107    #[error("unsupported element type {element_type:?}")]
108    UnsupportedElementType {
109        /// The element type.
110        element_type: ggml::Type,
111    },
112    /// An error was encountered while writing the hyperparameters.
113    #[error("an error was encountered while writing the hyperparameters")]
114    HyperparametersWriteError(#[source] HyperparametersWriteError),
115}
116impl QuantizeError {
117    pub(crate) fn from_format_error(value: SaveError<QuantizeError>, path: PathBuf) -> Self {
118        match value {
119            SaveError::Io(io) => QuantizeError::Io(io),
120            SaveError::InvalidIntegerConversion(e) => QuantizeError::InvalidIntegerConversion(e),
121            SaveError::ImplementationError(e) => e,
122            SaveError::InvariantBroken(invariant) => {
123                QuantizeError::InvariantBroken { path, invariant }
124            }
125        }
126    }
127}
128
129/// Quantizes a model.
130pub fn quantize<M: KnownModel, R: BufRead + Seek, W: Write + Seek>(
131    reader: &mut R,
132    writer: &mut W,
133    desired_type: ggml::Type,
134    progress_callback: impl Fn(QuantizeProgress),
135) -> Result<(), QuantizeError> {
136    // Sanity check
137    if !matches!(desired_type, ggml::Type::Q4_0 | ggml::Type::Q4_1) {
138        return Err(QuantizeError::InvalidQuantizationTarget {
139            element_type: desired_type,
140        });
141    }
142
143    // Load the model
144    let progress_callback = Arc::new(progress_callback);
145
146    let mut loader = Loader::<M::Hyperparameters, _>::new({
147        let progress_callback = progress_callback.clone();
148        move |p| {
149            if let LoadProgress::HyperparametersLoaded = p {
150                progress_callback(QuantizeProgress::HyperparametersLoaded)
151            }
152        }
153    });
154    ggml::format::load(reader, &mut loader)
155        .map_err(|err| LoadError::from_format_error(err, PathBuf::default()))?;
156
157    // Save the quantized model, quantizing as we go
158    let Loader {
159        hyperparameters,
160        vocabulary,
161        tensors,
162        ..
163    } = loader;
164
165    let vocabulary = vocabulary
166        .id_to_token
167        .iter()
168        .cloned()
169        .zip(vocabulary.id_to_token_score)
170        .collect::<Vec<_>>();
171
172    let mut saver = QuantizeSaver::new(desired_type, &hyperparameters, &tensors, reader, |p| {
173        progress_callback(p)
174    });
175    ggml::format::save(
176        writer,
177        &mut saver,
178        &vocabulary,
179        &tensors.keys().cloned().collect::<Vec<_>>(),
180    )
181    .map_err(|err| QuantizeError::from_format_error(err, PathBuf::default()))?;
182
183    // Final report
184    let sum_all: i64 = saver.history_all.iter().sum();
185    progress_callback(QuantizeProgress::Finished {
186        original_size: saver.total_size_original,
187        reduced_size: saver.total_size_new,
188        history: saver
189            .history_all
190            .iter()
191            .map(|hist| *hist as f32 / sum_all as f32)
192            .collect(),
193    });
194
195    Ok(())
196}
197
198struct QuantizeSaver<'a, F: Fn(QuantizeProgress), H: Hyperparameters, R: BufRead + Seek> {
199    // Input
200    quantization_type: ggml::Type,
201    hyperparameters: &'a H,
202    tensors: &'a HashMap<String, TensorLoadInfo>,
203    source_reader: &'a mut R,
204    progress_callback: F,
205
206    // Output
207    total_size_original: usize,
208    total_size_new: usize,
209    history_all: Vec<i64>,
210}
211impl<'a, F: Fn(QuantizeProgress), H: Hyperparameters, R: BufRead + Seek>
212    QuantizeSaver<'a, F, H, R>
213{
214    fn new(
215        quantization_type: ggml::Type,
216        hyperparameters: &'a H,
217        tensors: &'a HashMap<String, TensorLoadInfo>,
218        source_reader: &'a mut R,
219        progress_callback: F,
220    ) -> Self {
221        Self {
222            quantization_type,
223            hyperparameters,
224            tensors,
225            source_reader,
226            progress_callback,
227
228            total_size_original: 0,
229            total_size_new: 0,
230            history_all: vec![0; 16],
231        }
232    }
233}
234impl<F: Fn(QuantizeProgress), H: Hyperparameters, R: BufRead + Seek> SaveHandler<QuantizeError>
235    for QuantizeSaver<'_, F, H, R>
236{
237    fn write_hyperparameters(&mut self, writer: &mut dyn Write) -> Result<(), QuantizeError> {
238        self.hyperparameters
239            .write_ggml(writer)
240            .map_err(QuantizeError::HyperparametersWriteError)?;
241        Ok(())
242    }
243
244    fn tensor_data(&mut self, tensor_name: &str) -> Result<TensorSaveInfo, QuantizeError> {
245        let tensor = self.tensors.get(tensor_name).expect(
246            "tensor not found; should be impossible due to handler being populated from loader",
247        );
248
249        (self.progress_callback)(QuantizeProgress::TensorLoading {
250            name: tensor_name,
251            dims: tensor.dims,
252            n_elements: tensor.n_elements,
253            element_type: tensor.element_type,
254        });
255
256        // Quantize only 2D tensors
257        let quantize = tensor_name.contains("weight") && tensor.n_dims == 2;
258        let raw_data = tensor.read_data(self.source_reader)?;
259
260        if quantize && !matches!(tensor.element_type, ggml::Type::F32 | ggml::Type::F16) {
261            return Err(QuantizeError::UnsupportedElementType {
262                element_type: tensor.element_type,
263            });
264        }
265
266        self.total_size_original += raw_data.len();
267
268        let (element_type, data) = if quantize {
269            (self.progress_callback)(QuantizeProgress::TensorQuantizing { name: tensor_name });
270
271            let data_f32: Vec<f32> = match tensor.element_type {
272                ggml::Type::F32 => raw_data
273                    .chunks_exact(4)
274                    .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()))
275                    .collect(),
276                ggml::Type::F16 => raw_data
277                    .chunks_exact(2)
278                    .map(|chunk| {
279                        f16::from_bits(u16::from_le_bytes(chunk.try_into().unwrap())).to_f32()
280                    })
281                    .collect(),
282                _ => unreachable!(),
283            };
284
285            let result = match self.quantization_type {
286                ggml::Type::Q4_0 => {
287                    ggml::quantize_q4_0(&data_f32, tensor.n_elements, tensor.dims[0])
288                }
289                ggml::Type::Q4_1 => {
290                    ggml::quantize_q4_1(&data_f32, tensor.n_elements, tensor.dims[0])
291                }
292                _ => unreachable!(),
293            };
294            let new_data = result.output;
295
296            let mut history_new = vec![];
297            for (i, val) in result.history.iter().enumerate() {
298                self.history_all[i] += val;
299                history_new.push(*val as f32 / tensor.n_elements as f32);
300            }
301
302            (self.progress_callback)(QuantizeProgress::TensorQuantized {
303                name: tensor_name,
304                original_size: raw_data.len(),
305                reduced_size: new_data.len(),
306                history: history_new,
307            });
308
309            self.total_size_new += new_data.len();
310
311            (self.quantization_type, new_data)
312        } else {
313            (self.progress_callback)(QuantizeProgress::TensorSkipped {
314                name: tensor_name,
315                size: raw_data.len(),
316            });
317            self.total_size_new += raw_data.len();
318            (tensor.element_type, raw_data)
319        };
320
321        Ok(TensorSaveInfo {
322            n_dims: tensor.n_dims,
323            dims: tensor.dims,
324            element_type,
325            data,
326        })
327    }
328}