Skip to main content

hanzo_engine/pipeline/
isq.rs

1use std::{
2    borrow::Cow,
3    collections::{HashMap, HashSet},
4    env,
5    fs::File,
6    path::PathBuf,
7    str::FromStr,
8    sync::{atomic::AtomicUsize, Arc},
9    time::Instant,
10};
11/// Wrapper around a `Cow<'a, [u8]>` buffer that implements
12/// `safetensors::tensor::View`.
13///
14/// *Purpose*: lets us pass raw byte buffers to
15/// `safetensors::serialize_to_file` without cloning them into a `Vec<u8>` or
16/// converting to a higher‑level tensor type.
17/// We expose the buffer as a 1‑D `u8` tensor of shape `[len]`.
18#[derive(Clone)]
19pub struct CowBytesView<'a> {
20    data: Cow<'a, [u8]>,
21    shape: [usize; 1],
22}
23
24impl<'a> CowBytesView<'a> {
25    /// Convenience constructor.
26    pub fn new(data: Cow<'a, [u8]>) -> Self {
27        let len = data.len();
28        Self { data, shape: [len] }
29    }
30}
31
32impl safetensors::tensor::View for CowBytesView<'_> {
33    fn dtype(&self) -> safetensors::tensor::Dtype {
34        // Serialize as raw bytes
35        safetensors::tensor::Dtype::U8
36    }
37
38    fn shape(&self) -> &[usize] {
39        &self.shape
40    }
41
42    fn data(&self) -> Cow<'_, [u8]> {
43        assert!(matches!(self.data, Cow::Borrowed(_)));
44        // Cloning a `Cow` is cheap (only clones the enum, not the data).
45        self.data.clone()
46    }
47
48    fn data_len(&self) -> usize {
49        self.data.len()
50    }
51}
52
53use anyhow::Result;
54use hanzo_ml::{quantized, Context, Device, Tensor};
55use hanzo_quant::{
56    AfqLayer, CollectedImatrixData, ColumnParallelLayer, DistributedKind, F8Q8Linear, FP8Linear,
57    GgufMatMul, HqqLayer, IsqBits, IsqType, MXFP4Layer, QuantMethod, QuantizeOntoGuard,
58    QuantizedSerde, QuantizedSerdeType, ReplicatedLayer, RowParallelLayer, UnquantLinear,
59};
60use indicatif::{MultiProgress, ParallelProgressIterator, ProgressBar, ProgressStyle};
61use itertools::Itertools;
62use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
63use regex::Regex;
64use serde::Deserialize;
65use tokenizers::Tokenizer;
66use tracing::{info, warn};
67
68use crate::{
69    device_map::DeviceMapper, pipeline::EmbeddingModulePaths, topology::LayerTopology,
70    utils::progress::configure_progress_bar, Topology,
71};
72
73pub(crate) const UQFF_RESIDUAL_SAFETENSORS: &str = "residual.safetensors";
74// 10 GB max per file
75#[cfg(target_pointer_width = "64")]
76const MAX_UQFF_SIZE_BYTES: usize = 10 * 1024 * 1024 * 1024;
77#[cfg(not(target_pointer_width = "64"))]
78const MAX_UQFF_SIZE_BYTES: usize = usize::MAX;
79pub const UQFF_MULTI_FILE_DELIMITER: &str = ";";
80
81pub(crate) struct WeightLoadingState {
82    pub(crate) from_uqff: bool,
83    pub(crate) loading_isq: bool,
84    pub(crate) immediate_isq: bool,
85    pub(crate) write_uqff: bool,
86}
87
88pub(crate) enum WeightLoadingMode {
89    Uqff,
90    ImmediateIsq,
91    PostLoadIsq,
92    UqffSerialization,
93    Plain,
94}
95
96impl From<WeightLoadingState> for WeightLoadingMode {
97    fn from(state: WeightLoadingState) -> Self {
98        if state.from_uqff {
99            Self::Uqff
100        } else if state.immediate_isq {
101            Self::ImmediateIsq
102        } else if state.loading_isq {
103            Self::PostLoadIsq
104        } else if state.write_uqff {
105            Self::UqffSerialization
106        } else {
107            Self::Plain
108        }
109    }
110}
111
112impl WeightLoadingMode {
113    pub(crate) fn message(self, target: &'static str) -> Cow<'static, str> {
114        match self {
115            Self::Uqff => {
116                Cow::Borrowed("Loading residual weights and preparing UQFF placeholders.")
117            }
118            Self::ImmediateIsq => {
119                Cow::Owned(format!("Loading {target} weights with immediate ISQ."))
120            }
121            Self::PostLoadIsq => Cow::Owned(format!(
122                "Loading full-precision {target} weights for post-load ISQ."
123            )),
124            Self::UqffSerialization => {
125                Cow::Owned(format!("Loading {target} weights for UQFF serialization."))
126            }
127            Self::Plain => Cow::Owned(format!("Loading {target} weights.")),
128        }
129    }
130}
131
132/// Parse ISQ value.
133///
134/// If the provided value is a valid integer (one of 2,3,4,5,6,8), the best quantization type will be chosen.
135/// Note that the fallback is always a Q/K quantization but on Metal 2,3,4,6,8 uses the fast AFQ.
136///
137/// One of:
138/// - `Q4_0`
139/// - `Q4_1`
140/// - `Q5_0`
141/// - `Q5_1`
142/// - `Q8_0`
143/// - `Q8_1`
144/// - `Q2K`
145/// - `Q3K`
146/// - `Q4K`
147/// - `Q5K`
148/// - `Q6K`
149/// - `Q8K`
150/// - `HQQ1`
151/// - `HQQ2`
152/// - `HQQ3`
153/// - `HQQ4`
154/// - `HQQ8`
155/// - `AFQ2`
156/// - `AFQ3`
157/// - `AFQ4`
158/// - `AFQ6`
159/// - `AFQ8`
160pub fn parse_isq_value(s: &str, device: Option<&Device>) -> Result<IsqType, String> {
161    let lowered = s.to_lowercase();
162
163    // Numeric shorthands resolve via IsqBits
164    if let Ok(bits) = IsqBits::try_from(lowered.as_str()) {
165        let tp = match device {
166            Some(dev) => bits.resolve(dev),
167            None => bits.resolve(&Device::Cpu),
168        };
169        #[cfg(feature = "cuda")]
170        {
171            // All IsqBits resolutions are CUDA-safe, so no extra check needed.
172        }
173        return Ok(tp);
174    }
175
176    let tp = match lowered.as_str() {
177        "q4_0" => IsqType::Q4_0,
178        "q4_1" => IsqType::Q4_1,
179        "q5_0" => IsqType::Q5_0,
180        "q5_1" => IsqType::Q5_1,
181        "q8_0" => IsqType::Q8_0,
182        "q8_1" => IsqType::Q8_1,
183        "q2k" => IsqType::Q2K,
184        "q3k" => IsqType::Q3K,
185        "q4k" => IsqType::Q4K,
186        "q5k" => IsqType::Q5K,
187        "q6k" => IsqType::Q6K,
188        "q8k" => IsqType::Q8K,
189        "hqq8" => IsqType::HQQ8,
190        "hqq4" => IsqType::HQQ4,
191        "fp8" => IsqType::F8E4M3,
192        "afq8" => IsqType::AFQ8,
193        "afq6" => IsqType::AFQ6,
194        "afq4" => IsqType::AFQ4,
195        "afq3" => IsqType::AFQ3,
196        "afq2" => IsqType::AFQ2,
197        "f8q8" => IsqType::F8Q8,
198        "mxfp4" => IsqType::MXFP4,
199        // "hqq3" => IsqType::HQQ3,
200        // "hqq2" => IsqType::HQQ2,
201        // "hqq1" => IsqType::HQQ1,
202        _ => return Err(format!("ISQ type {s} unknown, choose one of `2`, `3`, `4`, `5`, `6`, `8`, `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q8_1`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `Q8K`, `HQQ8`, `HQQ4`, `FP8`, `AFQ8`, `AFQ6`, `AFQ4`, `AFQ3`, `AFQ2`, `F8Q8`, `MXFP4`.")),
203    };
204    #[cfg(feature = "cuda")]
205    {
206        if !matches!(
207            tp,
208            IsqType::Q4_0
209                | IsqType::Q4_1
210                | IsqType::Q5_0
211                | IsqType::Q5_1
212                | IsqType::Q8_0
213                | IsqType::Q2K
214                | IsqType::Q3K
215                | IsqType::Q4K
216                | IsqType::Q5K
217                | IsqType::Q6K
218                | IsqType::HQQ8
219                | IsqType::HQQ4
220                | IsqType::F8E4M3
221                | IsqType::AFQ2
222                | IsqType::AFQ3
223                | IsqType::AFQ4
224                | IsqType::AFQ6
225                | IsqType::AFQ8
226                | IsqType::F8Q8
227                | IsqType::MXFP4 // | IsqType::HQQ3
228                                 // | IsqType::HQQ2
229                                 // | IsqType::HQQ1
230        ) {
231            return Err("ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`, `FP8`, `AFQ8`, `AFQ6`, `AFQ4`, `AFQ3`, `AFQ2`, `F8Q8`, `MXFP4`".to_string());
232        }
233    }
234    Ok(tp)
235}
236
237/// Expand an ISQ specifier into concrete `IsqType` variants.
238/// Numeric shorthands (2-8) produce both the non-Metal and Metal variants;
239/// explicit method names resolve to a single variant.
240pub fn expand_isq_value(s: &str) -> anyhow::Result<Vec<IsqType>> {
241    if let Ok(bits) = IsqBits::try_from(s.to_lowercase().as_str()) {
242        return Ok(bits.expand());
243    }
244    let isq = parse_isq_value(s, None).map_err(|e| anyhow::anyhow!("{e}"))?;
245    Ok(vec![isq])
246}
247
248/// Given a UQFF filename like `"q4k-0.uqff"`, returns `Some(("q4k", 0))`.
249/// Returns `None` for non-sharded filenames like `"model.uqff"` where the
250/// suffix after the last `-` is not a number.
251pub fn parse_uqff_shard(filename: &str) -> Option<(String, u64)> {
252    let stem = std::path::Path::new(filename)
253        .file_stem()
254        .and_then(|s| s.to_str())?;
255    let (prefix, suffix) = stem.rsplit_once('-')?;
256    let index = suffix.parse::<u64>().ok()?;
257    Some((prefix.to_string(), index))
258}
259
260/// Expand a single UQFF filename to include all sibling shards.
261///
262/// Given `"q4k-0.uqff"` and a list of available files, returns
263/// `["q4k-0.uqff", "q4k-1.uqff", ...]` for all sequential indices found.
264/// Non-sharded filenames (those not matching `{prefix}-{N}.uqff`) are returned as-is.
265pub fn expand_uqff_shards(first_file: &str, available_files: &[String]) -> Vec<String> {
266    let Some((prefix, _)) = parse_uqff_shard(first_file) else {
267        return vec![first_file.to_string()];
268    };
269    let mut shards = Vec::new();
270    for index in 0u64.. {
271        let candidate = format!("{prefix}-{index}.uqff");
272        if available_files.iter().any(|f| f == &candidate) {
273            shards.push(candidate);
274        } else {
275            break;
276        }
277    }
278    if shards.is_empty() {
279        vec![first_file.to_string()]
280    } else {
281        shards
282    }
283}
284
285/// Resolve a UQFF shorthand (numeric like `"8"` or ISQ name like `"q4k"`) to an
286/// actual UQFF filename from the available files list.
287///
288/// Returns `Some("q8_0-0.uqff")` if a matching file is found, `None` otherwise.
289/// For numeric shorthands, tries all platform variants via `IsqBits::expand()`.
290pub fn resolve_uqff_shorthand(input: &str, available_files: &[String]) -> Option<String> {
291    let lowered = input.to_lowercase();
292
293    // Try numeric shorthand first (2/3/4/5/6/8)
294    if let Ok(bits) = IsqBits::try_from(lowered.as_str()) {
295        for isq_type in bits.expand() {
296            let candidate = format!("{isq_type}-0.uqff");
297            if available_files.iter().any(|f| f == &candidate) {
298                return Some(candidate);
299            }
300        }
301        return None;
302    }
303
304    // Try explicit ISQ type name (e.g., "q4k", "afq8", "q8_0")
305    if let Ok(isq_type) = parse_isq_value(&lowered, None) {
306        let candidate = format!("{isq_type}-0.uqff");
307        if available_files.iter().any(|f| f == &candidate) {
308            return Some(candidate);
309        }
310    }
311
312    None
313}
314
315#[derive(Clone, Debug, Copy, Default, Deserialize, serde::Serialize)]
316pub enum IsqOrganization {
317    #[default]
318    #[serde(rename = "default")]
319    Default,
320    /// Only quantize MoE experts, if applicable. The enables MoQE.
321    /// <https://arxiv.org/abs/2310.02410>
322    #[serde(rename = "moqe")]
323    MoeExpertsOnly,
324}
325
326impl FromStr for IsqOrganization {
327    type Err = String;
328    fn from_str(s: &str) -> Result<Self, Self::Err> {
329        match s {
330            "default" => Ok(Self::Default),
331            "moqe" => Ok(Self::MoeExpertsOnly),
332            other => Err(format!(
333                "Expected ISQ organization `default` or `moqe`, got `{other}`"
334            )),
335        }
336    }
337}
338
339pub struct UqffFullSer<'a> {
340    pub tokenizer: &'a Tokenizer,
341    pub template_filename: &'a Option<PathBuf>,
342    pub modules: Option<&'a String>,
343    pub module_paths: Option<&'a [EmbeddingModulePaths]>,
344    pub generation_config: Option<&'a PathBuf>,
345    pub config: String,
346    pub processor_filename: &'a Option<PathBuf>,
347    pub preprocessor_filename: &'a Option<PathBuf>,
348}
349
350#[derive(Debug, Clone, Copy)]
351pub enum ImatrixDataSource<'a> {
352    File(&'a PathBuf),
353    Collected,
354}
355
356pub trait IsqModel {
357    /// Corresponds to `IsqOrganization::Default`
358    #[allow(clippy::type_complexity)]
359    fn get_layers(
360        &mut self,
361    ) -> (
362        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
363        &dyn DeviceMapper,
364    );
365
366    /// This is used for imatrix generation internally. Begin stats tracking.
367    fn begin_track_stats(&mut self) -> anyhow::Result<()> {
368        let layers = self
369            .get_layers()
370            .0
371            .into_iter()
372            .map(|(layer, _)| layer)
373            .collect::<Vec<_>>();
374        for layer in layers {
375            Arc::get_mut(layer).unwrap().begin_track_stats()?;
376        }
377        Ok(())
378    }
379
380    /// End stats tracking and return the imatrix data
381    fn extract_imatrix_data(&mut self) -> hanzo_ml::Result<CollectedImatrixData> {
382        let layers = self
383            .get_layers()
384            .0
385            .into_iter()
386            .enumerate()
387            .map(|(i, (layer, _))| (i, layer))
388            .collect::<Vec<_>>();
389        let mut data = HashMap::new();
390        for (i, layer) in layers {
391            data.insert(i, Some(layer.end_track_stats()?.to_vec1::<f32>()?));
392        }
393        Ok(CollectedImatrixData(data))
394    }
395
396    /// Corresponds to `IsqOrganization::MoeExpertsOnly`
397    /// https://arxiv.org/abs/2310.02410
398    #[allow(clippy::type_complexity)]
399    fn get_layers_moe_experts_only(
400        &mut self,
401    ) -> (
402        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
403        &dyn DeviceMapper,
404    ) {
405        self.get_layers()
406    }
407
408    /// Corresponds to `IsqOrganization::MoeExpertsOnly`
409    /// This is used for imatrix generation internally. Begin stats tracking.
410    fn begin_track_stats_moe_experts_only(&mut self) -> anyhow::Result<()> {
411        let layers = self
412            .get_layers()
413            .0
414            .into_iter()
415            .map(|(layer, _)| layer)
416            .collect::<Vec<_>>();
417        for layer in layers {
418            Arc::get_mut(layer).unwrap().begin_track_stats()?;
419        }
420        Ok(())
421    }
422
423    /// Corresponds to `IsqOrganization::MoeExpertsOnly`
424    /// End stats tracking and return the imatrix data
425    fn extract_imatrix_data_moe_experts_only(&mut self) -> hanzo_ml::Result<CollectedImatrixData> {
426        let layers = self
427            .get_layers()
428            .0
429            .into_iter()
430            .enumerate()
431            .map(|(i, (layer, _))| (i, layer))
432            .collect::<Vec<_>>();
433        let mut data = HashMap::new();
434        for (i, layer) in layers {
435            data.insert(i, Some(layer.end_track_stats()?.to_vec1::<f32>()?));
436        }
437        Ok(CollectedImatrixData(data))
438    }
439
440    /// Corresponding to the specific order the model produces ISQ layers (None means
441    /// do not search for in the imatrix file). This is used to pair ISQ layers with the
442    /// corresponding imatrix weights.
443    ///
444    /// - This is only for loading from a llama.cpp imatrix file.
445    /// - Corresponds to `IsqOrganization::Default`
446    fn imatrix_names(&self) -> hanzo_ml::Result<Vec<Option<String>>> {
447        // TODO: make this required.
448        hanzo_ml::bail!("This model does not support quantizing with an imatrix.");
449    }
450
451    /// Residual tensors for generating a UQFF file. Counterpart to [`get_layers`].
452    fn residual_tensors(&self) -> Vec<(String, Tensor)>;
453
454    /// Residual tensors for generating a UQFF file. Counterpart to [`get_layers_moe_experts_only`].
455    fn residual_tensors_moe_experts_only(&self) -> Option<Vec<(String, Tensor)>> {
456        None
457    }
458
459    /// Quantize the model in-situ.
460    ///
461    /// This function will also create a UQFF file, or, if the model supports it (residual tensors are returned),
462    /// a full serialization is created.
463    #[allow(clippy::too_many_arguments)]
464    fn quantize(
465        &mut self,
466        dtype: Option<IsqType>,
467        device: Device,
468        topology: Option<&Topology>,
469        silent: bool,
470        imatrix_source: Option<ImatrixDataSource<'_>>,
471        organization: IsqOrganization,
472        apply_quantization: bool,
473        write_artifacts: Option<&PathBuf>,
474        full_ser: UqffFullSer<'_>,
475        multi_progress: Arc<MultiProgress>,
476    ) -> hanzo_ml::Result<()> {
477        {
478            let mut imatrix_source = imatrix_source;
479            let mut imatrix_to_weight_map: Option<HashMap<usize, Option<Vec<f32>>>> =
480                if apply_quantization {
481                    match imatrix_source.take() {
482                        Some(ImatrixDataSource::File(imatrix)) => {
483                            let ext = imatrix.extension().ok_or(hanzo_ml::Error::msg(
484                                "Expected an extension for the imatrix source file.",
485                            ))?;
486                            if ext == "cimatrix" {
487                                info!(
488                                    "Loading collected imatrix source file: `{}`",
489                                    imatrix.display()
490                                );
491                                let data = CollectedImatrixData::load_imatrix(imatrix)?;
492                                info!(
493                                    "Quantizing with collected imatrix data, {} imatrix weights",
494                                    data.0.iter().filter(|(_, x)| x.is_some()).count()
495                                );
496                                Some(data.0)
497                            } else {
498                                if ext != "imatrix" {
499                                    warn!("Imatrix source file extension is {ext:?}, expected .imatrix/.cimatrix. Assuming GGUF specification");
500                                }
501                                info!(
502                                    "Loading GGUF-format imatrix source file: `{}`",
503                                    imatrix.display()
504                                );
505                                let mut imatrix_data =
506                                    quantized::imatrix_file::load_imatrix(imatrix.clone())?;
507                                let imatrix_mapping = self
508                                    .imatrix_names()?
509                                    .into_iter()
510                                    .enumerate()
511                                    .collect::<HashMap<_, _>>();
512
513                                let layer_to_weight = imatrix_mapping
514                                    .into_iter()
515                                    .map(|(i, name)| {
516                                        if let Some(name) = name {
517                                            (i, Some(imatrix_data.remove(&name).unwrap()))
518                                        } else {
519                                            (i, None)
520                                        }
521                                    })
522                                    .collect::<HashMap<_, _>>();
523                                info!(
524                                    "Quantizing with imatrix file `{}`, {} imatrix weights",
525                                    imatrix.display(),
526                                    layer_to_weight.iter().filter(|(_, x)| x.is_some()).count()
527                                );
528                                Some(layer_to_weight)
529                            }
530                        }
531                        Some(ImatrixDataSource::Collected) => {
532                            let data = match organization {
533                                IsqOrganization::Default => self.extract_imatrix_data()?,
534                                IsqOrganization::MoeExpertsOnly => {
535                                    self.extract_imatrix_data_moe_experts_only()?
536                                }
537                            };
538                            // Save the collected imatrix data so users can reuse it
539                            let count = data.0.iter().filter(|(_, x)| x.is_some()).count();
540                            let save_path = format!("collected-{count}.cimatrix");
541                            info!("Saving collected imatrix data to `{save_path}`");
542                            data.save_imatrix(save_path)?;
543                            info!(
544                                "Quantizing with collected imatrix data, {count} imatrix weights"
545                            );
546                            Some(data.0)
547                        }
548                        None => None,
549                    }
550                } else {
551                    if imatrix_source.is_some() {
552                        info!("Imatrix source provided but quantization disabled; ignoring input.");
553                    }
554                    None
555                };
556
557            let (mut tensors, mapper) = match organization {
558                IsqOrganization::Default => self.get_layers(),
559                IsqOrganization::MoeExpertsOnly => self.get_layers_moe_experts_only(),
560            };
561
562            let total_tensors = tensors.len();
563
564            if apply_quantization {
565                let imatrix_to_weight: Vec<Option<Vec<f32>>> =
566                    if let Some(mut imatrix_to_weight) = imatrix_to_weight_map.take() {
567                        let ordered_keys = imatrix_to_weight
568                            .keys()
569                            .copied()
570                            .sorted()
571                            .collect::<Vec<_>>();
572                        ordered_keys
573                            .into_iter()
574                            .map(|layer| imatrix_to_weight.remove(&layer).unwrap())
575                            .collect()
576                    } else {
577                        vec![None; tensors.len()]
578                    };
579
580                let n_quantized = AtomicUsize::new(0);
581                if let Some(topology) = topology {
582                    let mut dtypes = HashSet::new();
583                    for layer in topology.layers.iter().flatten() {
584                        if let LayerTopology {
585                            isq: Some(isq_dtype),
586                            device: _,
587                        } = layer
588                        {
589                            dtypes.insert(isq_dtype);
590                        }
591                    }
592                    info!("Applying in-situ quantization into {:?} to {total_tensors} tensors according to topology.", dtypes.into_iter().collect::<Vec<_>>());
593                } else {
594                    info!(
595                        "Applying in-situ quantization into {dtype:?} to {total_tensors} tensors."
596                    );
597                }
598                let bar = ProgressBar::new(total_tensors as u64);
599                configure_progress_bar(&bar);
600                bar.set_style(
601                    ProgressStyle::default_bar()
602                        .template("[{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})")
603                        .unwrap()
604                        .progress_chars("#>-"),
605                );
606                multi_progress.add(bar.clone());
607
608                let layers = topology.map(|x| {
609                    x.layers
610                        .iter()
611                        .filter_map(|topo| topo.as_ref().map(|x| (x.isq, x.device.clone())))
612                        .collect::<Vec<_>>()
613                });
614
615                let mut devices_and_dtypes = Vec::new();
616                for (_, layer_num) in &tensors {
617                    let device = if let Some(ref layers) = layers {
618                        if let Some(layer) = layer_num {
619                            layers
620                                .get(*layer)
621                                .as_ref()
622                                .map(|x| x.1.clone())
623                                .unwrap_or(Some(device.clone()))
624                                .unwrap_or(device.clone())
625                        } else {
626                            device.clone()
627                        }
628                    } else if let Some(layer_num) = layer_num {
629                        mapper
630                            .device_for(*layer_num, false)
631                            .cloned()
632                            .unwrap_or(device.clone())
633                    } else {
634                        device.clone()
635                    };
636                    let dtype = if let Some(ref layers) = layers {
637                        if let Some(layer) = layer_num {
638                            layers.get(*layer).cloned().map(|x| x.0).unwrap_or(dtype)
639                        } else {
640                            dtype
641                        }
642                    } else {
643                        dtype
644                    };
645                    devices_and_dtypes.push((device, dtype));
646                }
647
648                let t_start = Instant::now();
649
650                // Get the MINIMUM of the max isq threads the quant method
651                let mut minimum_max_threads = {
652                    let current_rayon_threads = rayon::current_num_threads();
653                    if let Some(dtype) = dtype {
654                        dtype
655                            .get_max_isq_cpu_threads()
656                            .map(usize::from)
657                            .unwrap_or(current_rayon_threads)
658                    } else {
659                        current_rayon_threads
660                    }
661                };
662                if env::var("HANZO_ISQ_SINGLETHREAD").is_ok() {
663                    minimum_max_threads = 1;
664                }
665
666                if matches!(imatrix_source, Some(ImatrixDataSource::Collected)) {
667                    // Collected imatrix means that the model is potentially on the gpu already
668                    minimum_max_threads = 1;
669                }
670
671                info!("Applying ISQ on {minimum_max_threads} threads.");
672
673                let pool = rayon::ThreadPoolBuilder::new()
674                    .num_threads(minimum_max_threads)
675                    .build()
676                    .map_err(hanzo_ml::Error::msg)?;
677
678                let guard = QuantizeOntoGuard::new();
679
680                pool.install(|| {
681                    use indicatif::ParallelProgressIterator;
682                    use rayon::iter::{
683                        IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
684                    };
685                    if silent {
686                        tensors
687                            .par_iter_mut()
688                            .zip(devices_and_dtypes)
689                            .zip(imatrix_to_weight)
690                            .for_each(|(((tensor, _), (device, dtype)), imatrix_weight)| {
691                                **tensor = tensor
692                                    .clone()
693                                    .apply_isq(
694                                        dtype,
695                                        device.clone(),
696                                        &n_quantized,
697                                        imatrix_weight,
698                                        guard.clone(),
699                                    )
700                                    .unwrap();
701                                device.synchronize().unwrap();
702                            });
703                    } else {
704                        tensors
705                            .par_iter_mut()
706                            .zip(devices_and_dtypes)
707                            .zip(imatrix_to_weight)
708                            .progress_with(bar)
709                            .for_each(|(((tensor, _), (device, dtype)), imatrix_weight)| {
710                                **tensor = tensor
711                                    .clone()
712                                    .apply_isq(
713                                        dtype,
714                                        device.clone(),
715                                        &n_quantized,
716                                        imatrix_weight,
717                                        guard.clone(),
718                                    )
719                                    .unwrap();
720                                device.synchronize().unwrap();
721                            });
722                    }
723                });
724
725                let t_end = Instant::now();
726                info!(
727                    "Finished quantization pass in {:.2}s ({} tensors).",
728                    t_end.duration_since(t_start).as_secs_f32(),
729                    total_tensors
730                );
731            } else if imatrix_source.is_some() {
732                info!(
733                    "Imatrix data provided but quantization was skipped; existing tensors will be serialized as-is."
734                );
735            } else if write_artifacts.is_some() {
736                info!(
737                    "Skipping additional quantization; serializing {total_tensors} existing tensors."
738                );
739            }
740
741            if let Some(serialized) = write_artifacts {
742                info!(
743                    "Serializing {total_tensors} ISQ tensors to `{}`.",
744                    serialized.display()
745                );
746
747                if serialized.extension().is_none_or(|ext| ext != "uqff") {
748                    hanzo_ml::bail!("UQFF output path extension must be `.uqff`",);
749                }
750
751                let bar = ProgressBar::new(total_tensors as u64);
752                configure_progress_bar(&bar);
753                bar.set_style(
754                    ProgressStyle::default_bar()
755                        .template("[{elapsed_precise}] [{bar:40.red/magenta}] {pos}/{len} ({eta})")
756                        .unwrap()
757                        .progress_chars("#>-"),
758                );
759
760                // Metal and CUDA require serialization on the current thread because GPU contexts are thread-local.
761                // Using a rayon thread pool (even with n_threads=1) creates a new thread without the GPU context.
762                #[cfg(any(feature = "metal", feature = "cuda"))]
763                let quantized_values: hanzo_ml::Result<Vec<_>> = {
764                    tensors
765                        .iter()
766                        .enumerate()
767                        .filter(|(_, (layer, _))| layer.isq_serde_supported())
768                        .map(|(i, (layer, _))| {
769                            if !silent {
770                                bar.inc(1);
771                            }
772                            Ok((
773                                i.to_string(),
774                                match layer.serialize()? {
775                                    Cow::Borrowed(_) => unreachable!(),
776                                    Cow::Owned(owned) => owned,
777                                },
778                            ))
779                        })
780                        .collect()
781                };
782
783                #[cfg(not(any(feature = "metal", feature = "cuda")))]
784                let quantized_values: hanzo_ml::Result<Vec<_>> = {
785                    let pool = rayon::ThreadPoolBuilder::new()
786                        .num_threads(2)
787                        .build()
788                        .map_err(hanzo_ml::Error::msg)?;
789
790                    pool.install(|| {
791                        use rayon::iter::IntoParallelRefIterator;
792                        if silent {
793                            tensors
794                                .par_iter()
795                                .enumerate()
796                                .filter(|(_, (layer, _))| layer.isq_serde_supported())
797                                .map(|(i, (layer, _))| {
798                                    Ok((
799                                        i.to_string(),
800                                        match layer.serialize()? {
801                                            Cow::Borrowed(_) => unreachable!(),
802                                            Cow::Owned(owned) => owned,
803                                        },
804                                    ))
805                                })
806                                .collect::<hanzo_ml::Result<Vec<_>>>()
807                        } else {
808                            tensors
809                                .par_iter()
810                                .enumerate()
811                                .progress_with(bar)
812                                .filter(|(_, (layer, _))| layer.isq_serde_supported())
813                                .map(|(i, (layer, _))| {
814                                    Ok((
815                                        i.to_string(),
816                                        match layer.serialize()? {
817                                            Cow::Borrowed(_) => unreachable!(),
818                                            Cow::Owned(owned) => owned,
819                                        },
820                                    ))
821                                })
822                                .collect::<hanzo_ml::Result<Vec<_>>>()
823                        }
824                    })
825                };
826
827                let quantized_values = quantized_values?;
828
829                let parent = serialized
830                    .parent()
831                    .context("Target UQFF path must have a filename!")?;
832
833                std::fs::create_dir_all(parent)?;
834
835                let file_stem = serialized
836                    .file_stem()
837                    .context("Target UQFF path must have a file stem!")?
838                    .to_string_lossy()
839                    .to_string();
840
841                // Shard quantized values by cumulative byte size, max MAX_UQFF_SIZE_BYTES per file
842                let mut current_chunk = Vec::new();
843                let mut current_bytes: usize = 0;
844                let mut shard_index = 0;
845
846                // Every 10GB, flush the file. Then save any remaining tensors
847                for (name, tensor) in quantized_values.iter() {
848                    let tensor_bytes = tensor.len();
849                    if !current_chunk.is_empty()
850                        && current_bytes + tensor_bytes > MAX_UQFF_SIZE_BYTES
851                    {
852                        let mut shard_path = parent.to_path_buf();
853                        shard_path.push(format!("{file_stem}-{shard_index}.uqff"));
854                        info!(
855                            "Writing shard {} to `{}`",
856                            shard_index,
857                            shard_path.display()
858                        );
859                        safetensors::serialize_to_file(current_chunk.clone(), None, &shard_path)?;
860                        shard_index += 1;
861                        current_chunk.clear();
862                        current_bytes = 0;
863                    }
864                    current_bytes += tensor_bytes;
865                    current_chunk.push((name, CowBytesView::new(Cow::Borrowed(tensor))));
866                }
867
868                if !current_chunk.is_empty() {
869                    let mut shard_path = parent.to_path_buf();
870                    shard_path.push(format!("{file_stem}-{shard_index}.uqff"));
871                    info!(
872                        "Writing final shard {} to `{}`",
873                        shard_index,
874                        shard_path.display()
875                    );
876                    safetensors::serialize_to_file(current_chunk.clone(), None, &shard_path)?;
877                }
878
879                let residual = match organization {
880                    IsqOrganization::Default => self.residual_tensors(),
881                    IsqOrganization::MoeExpertsOnly => self
882                        .residual_tensors_moe_experts_only()
883                        .unwrap_or(self.residual_tensors()),
884                };
885
886                let residual_out = parent.join(UQFF_RESIDUAL_SAFETENSORS);
887                let config_out = parent.join("config.json");
888                let modules_out = parent.join("modules.json");
889                let tokenizer_out = parent.join("tokenizer.json");
890                let tokenizer_cfg_out = parent.join("tokenizer_config.json");
891                let chat_template_jinja_out = parent.join("chat_template.jinja");
892                let gen_cfg_out = parent.join("generation_config.json");
893                let processor_out = parent.join("processor_config.json");
894                let preprocessor_out = parent.join("preprocessor_config.json");
895
896                info!(
897                    "Serializing {} residual tensors to `{}`.",
898                    residual.len(),
899                    residual_out.display()
900                );
901
902                safetensors::serialize_to_file(residual, None, &residual_out)?;
903
904                let UqffFullSer {
905                    tokenizer,
906                    template_filename,
907                    modules,
908                    module_paths,
909                    generation_config,
910                    config,
911                    processor_filename,
912                    preprocessor_filename,
913                } = full_ser;
914
915                info!("Serializing configuration to `{}`.", config_out.display());
916
917                std::fs::write(config_out, config)?;
918
919                info!("Serializing tokenizer to `{}`.", tokenizer_out.display());
920
921                serde_json::to_writer_pretty(File::create(&tokenizer_out)?, tokenizer)
922                    .map_err(hanzo_ml::Error::msg)?;
923
924                if let Some(template_filename) = template_filename {
925                    let template =
926                        std::fs::read(template_filename).map_err(hanzo_ml::Error::msg)?;
927
928                    if template_filename.extension().map(|e| e.to_str()) == Some(Some("jinja")) {
929                        info!(
930                            "Serializing chat template to `{}`.",
931                            chat_template_jinja_out.display()
932                        );
933                        std::fs::write(&chat_template_jinja_out, template)
934                            .map_err(hanzo_ml::Error::msg)?;
935
936                        // When the chat template is a .jinja file, also save the
937                        // tokenizer_config.json that lives alongside it. This file
938                        // contains bos_token/eos_token/unk_token which are needed
939                        // to render the template correctly. Without it, special
940                        // tokens render as "none" in minijinja.
941                        let sibling_cfg = template_filename
942                            .parent()
943                            .map(|dir| dir.join("tokenizer_config.json"));
944                        if let Some(cfg_path) = sibling_cfg.filter(|p| p.exists()) {
945                            info!(
946                                "Serializing tokenizer config to `{}`.",
947                                tokenizer_cfg_out.display()
948                            );
949                            std::fs::copy(&cfg_path, &tokenizer_cfg_out)
950                                .map_err(hanzo_ml::Error::msg)?;
951                        }
952                    } else {
953                        info!(
954                            "Serializing tokenizer config to `{}`.",
955                            tokenizer_cfg_out.display()
956                        );
957                        std::fs::write(&tokenizer_cfg_out, template)
958                            .map_err(hanzo_ml::Error::msg)?;
959                    }
960                }
961
962                if let Some(generation_config) = generation_config {
963                    info!(
964                        "Serializing generation config to `{}`.",
965                        gen_cfg_out.display()
966                    );
967
968                    let cfg = std::fs::read(generation_config).map_err(hanzo_ml::Error::msg)?;
969                    std::fs::write(&gen_cfg_out, cfg).map_err(hanzo_ml::Error::msg)?;
970                }
971
972                if let Some(processor_config) = processor_filename {
973                    info!(
974                        "Serializing processor config to `{}`.",
975                        processor_out.display()
976                    );
977
978                    let cfg = std::fs::read(processor_config).map_err(hanzo_ml::Error::msg)?;
979                    std::fs::write(&processor_out, cfg).map_err(hanzo_ml::Error::msg)?;
980                }
981
982                if let Some(preprocessor_config) = preprocessor_filename {
983                    info!(
984                        "Serializing preprocessor config to `{}`.",
985                        preprocessor_out.display()
986                    );
987
988                    let cfg = std::fs::read(preprocessor_config).map_err(hanzo_ml::Error::msg)?;
989                    std::fs::write(&preprocessor_out, cfg).map_err(hanzo_ml::Error::msg)?;
990                }
991
992                if let Some(modules) = modules {
993                    info!(
994                        "Serializing modules manifest to `{}`.",
995                        modules_out.display()
996                    );
997
998                    std::fs::write(&modules_out, modules).map_err(hanzo_ml::Error::msg)?;
999
1000                    if let Some(module_paths) = module_paths {
1001                        for module in module_paths {
1002                            match module {
1003                                EmbeddingModulePaths::Transformer { path }
1004                                | EmbeddingModulePaths::Pooling { path, .. }
1005                                | EmbeddingModulePaths::Dense { path, .. }
1006                                | EmbeddingModulePaths::Normalize { path } => {
1007                                    if path.is_empty() {
1008                                        continue;
1009                                    }
1010                                    let module_dir = parent.join(path.as_str());
1011                                    std::fs::create_dir_all(&module_dir)
1012                                        .map_err(hanzo_ml::Error::msg)?;
1013
1014                                    match module {
1015                                        EmbeddingModulePaths::Pooling { config, .. } => {
1016                                            let dest = module_dir.join("config.json");
1017                                            if config != &dest {
1018                                                std::fs::copy(config, &dest)
1019                                                    .map_err(hanzo_ml::Error::msg)?;
1020                                            }
1021                                        }
1022                                        EmbeddingModulePaths::Dense { config, model, .. } => {
1023                                            let dest_cfg = module_dir.join("config.json");
1024                                            if config != &dest_cfg {
1025                                                std::fs::copy(config, &dest_cfg)
1026                                                    .map_err(hanzo_ml::Error::msg)?;
1027                                            }
1028                                            let dest_model = module_dir.join("model.safetensors");
1029                                            if model != &dest_model {
1030                                                std::fs::copy(model, &dest_model)
1031                                                    .map_err(hanzo_ml::Error::msg)?;
1032                                            }
1033                                        }
1034                                        EmbeddingModulePaths::Transformer { .. }
1035                                        | EmbeddingModulePaths::Normalize { .. } => {}
1036                                    }
1037                                }
1038                            }
1039                        }
1040                    }
1041                }
1042            }
1043        }
1044        Ok(())
1045    }
1046
1047    fn load_from_artifacts(
1048        &mut self,
1049        device: Device,
1050        topology: Option<&Topology>,
1051        silent: bool,
1052        artifacts: &[PathBuf],
1053    ) -> hanzo_ml::Result<()> {
1054        let (tensors, mapper) = self.get_layers();
1055        let total_tensors = tensors.len();
1056
1057        let layers = topology.map(|x| {
1058            x.layers
1059                .iter()
1060                .filter_map(|topo| topo.as_ref().map(|x| (x.isq, x.device.clone())))
1061                .collect::<Vec<_>>()
1062        });
1063
1064        let mut devices = Vec::new();
1065        let mut comms = Vec::new();
1066        for (_, layer_num) in &tensors {
1067            let device = if let Some(ref layers) = layers {
1068                if let Some(layer) = layer_num {
1069                    layers
1070                        .get(*layer)
1071                        .as_ref()
1072                        .map(|x| x.1.clone())
1073                        .unwrap_or(Some(device.clone()))
1074                        .unwrap_or(device.clone())
1075                } else {
1076                    device.clone()
1077                }
1078            } else if let Some(layer_num) = layer_num {
1079                mapper
1080                    .device_for(*layer_num, false)
1081                    .cloned()
1082                    .unwrap_or(device.clone())
1083            } else {
1084                device.clone()
1085            };
1086            devices.push(device);
1087            comms.push(mapper.get_comm_for(layer_num.unwrap_or(0))?)
1088        }
1089
1090        let artifacts = unsafe { hanzo_ml::safetensors::MmapedSafetensors::multi(artifacts)? };
1091
1092        let artifact_isqs = artifacts
1093            .tensors()
1094            .into_iter()
1095            .map(|(name, tensor)| {
1096                (
1097                    name.parse::<usize>()
1098                        .expect("Name should be parseable as usize"),
1099                    tensor,
1100                )
1101            })
1102            .collect::<HashMap<_, _>>();
1103
1104        if artifact_isqs.len() != total_tensors {
1105            hanzo_ml::bail!(
1106                "Number of artifacts ({}) does not match the number of ISQ layers ({total_tensors})",
1107                artifact_isqs.len(),
1108            );
1109        }
1110        info!("Loading UQFF artifacts into {total_tensors} quantized tensors.");
1111
1112        let bar = ProgressBar::new(total_tensors as u64);
1113        configure_progress_bar(&bar);
1114        bar.set_style(
1115            ProgressStyle::default_bar()
1116                .template("[{elapsed_precise}] [{bar:40.red/magenta}] {pos}/{len} ({eta})")
1117                .unwrap()
1118                .progress_chars("#>-"),
1119        );
1120
1121        let t_start = Instant::now();
1122
1123        let guard = QuantizeOntoGuard::new();
1124
1125        if silent {
1126            (0..tensors.len())
1127                .into_par_iter()
1128                .zip(tensors)
1129                .map(|(i, (tensor, _))| {
1130                    if let Some(artifact) = artifact_isqs.get(&i) {
1131                        let artifact = artifact.data();
1132
1133                        let comm = comms[i].clone();
1134                        let deserialized = match tensor.is_distributed() {
1135                            Some(DistributedKind::ColumnParallel) => {
1136                                ColumnParallelLayer::deserialize(
1137                                    Cow::from(artifact),
1138                                    &devices[i],
1139                                    &comm,
1140                                    guard.clone(),
1141                                )?
1142                            }
1143                            Some(DistributedKind::RowParallel) => RowParallelLayer::deserialize(
1144                                Cow::from(artifact),
1145                                &devices[i],
1146                                &comm,
1147                                guard.clone(),
1148                            )?,
1149                            Some(DistributedKind::Replicated) => ReplicatedLayer::deserialize(
1150                                Cow::from(artifact),
1151                                &devices[i],
1152                                &comm,
1153                                guard.clone(),
1154                            )?,
1155                            None => {
1156                                // NOTE(hanzoai): isq type is ALWAYS byte 4 (5th) of the tensor.
1157                                let isq_type = artifact[hanzo_quant::UQFF_QUANT_TYPE_OFFSET];
1158                                match QuantizedSerdeType::try_from(isq_type as usize)? {
1159                                    QuantizedSerdeType::Gguf => GgufMatMul::deserialize(
1160                                        Cow::from(artifact),
1161                                        &devices[i],
1162                                        &comm,
1163                                        guard.clone(),
1164                                    )?,
1165                                    QuantizedSerdeType::Unquant => UnquantLinear::deserialize(
1166                                        Cow::from(artifact),
1167                                        &devices[i],
1168                                        &comm,
1169                                        guard.clone(),
1170                                    )?,
1171                                    QuantizedSerdeType::Hqq => HqqLayer::deserialize(
1172                                        Cow::from(artifact),
1173                                        &devices[i],
1174                                        &comm,
1175                                        guard.clone(),
1176                                    )?,
1177                                    QuantizedSerdeType::Fp8 => FP8Linear::deserialize(
1178                                        Cow::from(artifact),
1179                                        &devices[i],
1180                                        &comm,
1181                                        guard.clone(),
1182                                    )?,
1183                                    QuantizedSerdeType::Afq => AfqLayer::deserialize(
1184                                        Cow::from(artifact),
1185                                        &devices[i],
1186                                        &comm,
1187                                        guard.clone(),
1188                                    )?,
1189                                    QuantizedSerdeType::F8Q8 => F8Q8Linear::deserialize(
1190                                        Cow::from(artifact),
1191                                        &devices[i],
1192                                        &comm,
1193                                        guard.clone(),
1194                                    )?,
1195                                    QuantizedSerdeType::Mxfp4 => MXFP4Layer::deserialize(
1196                                        Cow::from(artifact),
1197                                        &devices[i],
1198                                        &comm,
1199                                        guard.clone(),
1200                                    )?,
1201                                }
1202                            }
1203                        };
1204                        *tensor = deserialized;
1205                    }
1206                    Ok(())
1207                })
1208                .collect::<hanzo_ml::Result<Vec<_>>>()?;
1209        } else {
1210            (0..tensors.len())
1211                .into_par_iter()
1212                .zip(tensors)
1213                .progress_with(bar)
1214                .map(|(i, (tensor, _))| {
1215                    if let Some(artifact) = artifact_isqs.get(&i) {
1216                        let artifact = artifact.data();
1217
1218                        let comm = comms[i].clone();
1219                        let deserialized = match tensor.is_distributed() {
1220                            Some(DistributedKind::ColumnParallel) => {
1221                                ColumnParallelLayer::deserialize(
1222                                    Cow::from(artifact),
1223                                    &devices[i],
1224                                    &comm,
1225                                    guard.clone(),
1226                                )?
1227                            }
1228                            Some(DistributedKind::RowParallel) => RowParallelLayer::deserialize(
1229                                Cow::from(artifact),
1230                                &devices[i],
1231                                &comm,
1232                                guard.clone(),
1233                            )?,
1234                            Some(DistributedKind::Replicated) => ReplicatedLayer::deserialize(
1235                                Cow::from(artifact),
1236                                &devices[i],
1237                                &comm,
1238                                guard.clone(),
1239                            )?,
1240                            None => {
1241                                // NOTE(hanzoai): isq type is ALWAYS byte 4 (5th) of the tensor.
1242                                let isq_type = artifact[hanzo_quant::UQFF_QUANT_TYPE_OFFSET];
1243                                match QuantizedSerdeType::try_from(isq_type as usize)? {
1244                                    QuantizedSerdeType::Gguf => GgufMatMul::deserialize(
1245                                        Cow::from(artifact),
1246                                        &devices[i],
1247                                        &comm,
1248                                        guard.clone(),
1249                                    )?,
1250                                    QuantizedSerdeType::Unquant => UnquantLinear::deserialize(
1251                                        Cow::from(artifact),
1252                                        &devices[i],
1253                                        &comm,
1254                                        guard.clone(),
1255                                    )?,
1256                                    QuantizedSerdeType::Hqq => HqqLayer::deserialize(
1257                                        Cow::from(artifact),
1258                                        &devices[i],
1259                                        &comm,
1260                                        guard.clone(),
1261                                    )?,
1262                                    QuantizedSerdeType::Fp8 => FP8Linear::deserialize(
1263                                        Cow::from(artifact),
1264                                        &devices[i],
1265                                        &comm,
1266                                        guard.clone(),
1267                                    )?,
1268                                    QuantizedSerdeType::Afq => AfqLayer::deserialize(
1269                                        Cow::from(artifact),
1270                                        &devices[i],
1271                                        &comm,
1272                                        guard.clone(),
1273                                    )?,
1274                                    QuantizedSerdeType::F8Q8 => F8Q8Linear::deserialize(
1275                                        Cow::from(artifact),
1276                                        &devices[i],
1277                                        &comm,
1278                                        guard.clone(),
1279                                    )?,
1280                                    QuantizedSerdeType::Mxfp4 => MXFP4Layer::deserialize(
1281                                        Cow::from(artifact),
1282                                        &devices[i],
1283                                        &comm,
1284                                        guard.clone(),
1285                                    )?,
1286                                }
1287                            }
1288                        };
1289                        *tensor = deserialized;
1290                    }
1291                    Ok(())
1292                })
1293                .collect::<hanzo_ml::Result<Vec<_>>>()?;
1294        }
1295
1296        // Verify no DummyLayers remain after deserialization
1297        {
1298            let (check_tensors, _) = self.get_layers();
1299            for (i, (tensor, layer_num)) in check_tensors.iter().enumerate() {
1300                if let Some(info) = tensor.dummy_info() {
1301                    let artifact_note = if artifact_isqs.contains_key(&i) {
1302                        "the matching UQFF artifact did not deserialize into a real layer"
1303                    } else {
1304                        "the UQFF artifact set did not contain an entry for this layer index"
1305                    };
1306                    hanzo_ml::bail!(
1307                        "UQFF placeholder was not replaced at artifact index {i}, model layer {layer_num:?}: {artifact_note}. {}",
1308                        info.message("UQFF artifact loading")
1309                    );
1310                }
1311            }
1312        }
1313
1314        let delta = Instant::now().duration_since(t_start).as_secs_f32();
1315        info!("Loaded UQFF artifacts into {total_tensors} quantized tensors. Took {delta:.2}s");
1316
1317        Ok(())
1318    }
1319}
1320
1321/// Trait for loading models with ISQ.
1322pub(crate) trait IsqModelLoader {
1323    /// Regex to match layers which will have standard *immediate* ISQ applied.
1324    ///
1325    /// Only called on non-adapter models!
1326    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1327        Ok(Vec::new())
1328    }
1329
1330    /// Regex to match layers which will have standard MoQE *immediate* ISQ applied.
1331    ///
1332    /// Only called on non-adapter models!
1333    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
1334        self.isq_layer_regexes(config)
1335    }
1336
1337    /// Regex to match layers which will have standard ISQ applied.
1338    ///
1339    /// Only called on non-adapter models!
1340    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1341        Ok(Vec::new())
1342    }
1343
1344    /// Regex to match layers which will have standard MoQE ISQ applied.
1345    ///
1346    /// Only called on non-adapter models!
1347    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
1348        self.isq_layer_regexes(config)
1349    }
1350}
1351
1352#[cfg(test)]
1353mod tests {
1354    use super::*;
1355
1356    #[test]
1357    fn test_resolve_uqff_shorthand_numeric_q8() {
1358        let files = vec!["q8_0-0.uqff".to_string(), "config.json".to_string()];
1359        assert_eq!(
1360            resolve_uqff_shorthand("8", &files),
1361            Some("q8_0-0.uqff".to_string())
1362        );
1363    }
1364
1365    #[test]
1366    fn test_resolve_uqff_shorthand_numeric_afq8() {
1367        let files = vec!["afq8-0.uqff".to_string(), "config.json".to_string()];
1368        assert_eq!(
1369            resolve_uqff_shorthand("8", &files),
1370            Some("afq8-0.uqff".to_string())
1371        );
1372    }
1373
1374    #[test]
1375    fn test_resolve_uqff_shorthand_prefers_platform_variant() {
1376        // expand() returns platform-preferred variant first:
1377        // Metal: [AFQ8, Q8_0], non-Metal: [Q8_0, AFQ8]
1378        let files = vec!["q8_0-0.uqff".to_string(), "afq8-0.uqff".to_string()];
1379        let expected = if cfg!(feature = "metal") {
1380            "afq8-0.uqff"
1381        } else {
1382            "q8_0-0.uqff"
1383        };
1384        assert_eq!(
1385            resolve_uqff_shorthand("8", &files),
1386            Some(expected.to_string())
1387        );
1388    }
1389
1390    #[test]
1391    fn test_resolve_uqff_shorthand_numeric_q4() {
1392        let files = vec!["q4k-0.uqff".to_string()];
1393        assert_eq!(
1394            resolve_uqff_shorthand("4", &files),
1395            Some("q4k-0.uqff".to_string())
1396        );
1397    }
1398
1399    #[test]
1400    fn test_resolve_uqff_shorthand_numeric_q5() {
1401        let files = vec!["q5k-0.uqff".to_string()];
1402        assert_eq!(
1403            resolve_uqff_shorthand("5", &files),
1404            Some("q5k-0.uqff".to_string())
1405        );
1406    }
1407
1408    #[test]
1409    fn test_resolve_uqff_shorthand_isq_name() {
1410        let files = vec!["q4k-0.uqff".to_string(), "q8_0-0.uqff".to_string()];
1411        assert_eq!(
1412            resolve_uqff_shorthand("q4k", &files),
1413            Some("q4k-0.uqff".to_string())
1414        );
1415    }
1416
1417    #[test]
1418    fn test_resolve_uqff_shorthand_explicit_filename_returns_none() {
1419        let files = vec!["q8_0-0.uqff".to_string()];
1420        assert_eq!(resolve_uqff_shorthand("q8_0-0.uqff", &files), None);
1421    }
1422
1423    #[test]
1424    fn test_resolve_uqff_shorthand_no_match() {
1425        let files = vec!["config.json".to_string()];
1426        assert_eq!(resolve_uqff_shorthand("8", &files), None);
1427    }
1428}