Skip to main content

mistralrs_core/pipeline/
isq.rs

1use std::{
2    borrow::Cow,
3    collections::{HashMap, HashSet},
4    env,
5    fs::File,
6    path::PathBuf,
7    str::FromStr,
8    sync::{atomic::AtomicUsize, Arc},
9    time::Instant,
10};
11/// Wrapper around a `Cow<'a, [u8]>` buffer that implements
12/// `safetensors::tensor::View`.
13///
14/// *Purpose*: lets us pass raw byte buffers to
15/// `safetensors::serialize_to_file` without cloning them into a `Vec<u8>` or
16/// converting to a higher‑level tensor type.  
17/// We expose the buffer as a 1‑D `u8` tensor of shape `[len]`.
18#[derive(Clone)]
19pub struct CowBytesView<'a> {
20    data: Cow<'a, [u8]>,
21    shape: [usize; 1],
22}
23
24impl<'a> CowBytesView<'a> {
25    /// Convenience constructor.
26    pub fn new(data: Cow<'a, [u8]>) -> Self {
27        let len = data.len();
28        Self { data, shape: [len] }
29    }
30}
31
32impl safetensors::tensor::View for CowBytesView<'_> {
33    fn dtype(&self) -> safetensors::tensor::Dtype {
34        // Serialize as raw bytes
35        safetensors::tensor::Dtype::U8
36    }
37
38    fn shape(&self) -> &[usize] {
39        &self.shape
40    }
41
42    fn data(&self) -> Cow<'_, [u8]> {
43        assert!(matches!(self.data, Cow::Borrowed(_)));
44        // Cloning a `Cow` is cheap (only clones the enum, not the data).
45        self.data.clone()
46    }
47
48    fn data_len(&self) -> usize {
49        self.data.len()
50    }
51}
52
53use anyhow::Result;
54use candle_core::{quantized, Context, Device, Tensor};
55use indicatif::{MultiProgress, ParallelProgressIterator, ProgressBar, ProgressStyle};
56use itertools::Itertools;
57use mistralrs_quant::{
58    AfqLayer, CollectedImatrixData, ColumnParallelLayer, DistributedKind, FP8Linear, GgufMatMul,
59    HqqLayer, IsqType, QuantMethod, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
60    ReplicatedLayer, RowParallelLayer, UnquantLinear,
61};
62use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
63use regex::Regex;
64use serde::Deserialize;
65use tokenizers::Tokenizer;
66use tracing::{info, warn};
67
68use crate::{
69    device_map::DeviceMapper, pipeline::EmbeddingModulePaths, topology::LayerTopology,
70    utils::progress::configure_progress_bar, Topology,
71};
72
73pub(crate) const UQFF_RESIDUAL_SAFETENSORS: &str = "residual.safetensors";
74// 10 GB max per file
75const MAX_UQFF_SIZE_BYTES: usize = 10 * 1024 * 1024 * 1024;
76pub const UQFF_MULTI_FILE_DELIMITER: &str = ";";
77
78/// Parse ISQ value.
79///
80/// If the provided value is a valid integer (one of 2,3,4,5,6,8), the best quantization type will be chosen.
81/// Note that the fallback is always a Q/K quantization but on Metal 2,3,4,6,8 uses the fast AFQ.
82///
83/// One of:
84/// - `Q4_0`
85/// - `Q4_1`
86/// - `Q5_0`
87/// - `Q5_1`
88/// - `Q8_0`
89/// - `Q8_1`
90/// - `Q2K`
91/// - `Q3K`
92/// - `Q4K`
93/// - `Q5K`
94/// - `Q6K`
95/// - `Q8K`
96/// - `HQQ1`
97/// - `HQQ2`
98/// - `HQQ3`
99/// - `HQQ4`
100/// - `HQQ8`
101/// - `AFQ2`
102/// - `AFQ3`
103/// - `AFQ4`
104/// - `AFQ6`
105/// - `AFQ8`
106pub fn parse_isq_value(s: &str, device: Option<&Device>) -> Result<IsqType, String> {
107    let is_metal = device.map(|device| device.is_metal()).unwrap_or(false);
108    let tp = match s.to_lowercase().as_str() {
109        "2" if is_metal => IsqType::AFQ2,
110        "2" if !is_metal => IsqType::Q2K,
111        "3" if is_metal => IsqType::AFQ3,
112        "3" if !is_metal => IsqType::Q3K,
113        "4" if is_metal => IsqType::AFQ4,
114        "4" if !is_metal => IsqType::Q4K,
115        "5" => IsqType::Q5K,
116        "6" if is_metal => IsqType::AFQ6,
117        "6" if !is_metal => IsqType::Q6K,
118        "8" if is_metal => IsqType::AFQ8,
119        "8" if !is_metal => IsqType::Q8_0,
120        "q4_0" => IsqType::Q4_0,
121        "q4_1" => IsqType::Q4_1,
122        "q5_0" => IsqType::Q5_0,
123        "q5_1" => IsqType::Q5_1,
124        "q8_0" => IsqType::Q8_0,
125        "q8_1" => IsqType::Q8_1,
126        "q2k" => IsqType::Q2K,
127        "q3k" => IsqType::Q3K,
128        "q4k" => IsqType::Q4K,
129        "q5k" => IsqType::Q5K,
130        "q6k" => IsqType::Q6K,
131        "q8k" => IsqType::Q8K,
132        "hqq8" => IsqType::HQQ8,
133        "hqq4" => IsqType::HQQ4,
134        "fp8" => IsqType::F8E4M3,
135        "afq8" => IsqType::AFQ8,
136        "afq6" => IsqType::AFQ6,
137        "afq4" => IsqType::AFQ4,
138        "afq3" => IsqType::AFQ3,
139        "afq2" => IsqType::AFQ2,
140        // "hqq3" => IsqType::HQQ3,
141        // "hqq2" => IsqType::HQQ2,
142        // "hqq1" => IsqType::HQQ1,
143        _ => return Err(format!("ISQ type {s} unknown, choose one of `2`, `3`, `4`, `6`, `8`, `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q8_1`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `Q8K`, `HQQ8`, `HQQ4`, `FP8`, `AFQ8`, `AFQ6`, `AFQ4`, `AFQ3`, `AFQ2`.")),
144    };
145    #[cfg(feature = "cuda")]
146    {
147        if !matches!(
148            tp,
149            IsqType::Q4_0
150                | IsqType::Q4_1
151                | IsqType::Q5_0
152                | IsqType::Q5_1
153                | IsqType::Q8_0
154                | IsqType::Q2K
155                | IsqType::Q3K
156                | IsqType::Q4K
157                | IsqType::Q5K
158                | IsqType::Q6K
159                | IsqType::HQQ8
160                | IsqType::HQQ4
161                | IsqType::F8E4M3 // | IsqType::HQQ3
162                                  // | IsqType::HQQ2
163                                  // | IsqType::HQQ1
164        ) {
165            return Err("ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`, `FP8`".to_string());
166        }
167    }
168    Ok(tp)
169}
170
171#[derive(Clone, Debug, Copy, Default, Deserialize, serde::Serialize)]
172pub enum IsqOrganization {
173    #[default]
174    #[serde(rename = "default")]
175    Default,
176    /// Only quantize MoE experts, if applicable. The enables MoQE.
177    /// <https://arxiv.org/abs/2310.02410>
178    #[serde(rename = "moqe")]
179    MoeExpertsOnly,
180}
181
182impl FromStr for IsqOrganization {
183    type Err = String;
184    fn from_str(s: &str) -> Result<Self, Self::Err> {
185        match s {
186            "default" => Ok(Self::Default),
187            "moqe" => Ok(Self::MoeExpertsOnly),
188            other => Err(format!(
189                "Expected ISQ organization `default` or `moqe`, got `{other}`"
190            )),
191        }
192    }
193}
194
195pub struct UqffFullSer<'a> {
196    pub tokenizer: &'a Tokenizer,
197    pub template_filename: &'a Option<PathBuf>,
198    pub modules: Option<&'a String>,
199    pub module_paths: Option<&'a [EmbeddingModulePaths]>,
200    pub generation_config: Option<&'a PathBuf>,
201    pub config: String,
202    pub processor_filename: &'a Option<PathBuf>,
203    pub preprocessor_filename: &'a Option<PathBuf>,
204}
205
206#[derive(Debug, Clone, Copy)]
207pub enum ImatrixDataSource<'a> {
208    File(&'a PathBuf),
209    Collected,
210}
211
212pub trait IsqModel {
213    /// Corresponds to `IsqOrganization::Default`
214    #[allow(clippy::type_complexity)]
215    fn get_layers(
216        &mut self,
217    ) -> (
218        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
219        &dyn DeviceMapper,
220    );
221
222    /// This is used for imatrix generation internally. Begin stats tracking.
223    fn begin_track_stats(&mut self) -> anyhow::Result<()> {
224        let layers = self
225            .get_layers()
226            .0
227            .into_iter()
228            .map(|(layer, _)| layer)
229            .collect::<Vec<_>>();
230        for layer in layers {
231            Arc::get_mut(layer).unwrap().begin_track_stats()?;
232        }
233        Ok(())
234    }
235
236    /// End stats tracking and return the imatrix data
237    fn extract_imatrix_data(&mut self) -> candle_core::Result<CollectedImatrixData> {
238        let layers = self
239            .get_layers()
240            .0
241            .into_iter()
242            .enumerate()
243            .map(|(i, (layer, _))| (i, layer))
244            .collect::<Vec<_>>();
245        let mut data = HashMap::new();
246        for (i, layer) in layers {
247            data.insert(i, Some(layer.end_track_stats()?.to_vec1::<f32>()?));
248        }
249        Ok(CollectedImatrixData(data))
250    }
251
252    /// Corresponds to `IsqOrganization::MoeExpertsOnly`
253    /// https://arxiv.org/abs/2310.02410
254    #[allow(clippy::type_complexity)]
255    fn get_layers_moe_experts_only(
256        &mut self,
257    ) -> (
258        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
259        &dyn DeviceMapper,
260    ) {
261        self.get_layers()
262    }
263
264    /// Corresponds to `IsqOrganization::MoeExpertsOnly`
265    /// This is used for imatrix generation internally. Begin stats tracking.
266    fn begin_track_stats_moe_experts_only(&mut self) -> anyhow::Result<()> {
267        let layers = self
268            .get_layers()
269            .0
270            .into_iter()
271            .map(|(layer, _)| layer)
272            .collect::<Vec<_>>();
273        for layer in layers {
274            Arc::get_mut(layer).unwrap().begin_track_stats()?;
275        }
276        Ok(())
277    }
278
279    /// Corresponds to `IsqOrganization::MoeExpertsOnly`
280    /// End stats tracking and return the imatrix data
281    fn extract_imatrix_data_moe_experts_only(
282        &mut self,
283    ) -> candle_core::Result<CollectedImatrixData> {
284        let layers = self
285            .get_layers()
286            .0
287            .into_iter()
288            .enumerate()
289            .map(|(i, (layer, _))| (i, layer))
290            .collect::<Vec<_>>();
291        let mut data = HashMap::new();
292        for (i, layer) in layers {
293            data.insert(i, Some(layer.end_track_stats()?.to_vec1::<f32>()?));
294        }
295        Ok(CollectedImatrixData(data))
296    }
297
298    /// Corresponding to the specific order the model produces ISQ layers (None means
299    /// do not search for in the imatrix file). This is used to pair ISQ layers with the
300    /// corresponding imatrix weights.
301    ///
302    /// - This is only for loading from a llama.cpp imatrix file.
303    /// - Corresponds to `IsqOrganization::Default`
304    fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
305        // TODO: make this required.
306        candle_core::bail!("This model does not support quantizing with an imatrix.");
307    }
308
309    /// Residual tensors for generating a UQFF file. Counterpart to [`get_layers`].
310    fn residual_tensors(&self) -> Vec<(String, Tensor)>;
311
312    /// Residual tensors for generating a UQFF file. Counterpart to [`get_layers_moe_experts_only`].
313    fn residual_tensors_moe_experts_only(&self) -> Option<Vec<(String, Tensor)>> {
314        None
315    }
316
317    /// Quantize the model in-situ.
318    ///
319    /// This function will also create a UQFF file, or, if the model supports it (residual tensors are returned),
320    /// a full serialization is created.
321    #[allow(clippy::too_many_arguments)]
322    fn quantize(
323        &mut self,
324        dtype: Option<IsqType>,
325        device: Device,
326        topology: Option<&Topology>,
327        silent: bool,
328        imatrix_source: Option<ImatrixDataSource<'_>>,
329        organization: IsqOrganization,
330        apply_quantization: bool,
331        write_artifacts: Option<&PathBuf>,
332        full_ser: UqffFullSer<'_>,
333        multi_progress: Arc<MultiProgress>,
334    ) -> candle_core::Result<()> {
335        {
336            let mut imatrix_source = imatrix_source;
337            let mut imatrix_to_weight_map: Option<HashMap<usize, Option<Vec<f32>>>> =
338                if apply_quantization {
339                    match imatrix_source.take() {
340                        Some(ImatrixDataSource::File(imatrix)) => {
341                            let ext = imatrix.extension().ok_or(candle_core::Error::msg(
342                                "Expected an extension for the imatrix source file.",
343                            ))?;
344                            if ext == "cimatrix" {
345                                info!(
346                                    "Loading collected imatrix source file: `{}`",
347                                    imatrix.display()
348                                );
349                                let data = CollectedImatrixData::load_imatrix(imatrix)?;
350                                info!(
351                                    "Quantizing with collected imatrix data, {} imatrix weights",
352                                    data.0.iter().filter(|(_, x)| x.is_some()).count()
353                                );
354                                Some(data.0)
355                            } else {
356                                if ext != "imatrix" {
357                                    warn!("Imatrix source file extension is {ext:?}, expected .imatrix/.cimatrix. Assuming GGUF specification");
358                                }
359                                info!(
360                                    "Loading GGUF-format imatrix source file: `{}`",
361                                    imatrix.display()
362                                );
363                                let mut imatrix_data =
364                                    quantized::imatrix_file::load_imatrix(imatrix.clone())?;
365                                let imatrix_mapping = self
366                                    .imatrix_names()?
367                                    .into_iter()
368                                    .enumerate()
369                                    .collect::<HashMap<_, _>>();
370
371                                let layer_to_weight = imatrix_mapping
372                                    .into_iter()
373                                    .map(|(i, name)| {
374                                        if let Some(name) = name {
375                                            (i, Some(imatrix_data.remove(&name).unwrap()))
376                                        } else {
377                                            (i, None)
378                                        }
379                                    })
380                                    .collect::<HashMap<_, _>>();
381                                info!(
382                                    "Quantizing with imatrix file `{}`, {} imatrix weights",
383                                    imatrix.display(),
384                                    layer_to_weight.iter().filter(|(_, x)| x.is_some()).count()
385                                );
386                                Some(layer_to_weight)
387                            }
388                        }
389                        Some(ImatrixDataSource::Collected) => {
390                            let data = match organization {
391                                IsqOrganization::Default => self.extract_imatrix_data()?,
392                                IsqOrganization::MoeExpertsOnly => {
393                                    self.extract_imatrix_data_moe_experts_only()?
394                                }
395                            };
396                            // Save the collected imatrix data so users can reuse it
397                            let count = data.0.iter().filter(|(_, x)| x.is_some()).count();
398                            let save_path = format!("collected-{count}.cimatrix");
399                            info!("Saving collected imatrix data to `{save_path}`");
400                            data.save_imatrix(save_path)?;
401                            info!(
402                                "Quantizing with collected imatrix data, {count} imatrix weights"
403                            );
404                            Some(data.0)
405                        }
406                        None => None,
407                    }
408                } else {
409                    if imatrix_source.is_some() {
410                        info!("Imatrix source provided but quantization disabled; ignoring input.");
411                    }
412                    None
413                };
414
415            let (mut tensors, mapper) = match organization {
416                IsqOrganization::Default => self.get_layers(),
417                IsqOrganization::MoeExpertsOnly => self.get_layers_moe_experts_only(),
418            };
419
420            let total_tensors = tensors.len();
421
422            if apply_quantization {
423                let imatrix_to_weight: Vec<Option<Vec<f32>>> =
424                    if let Some(mut imatrix_to_weight) = imatrix_to_weight_map.take() {
425                        let ordered_keys = imatrix_to_weight
426                            .keys()
427                            .copied()
428                            .sorted()
429                            .collect::<Vec<_>>();
430                        ordered_keys
431                            .into_iter()
432                            .map(|layer| imatrix_to_weight.remove(&layer).unwrap())
433                            .collect()
434                    } else {
435                        vec![None; tensors.len()]
436                    };
437
438                let n_quantized = AtomicUsize::new(0);
439                if let Some(topology) = topology {
440                    let mut dtypes = HashSet::new();
441                    for layer in topology.layers.iter().flatten() {
442                        if let LayerTopology {
443                            isq: Some(isq_dtype),
444                            device: _,
445                        } = layer
446                        {
447                            dtypes.insert(isq_dtype);
448                        }
449                    }
450                    info!("Applying in-situ quantization into {:?} to {total_tensors} tensors according to topology.", dtypes.into_iter().collect::<Vec<_>>());
451                } else {
452                    info!(
453                        "Applying in-situ quantization into {dtype:?} to {total_tensors} tensors."
454                    );
455                }
456                let bar = ProgressBar::new(total_tensors as u64);
457                configure_progress_bar(&bar);
458                bar.set_style(
459                    ProgressStyle::default_bar()
460                        .template("[{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})")
461                        .unwrap()
462                        .progress_chars("#>-"),
463                );
464                multi_progress.add(bar.clone());
465
466                let layers = topology.map(|x| {
467                    x.layers
468                        .iter()
469                        .filter_map(|topo| topo.as_ref().map(|x| (x.isq, x.device.clone())))
470                        .collect::<Vec<_>>()
471                });
472
473                let mut devices_and_dtypes = Vec::new();
474                for (_, layer_num) in &tensors {
475                    let device = if let Some(ref layers) = layers {
476                        if let Some(layer) = layer_num {
477                            layers
478                                .get(*layer)
479                                .as_ref()
480                                .map(|x| x.1.clone())
481                                .unwrap_or(Some(device.clone()))
482                                .unwrap_or(device.clone())
483                        } else {
484                            device.clone()
485                        }
486                    } else if let Some(layer_num) = layer_num {
487                        mapper
488                            .device_for(*layer_num, false)
489                            .cloned()
490                            .unwrap_or(device.clone())
491                    } else {
492                        device.clone()
493                    };
494                    let dtype = if let Some(ref layers) = layers {
495                        if let Some(layer) = layer_num {
496                            layers.get(*layer).cloned().map(|x| x.0).unwrap_or(dtype)
497                        } else {
498                            dtype
499                        }
500                    } else {
501                        dtype
502                    };
503                    devices_and_dtypes.push((device, dtype));
504                }
505
506                let t_start = Instant::now();
507
508                // Get the MINIMUM of the max isq threads the quant method
509                let mut minimum_max_threads = {
510                    let current_rayon_threads = rayon::current_num_threads();
511                    if let Some(dtype) = dtype {
512                        dtype
513                            .get_max_isq_cpu_threads()
514                            .map(usize::from)
515                            .unwrap_or(current_rayon_threads)
516                    } else {
517                        current_rayon_threads
518                    }
519                };
520                if env::var("MISTRALRS_ISQ_SINGLETHREAD").is_ok() {
521                    minimum_max_threads = 1;
522                }
523
524                if matches!(imatrix_source, Some(ImatrixDataSource::Collected)) {
525                    // Collected imatrix means that the model is potentially on the gpu already
526                    minimum_max_threads = 1;
527                }
528
529                info!("Applying ISQ on {minimum_max_threads} threads.");
530
531                let pool = rayon::ThreadPoolBuilder::new()
532                    .num_threads(minimum_max_threads)
533                    .build()
534                    .map_err(candle_core::Error::msg)?;
535
536                let guard = QuantizeOntoGuard::new();
537
538                pool.install(|| {
539                    use indicatif::ParallelProgressIterator;
540                    use rayon::iter::{
541                        IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
542                    };
543                    if silent {
544                        tensors
545                            .par_iter_mut()
546                            .zip(devices_and_dtypes)
547                            .zip(imatrix_to_weight)
548                            .for_each(|(((tensor, _), (device, dtype)), imatrix_weight)| {
549                                **tensor = tensor
550                                    .clone()
551                                    .apply_isq(
552                                        dtype,
553                                        device.clone(),
554                                        &n_quantized,
555                                        imatrix_weight,
556                                        guard.clone(),
557                                    )
558                                    .unwrap();
559                                device.synchronize().unwrap();
560                            });
561                    } else {
562                        tensors
563                            .par_iter_mut()
564                            .zip(devices_and_dtypes)
565                            .zip(imatrix_to_weight)
566                            .progress_with(bar)
567                            .for_each(|(((tensor, _), (device, dtype)), imatrix_weight)| {
568                                **tensor = tensor
569                                    .clone()
570                                    .apply_isq(
571                                        dtype,
572                                        device.clone(),
573                                        &n_quantized,
574                                        imatrix_weight,
575                                        guard.clone(),
576                                    )
577                                    .unwrap();
578                                device.synchronize().unwrap();
579                            });
580                    }
581                });
582
583                let t_end = Instant::now();
584                info!(
585                    "Finished quantization pass in {:.2}s ({} tensors).",
586                    t_end.duration_since(t_start).as_secs_f32(),
587                    total_tensors
588                );
589            } else if imatrix_source.is_some() {
590                info!(
591                    "Imatrix data provided but quantization was skipped; existing tensors will be serialized as-is."
592                );
593            } else if write_artifacts.is_some() {
594                info!(
595                    "Skipping additional quantization; serializing {total_tensors} existing tensors."
596                );
597            }
598
599            if let Some(serialized) = write_artifacts {
600                info!(
601                    "Serializing {total_tensors} ISQ tensors to `{}`.",
602                    serialized.display()
603                );
604
605                if serialized.extension().is_none_or(|ext| ext != "uqff") {
606                    candle_core::bail!("UQFF output path extension must be `.uqff`",);
607                }
608
609                let bar = ProgressBar::new(total_tensors as u64);
610                configure_progress_bar(&bar);
611                bar.set_style(
612                    ProgressStyle::default_bar()
613                        .template("[{elapsed_precise}] [{bar:40.red/magenta}] {pos}/{len} ({eta})")
614                        .unwrap()
615                        .progress_chars("#>-"),
616                );
617
618                // Metal and CUDA require serialization on the current thread because GPU contexts are thread-local.
619                // Using a rayon thread pool (even with n_threads=1) creates a new thread without the GPU context.
620                #[cfg(any(feature = "metal", feature = "cuda"))]
621                let quantized_values: candle_core::Result<Vec<_>> = {
622                    tensors
623                        .iter()
624                        .enumerate()
625                        .filter(|(_, (layer, _))| layer.isq_serde_supported())
626                        .map(|(i, (layer, _))| {
627                            if !silent {
628                                bar.inc(1);
629                            }
630                            Ok((
631                                i.to_string(),
632                                match layer.serialize()? {
633                                    Cow::Borrowed(_) => unreachable!(),
634                                    Cow::Owned(owned) => owned,
635                                },
636                            ))
637                        })
638                        .collect()
639                };
640
641                #[cfg(not(any(feature = "metal", feature = "cuda")))]
642                let quantized_values: candle_core::Result<Vec<_>> = {
643                    let pool = rayon::ThreadPoolBuilder::new()
644                        .num_threads(2)
645                        .build()
646                        .map_err(candle_core::Error::msg)?;
647
648                    pool.install(|| {
649                        use rayon::iter::IntoParallelRefIterator;
650                        if silent {
651                            tensors
652                                .par_iter()
653                                .enumerate()
654                                .filter(|(_, (layer, _))| layer.isq_serde_supported())
655                                .map(|(i, (layer, _))| {
656                                    Ok((
657                                        i.to_string(),
658                                        match layer.serialize()? {
659                                            Cow::Borrowed(_) => unreachable!(),
660                                            Cow::Owned(owned) => owned,
661                                        },
662                                    ))
663                                })
664                                .collect::<candle_core::Result<Vec<_>>>()
665                        } else {
666                            tensors
667                                .par_iter()
668                                .enumerate()
669                                .progress_with(bar)
670                                .filter(|(_, (layer, _))| layer.isq_serde_supported())
671                                .map(|(i, (layer, _))| {
672                                    Ok((
673                                        i.to_string(),
674                                        match layer.serialize()? {
675                                            Cow::Borrowed(_) => unreachable!(),
676                                            Cow::Owned(owned) => owned,
677                                        },
678                                    ))
679                                })
680                                .collect::<candle_core::Result<Vec<_>>>()
681                        }
682                    })
683                };
684
685                let quantized_values = quantized_values?;
686
687                let parent = serialized
688                    .parent()
689                    .context("Target UQFF path must have a filename!")?;
690
691                std::fs::create_dir_all(parent)?;
692
693                let file_stem = serialized
694                    .file_stem()
695                    .context("Target UQFF path must have a file stem!")?
696                    .to_string_lossy()
697                    .to_string();
698
699                // Shard quantized values by cumulative byte size, max MAX_UQFF_SIZE_BYTES per file
700                let mut current_chunk = Vec::new();
701                let mut current_bytes: usize = 0;
702                let mut shard_index = 0;
703
704                // Every 10GB, flush the file. Then save any remaining tensors
705                for (name, tensor) in quantized_values.iter() {
706                    let tensor_bytes = tensor.len();
707                    if !current_chunk.is_empty()
708                        && current_bytes + tensor_bytes > MAX_UQFF_SIZE_BYTES
709                    {
710                        let mut shard_path = parent.to_path_buf();
711                        shard_path.push(format!("{file_stem}-{shard_index}.uqff"));
712                        info!(
713                            "Writing shard {} to `{}`",
714                            shard_index,
715                            shard_path.display()
716                        );
717                        safetensors::serialize_to_file(current_chunk.clone(), None, &shard_path)?;
718                        shard_index += 1;
719                        current_chunk.clear();
720                        current_bytes = 0;
721                    }
722                    current_bytes += tensor_bytes;
723                    current_chunk.push((name, CowBytesView::new(Cow::Borrowed(tensor))));
724                }
725
726                if !current_chunk.is_empty() {
727                    let mut shard_path = parent.to_path_buf();
728                    shard_path.push(format!("{file_stem}-{shard_index}.uqff"));
729                    info!(
730                        "Writing final shard {} to `{}`",
731                        shard_index,
732                        shard_path.display()
733                    );
734                    safetensors::serialize_to_file(current_chunk.clone(), None, &shard_path)?;
735                }
736
737                let residual = match organization {
738                    IsqOrganization::Default => self.residual_tensors(),
739                    IsqOrganization::MoeExpertsOnly => self
740                        .residual_tensors_moe_experts_only()
741                        .unwrap_or(self.residual_tensors()),
742                };
743
744                let residual_out = parent.join(UQFF_RESIDUAL_SAFETENSORS);
745                let config_out = parent.join("config.json");
746                let modules_out = parent.join("modules.json");
747                let tokenizer_out = parent.join("tokenizer.json");
748                let tokenizer_cfg_out = parent.join("tokenizer_config.json");
749                let chat_template_jinja_out = parent.join("chat_template.jinja");
750                let gen_cfg_out = parent.join("generation_config.json");
751                let processor_out = parent.join("processor_config.json");
752                let preprocessor_out = parent.join("preprocessor_config.json");
753
754                info!(
755                    "Serializing {} residual tensors to `{}`.",
756                    residual.len(),
757                    residual_out.display()
758                );
759
760                safetensors::serialize_to_file(residual, None, &residual_out)?;
761
762                let UqffFullSer {
763                    tokenizer,
764                    template_filename,
765                    modules,
766                    module_paths,
767                    generation_config,
768                    config,
769                    processor_filename,
770                    preprocessor_filename,
771                } = full_ser;
772
773                info!("Serializing configuration to `{}`.", config_out.display());
774
775                std::fs::write(config_out, config)?;
776
777                info!("Serializing tokenizer to `{}`.", tokenizer_out.display());
778
779                serde_json::to_writer_pretty(File::create(&tokenizer_out)?, tokenizer)
780                    .map_err(candle_core::Error::msg)?;
781
782                if let Some(template_filename) = template_filename {
783                    let template =
784                        std::fs::read(template_filename).map_err(candle_core::Error::msg)?;
785
786                    if template_filename.extension().map(|e| e.to_str()) == Some(Some("jinja")) {
787                        info!(
788                            "Serializing chat template to `{}`.",
789                            chat_template_jinja_out.display()
790                        );
791                        std::fs::write(&chat_template_jinja_out, template)
792                            .map_err(candle_core::Error::msg)?;
793                    } else {
794                        info!(
795                            "Serializing tokenizer config to `{}`.",
796                            tokenizer_cfg_out.display()
797                        );
798                        std::fs::write(&tokenizer_cfg_out, template)
799                            .map_err(candle_core::Error::msg)?;
800                    }
801                }
802
803                if let Some(generation_config) = generation_config {
804                    info!(
805                        "Serializing generation config to `{}`.",
806                        gen_cfg_out.display()
807                    );
808
809                    let cfg = std::fs::read(generation_config).map_err(candle_core::Error::msg)?;
810                    std::fs::write(&gen_cfg_out, cfg).map_err(candle_core::Error::msg)?;
811                }
812
813                if let Some(processor_config) = processor_filename {
814                    info!(
815                        "Serializing processor config to `{}`.",
816                        processor_out.display()
817                    );
818
819                    let cfg = std::fs::read(processor_config).map_err(candle_core::Error::msg)?;
820                    std::fs::write(&processor_out, cfg).map_err(candle_core::Error::msg)?;
821                }
822
823                if let Some(preprocessor_config) = preprocessor_filename {
824                    info!(
825                        "Serializing preprocessor config to `{}`.",
826                        preprocessor_out.display()
827                    );
828
829                    let cfg =
830                        std::fs::read(preprocessor_config).map_err(candle_core::Error::msg)?;
831                    std::fs::write(&preprocessor_out, cfg).map_err(candle_core::Error::msg)?;
832                }
833
834                if let Some(modules) = modules {
835                    info!(
836                        "Serializing modules manifest to `{}`.",
837                        modules_out.display()
838                    );
839
840                    std::fs::write(&modules_out, modules).map_err(candle_core::Error::msg)?;
841
842                    if let Some(module_paths) = module_paths {
843                        for module in module_paths {
844                            match module {
845                                EmbeddingModulePaths::Transformer { path }
846                                | EmbeddingModulePaths::Pooling { path, .. }
847                                | EmbeddingModulePaths::Dense { path, .. }
848                                | EmbeddingModulePaths::Normalize { path } => {
849                                    if path.is_empty() {
850                                        continue;
851                                    }
852                                    let module_dir = parent.join(path.as_str());
853                                    std::fs::create_dir_all(&module_dir)
854                                        .map_err(candle_core::Error::msg)?;
855
856                                    match module {
857                                        EmbeddingModulePaths::Pooling { config, .. } => {
858                                            let dest = module_dir.join("config.json");
859                                            if config != &dest {
860                                                std::fs::copy(config, &dest)
861                                                    .map_err(candle_core::Error::msg)?;
862                                            }
863                                        }
864                                        EmbeddingModulePaths::Dense { config, model, .. } => {
865                                            let dest_cfg = module_dir.join("config.json");
866                                            if config != &dest_cfg {
867                                                std::fs::copy(config, &dest_cfg)
868                                                    .map_err(candle_core::Error::msg)?;
869                                            }
870                                            let dest_model = module_dir.join("model.safetensors");
871                                            if model != &dest_model {
872                                                std::fs::copy(model, &dest_model)
873                                                    .map_err(candle_core::Error::msg)?;
874                                            }
875                                        }
876                                        EmbeddingModulePaths::Transformer { .. }
877                                        | EmbeddingModulePaths::Normalize { .. } => {}
878                                    }
879                                }
880                            }
881                        }
882                    }
883                }
884            }
885        }
886        Ok(())
887    }
888
889    fn load_from_artifacts(
890        &mut self,
891        device: Device,
892        topology: Option<&Topology>,
893        silent: bool,
894        artifacts: &[PathBuf],
895    ) -> candle_core::Result<()> {
896        let (tensors, mapper) = self.get_layers();
897        let total_tensors = tensors.len();
898
899        let layers = topology.map(|x| {
900            x.layers
901                .iter()
902                .filter_map(|topo| topo.as_ref().map(|x| (x.isq, x.device.clone())))
903                .collect::<Vec<_>>()
904        });
905
906        let mut devices = Vec::new();
907        let mut comms = Vec::new();
908        for (_, layer_num) in &tensors {
909            let device = if let Some(ref layers) = layers {
910                if let Some(layer) = layer_num {
911                    layers
912                        .get(*layer)
913                        .as_ref()
914                        .map(|x| x.1.clone())
915                        .unwrap_or(Some(device.clone()))
916                        .unwrap_or(device.clone())
917                } else {
918                    device.clone()
919                }
920            } else if let Some(layer_num) = layer_num {
921                mapper
922                    .device_for(*layer_num, false)
923                    .cloned()
924                    .unwrap_or(device.clone())
925            } else {
926                device.clone()
927            };
928            devices.push(device);
929            comms.push(mapper.get_comm_for(layer_num.unwrap_or(0))?)
930        }
931
932        let artifacts = unsafe { candle_core::safetensors::MmapedSafetensors::multi(artifacts)? };
933
934        let artifact_isqs = artifacts
935            .tensors()
936            .into_iter()
937            .map(|(name, tensor)| {
938                (
939                    name.parse::<usize>()
940                        .expect("Name should be parseable as usize"),
941                    tensor,
942                )
943            })
944            .collect::<HashMap<_, _>>();
945
946        if artifact_isqs.len() != total_tensors {
947            candle_core::bail!(
948                "Number of artifacts ({}) does not match the number of ISQ layers ({total_tensors})",
949                artifact_isqs.len(),
950            );
951        }
952
953        let bar = ProgressBar::new(total_tensors as u64);
954        configure_progress_bar(&bar);
955        bar.set_style(
956            ProgressStyle::default_bar()
957                .template("[{elapsed_precise}] [{bar:40.red/magenta}] {pos}/{len} ({eta})")
958                .unwrap()
959                .progress_chars("#>-"),
960        );
961
962        let t_start = Instant::now();
963
964        let guard = QuantizeOntoGuard::new();
965
966        if silent {
967            (0..tensors.len())
968                .into_par_iter()
969                .zip(tensors)
970                .map(|(i, (tensor, _))| {
971                    if let Some(artifact) = artifact_isqs.get(&i) {
972                        let artifact = artifact.data();
973
974                        let comm = comms[i].clone();
975                        let deserialized = match tensor.is_distributed() {
976                            Some(DistributedKind::ColumnParallel) => {
977                                ColumnParallelLayer::deserialize(
978                                    Cow::from(artifact),
979                                    &devices[i],
980                                    &comm,
981                                    guard.clone(),
982                                )?
983                            }
984                            Some(DistributedKind::RowParallel) => RowParallelLayer::deserialize(
985                                Cow::from(artifact),
986                                &devices[i],
987                                &comm,
988                                guard.clone(),
989                            )?,
990                            Some(DistributedKind::Replicated) => ReplicatedLayer::deserialize(
991                                Cow::from(artifact),
992                                &devices[i],
993                                &comm,
994                                guard.clone(),
995                            )?,
996                            None => {
997                                // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
998                                let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
999                                match QuantizedSerdeType::try_from(isq_type as usize)? {
1000                                    QuantizedSerdeType::Gguf => GgufMatMul::deserialize(
1001                                        Cow::from(artifact),
1002                                        &devices[i],
1003                                        &comm,
1004                                        guard.clone(),
1005                                    )?,
1006                                    QuantizedSerdeType::Unquant => UnquantLinear::deserialize(
1007                                        Cow::from(artifact),
1008                                        &devices[i],
1009                                        &comm,
1010                                        guard.clone(),
1011                                    )?,
1012                                    QuantizedSerdeType::Hqq => HqqLayer::deserialize(
1013                                        Cow::from(artifact),
1014                                        &devices[i],
1015                                        &comm,
1016                                        guard.clone(),
1017                                    )?,
1018                                    QuantizedSerdeType::Fp8 => FP8Linear::deserialize(
1019                                        Cow::from(artifact),
1020                                        &devices[i],
1021                                        &comm,
1022                                        guard.clone(),
1023                                    )?,
1024                                    QuantizedSerdeType::Afq => AfqLayer::deserialize(
1025                                        Cow::from(artifact),
1026                                        &devices[i],
1027                                        &comm,
1028                                        guard.clone(),
1029                                    )?,
1030                                }
1031                            }
1032                        };
1033                        *tensor = deserialized;
1034                    }
1035                    Ok(())
1036                })
1037                .collect::<candle_core::Result<Vec<_>>>()?;
1038        } else {
1039            (0..tensors.len())
1040                .into_par_iter()
1041                .zip(tensors)
1042                .progress_with(bar)
1043                .map(|(i, (tensor, _))| {
1044                    if let Some(artifact) = artifact_isqs.get(&i) {
1045                        let artifact = artifact.data();
1046
1047                        let comm = comms[i].clone();
1048                        let deserialized = match tensor.is_distributed() {
1049                            Some(DistributedKind::ColumnParallel) => {
1050                                ColumnParallelLayer::deserialize(
1051                                    Cow::from(artifact),
1052                                    &devices[i],
1053                                    &comm,
1054                                    guard.clone(),
1055                                )?
1056                            }
1057                            Some(DistributedKind::RowParallel) => RowParallelLayer::deserialize(
1058                                Cow::from(artifact),
1059                                &devices[i],
1060                                &comm,
1061                                guard.clone(),
1062                            )?,
1063                            Some(DistributedKind::Replicated) => ReplicatedLayer::deserialize(
1064                                Cow::from(artifact),
1065                                &devices[i],
1066                                &comm,
1067                                guard.clone(),
1068                            )?,
1069                            None => {
1070                                // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
1071                                let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
1072                                match QuantizedSerdeType::try_from(isq_type as usize)? {
1073                                    QuantizedSerdeType::Gguf => GgufMatMul::deserialize(
1074                                        Cow::from(artifact),
1075                                        &devices[i],
1076                                        &comm,
1077                                        guard.clone(),
1078                                    )?,
1079                                    QuantizedSerdeType::Unquant => UnquantLinear::deserialize(
1080                                        Cow::from(artifact),
1081                                        &devices[i],
1082                                        &comm,
1083                                        guard.clone(),
1084                                    )?,
1085                                    QuantizedSerdeType::Hqq => HqqLayer::deserialize(
1086                                        Cow::from(artifact),
1087                                        &devices[i],
1088                                        &comm,
1089                                        guard.clone(),
1090                                    )?,
1091                                    QuantizedSerdeType::Fp8 => FP8Linear::deserialize(
1092                                        Cow::from(artifact),
1093                                        &devices[i],
1094                                        &comm,
1095                                        guard.clone(),
1096                                    )?,
1097                                    QuantizedSerdeType::Afq => AfqLayer::deserialize(
1098                                        Cow::from(artifact),
1099                                        &devices[i],
1100                                        &comm,
1101                                        guard.clone(),
1102                                    )?,
1103                                }
1104                            }
1105                        };
1106                        *tensor = deserialized;
1107                    }
1108                    Ok(())
1109                })
1110                .collect::<candle_core::Result<Vec<_>>>()?;
1111        }
1112
1113        let delta = Instant::now().duration_since(t_start).as_secs_f32();
1114        info!("Loaded in-situ quantization artifacts into {total_tensors} total tensors. Took {delta:.2}s", );
1115
1116        Ok(())
1117    }
1118}
1119
1120/// Trait for loading models with ISQ.
1121pub(crate) trait IsqModelLoader {
1122    /// Regex to match layers which will have standard *immediate* ISQ applied.
1123    ///
1124    /// Only called on non-adapter models!
1125    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1126        Ok(Vec::new())
1127    }
1128
1129    /// Regex to match layers which will have standard MoQE *immediate* ISQ applied.
1130    ///
1131    /// Only called on non-adapter models!
1132    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
1133        self.isq_layer_regexes(config)
1134    }
1135
1136    /// Regex to match layers which will have standard ISQ applied.
1137    ///
1138    /// Only called on non-adapter models!
1139    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1140        Ok(Vec::new())
1141    }
1142
1143    /// Regex to match layers which will have standard MoQE ISQ applied.
1144    ///
1145    /// Only called on non-adapter models!
1146    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
1147        self.isq_layer_regexes(config)
1148    }
1149}