1use std::{
2 borrow::Cow,
3 collections::{HashMap, HashSet},
4 env,
5 fs::File,
6 path::PathBuf,
7 str::FromStr,
8 sync::{atomic::AtomicUsize, Arc},
9 time::Instant,
10};
11#[derive(Clone)]
19pub struct CowBytesView<'a> {
20 data: Cow<'a, [u8]>,
21 shape: [usize; 1],
22}
23
24impl<'a> CowBytesView<'a> {
25 pub fn new(data: Cow<'a, [u8]>) -> Self {
27 let len = data.len();
28 Self { data, shape: [len] }
29 }
30}
31
32impl safetensors::tensor::View for CowBytesView<'_> {
33 fn dtype(&self) -> safetensors::tensor::Dtype {
34 safetensors::tensor::Dtype::U8
36 }
37
38 fn shape(&self) -> &[usize] {
39 &self.shape
40 }
41
42 fn data(&self) -> Cow<'_, [u8]> {
43 assert!(matches!(self.data, Cow::Borrowed(_)));
44 self.data.clone()
46 }
47
48 fn data_len(&self) -> usize {
49 self.data.len()
50 }
51}
52
53use anyhow::Result;
54use hanzo_ml::{quantized, Context, Device, Tensor};
55use hanzo_quant::{
56 AfqLayer, CollectedImatrixData, ColumnParallelLayer, DistributedKind, F8Q8Linear, FP8Linear,
57 GgufMatMul, HqqLayer, IsqBits, IsqType, MXFP4Layer, QuantMethod, QuantizeOntoGuard,
58 QuantizedSerde, QuantizedSerdeType, ReplicatedLayer, RowParallelLayer, UnquantLinear,
59};
60use indicatif::{MultiProgress, ParallelProgressIterator, ProgressBar, ProgressStyle};
61use itertools::Itertools;
62use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
63use regex::Regex;
64use serde::Deserialize;
65use tokenizers::Tokenizer;
66use tracing::{info, warn};
67
68use crate::{
69 device_map::DeviceMapper, pipeline::EmbeddingModulePaths, topology::LayerTopology,
70 utils::progress::configure_progress_bar, Topology,
71};
72
73pub(crate) const UQFF_RESIDUAL_SAFETENSORS: &str = "residual.safetensors";
74#[cfg(target_pointer_width = "64")]
76const MAX_UQFF_SIZE_BYTES: usize = 10 * 1024 * 1024 * 1024;
77#[cfg(not(target_pointer_width = "64"))]
78const MAX_UQFF_SIZE_BYTES: usize = usize::MAX;
79pub const UQFF_MULTI_FILE_DELIMITER: &str = ";";
80
81pub(crate) struct WeightLoadingState {
82 pub(crate) from_uqff: bool,
83 pub(crate) loading_isq: bool,
84 pub(crate) immediate_isq: bool,
85 pub(crate) write_uqff: bool,
86}
87
88pub(crate) enum WeightLoadingMode {
89 Uqff,
90 ImmediateIsq,
91 PostLoadIsq,
92 UqffSerialization,
93 Plain,
94}
95
96impl From<WeightLoadingState> for WeightLoadingMode {
97 fn from(state: WeightLoadingState) -> Self {
98 if state.from_uqff {
99 Self::Uqff
100 } else if state.immediate_isq {
101 Self::ImmediateIsq
102 } else if state.loading_isq {
103 Self::PostLoadIsq
104 } else if state.write_uqff {
105 Self::UqffSerialization
106 } else {
107 Self::Plain
108 }
109 }
110}
111
112impl WeightLoadingMode {
113 pub(crate) fn message(self, target: &'static str) -> Cow<'static, str> {
114 match self {
115 Self::Uqff => {
116 Cow::Borrowed("Loading residual weights and preparing UQFF placeholders.")
117 }
118 Self::ImmediateIsq => {
119 Cow::Owned(format!("Loading {target} weights with immediate ISQ."))
120 }
121 Self::PostLoadIsq => Cow::Owned(format!(
122 "Loading full-precision {target} weights for post-load ISQ."
123 )),
124 Self::UqffSerialization => {
125 Cow::Owned(format!("Loading {target} weights for UQFF serialization."))
126 }
127 Self::Plain => Cow::Owned(format!("Loading {target} weights.")),
128 }
129 }
130}
131
132pub fn parse_isq_value(s: &str, device: Option<&Device>) -> Result<IsqType, String> {
161 let lowered = s.to_lowercase();
162
163 if let Ok(bits) = IsqBits::try_from(lowered.as_str()) {
165 let tp = match device {
166 Some(dev) => bits.resolve(dev),
167 None => bits.resolve(&Device::Cpu),
168 };
169 #[cfg(feature = "cuda")]
170 {
171 }
173 return Ok(tp);
174 }
175
176 let tp = match lowered.as_str() {
177 "q4_0" => IsqType::Q4_0,
178 "q4_1" => IsqType::Q4_1,
179 "q5_0" => IsqType::Q5_0,
180 "q5_1" => IsqType::Q5_1,
181 "q8_0" => IsqType::Q8_0,
182 "q8_1" => IsqType::Q8_1,
183 "q2k" => IsqType::Q2K,
184 "q3k" => IsqType::Q3K,
185 "q4k" => IsqType::Q4K,
186 "q5k" => IsqType::Q5K,
187 "q6k" => IsqType::Q6K,
188 "q8k" => IsqType::Q8K,
189 "hqq8" => IsqType::HQQ8,
190 "hqq4" => IsqType::HQQ4,
191 "fp8" => IsqType::F8E4M3,
192 "afq8" => IsqType::AFQ8,
193 "afq6" => IsqType::AFQ6,
194 "afq4" => IsqType::AFQ4,
195 "afq3" => IsqType::AFQ3,
196 "afq2" => IsqType::AFQ2,
197 "f8q8" => IsqType::F8Q8,
198 "mxfp4" => IsqType::MXFP4,
199 _ => return Err(format!("ISQ type {s} unknown, choose one of `2`, `3`, `4`, `5`, `6`, `8`, `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q8_1`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `Q8K`, `HQQ8`, `HQQ4`, `FP8`, `AFQ8`, `AFQ6`, `AFQ4`, `AFQ3`, `AFQ2`, `F8Q8`, `MXFP4`.")),
203 };
204 #[cfg(feature = "cuda")]
205 {
206 if !matches!(
207 tp,
208 IsqType::Q4_0
209 | IsqType::Q4_1
210 | IsqType::Q5_0
211 | IsqType::Q5_1
212 | IsqType::Q8_0
213 | IsqType::Q2K
214 | IsqType::Q3K
215 | IsqType::Q4K
216 | IsqType::Q5K
217 | IsqType::Q6K
218 | IsqType::HQQ8
219 | IsqType::HQQ4
220 | IsqType::F8E4M3
221 | IsqType::AFQ2
222 | IsqType::AFQ3
223 | IsqType::AFQ4
224 | IsqType::AFQ6
225 | IsqType::AFQ8
226 | IsqType::F8Q8
227 | IsqType::MXFP4 ) {
231 return Err("ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`, `FP8`, `AFQ8`, `AFQ6`, `AFQ4`, `AFQ3`, `AFQ2`, `F8Q8`, `MXFP4`".to_string());
232 }
233 }
234 Ok(tp)
235}
236
237pub fn expand_isq_value(s: &str) -> anyhow::Result<Vec<IsqType>> {
241 if let Ok(bits) = IsqBits::try_from(s.to_lowercase().as_str()) {
242 return Ok(bits.expand());
243 }
244 let isq = parse_isq_value(s, None).map_err(|e| anyhow::anyhow!("{e}"))?;
245 Ok(vec![isq])
246}
247
248pub fn parse_uqff_shard(filename: &str) -> Option<(String, u64)> {
252 let stem = std::path::Path::new(filename)
253 .file_stem()
254 .and_then(|s| s.to_str())?;
255 let (prefix, suffix) = stem.rsplit_once('-')?;
256 let index = suffix.parse::<u64>().ok()?;
257 Some((prefix.to_string(), index))
258}
259
260pub fn expand_uqff_shards(first_file: &str, available_files: &[String]) -> Vec<String> {
266 let Some((prefix, _)) = parse_uqff_shard(first_file) else {
267 return vec![first_file.to_string()];
268 };
269 let mut shards = Vec::new();
270 for index in 0u64.. {
271 let candidate = format!("{prefix}-{index}.uqff");
272 if available_files.iter().any(|f| f == &candidate) {
273 shards.push(candidate);
274 } else {
275 break;
276 }
277 }
278 if shards.is_empty() {
279 vec![first_file.to_string()]
280 } else {
281 shards
282 }
283}
284
285pub fn resolve_uqff_shorthand(input: &str, available_files: &[String]) -> Option<String> {
291 let lowered = input.to_lowercase();
292
293 if let Ok(bits) = IsqBits::try_from(lowered.as_str()) {
295 for isq_type in bits.expand() {
296 let candidate = format!("{isq_type}-0.uqff");
297 if available_files.iter().any(|f| f == &candidate) {
298 return Some(candidate);
299 }
300 }
301 return None;
302 }
303
304 if let Ok(isq_type) = parse_isq_value(&lowered, None) {
306 let candidate = format!("{isq_type}-0.uqff");
307 if available_files.iter().any(|f| f == &candidate) {
308 return Some(candidate);
309 }
310 }
311
312 None
313}
314
315#[derive(Clone, Debug, Copy, Default, Deserialize, serde::Serialize)]
316pub enum IsqOrganization {
317 #[default]
318 #[serde(rename = "default")]
319 Default,
320 #[serde(rename = "moqe")]
323 MoeExpertsOnly,
324}
325
326impl FromStr for IsqOrganization {
327 type Err = String;
328 fn from_str(s: &str) -> Result<Self, Self::Err> {
329 match s {
330 "default" => Ok(Self::Default),
331 "moqe" => Ok(Self::MoeExpertsOnly),
332 other => Err(format!(
333 "Expected ISQ organization `default` or `moqe`, got `{other}`"
334 )),
335 }
336 }
337}
338
339pub struct UqffFullSer<'a> {
340 pub tokenizer: &'a Tokenizer,
341 pub template_filename: &'a Option<PathBuf>,
342 pub modules: Option<&'a String>,
343 pub module_paths: Option<&'a [EmbeddingModulePaths]>,
344 pub generation_config: Option<&'a PathBuf>,
345 pub config: String,
346 pub processor_filename: &'a Option<PathBuf>,
347 pub preprocessor_filename: &'a Option<PathBuf>,
348}
349
350#[derive(Debug, Clone, Copy)]
351pub enum ImatrixDataSource<'a> {
352 File(&'a PathBuf),
353 Collected,
354}
355
356pub trait IsqModel {
357 #[allow(clippy::type_complexity)]
359 fn get_layers(
360 &mut self,
361 ) -> (
362 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
363 &dyn DeviceMapper,
364 );
365
366 fn begin_track_stats(&mut self) -> anyhow::Result<()> {
368 let layers = self
369 .get_layers()
370 .0
371 .into_iter()
372 .map(|(layer, _)| layer)
373 .collect::<Vec<_>>();
374 for layer in layers {
375 Arc::get_mut(layer).unwrap().begin_track_stats()?;
376 }
377 Ok(())
378 }
379
380 fn extract_imatrix_data(&mut self) -> hanzo_ml::Result<CollectedImatrixData> {
382 let layers = self
383 .get_layers()
384 .0
385 .into_iter()
386 .enumerate()
387 .map(|(i, (layer, _))| (i, layer))
388 .collect::<Vec<_>>();
389 let mut data = HashMap::new();
390 for (i, layer) in layers {
391 data.insert(i, Some(layer.end_track_stats()?.to_vec1::<f32>()?));
392 }
393 Ok(CollectedImatrixData(data))
394 }
395
396 #[allow(clippy::type_complexity)]
399 fn get_layers_moe_experts_only(
400 &mut self,
401 ) -> (
402 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
403 &dyn DeviceMapper,
404 ) {
405 self.get_layers()
406 }
407
408 fn begin_track_stats_moe_experts_only(&mut self) -> anyhow::Result<()> {
411 let layers = self
412 .get_layers()
413 .0
414 .into_iter()
415 .map(|(layer, _)| layer)
416 .collect::<Vec<_>>();
417 for layer in layers {
418 Arc::get_mut(layer).unwrap().begin_track_stats()?;
419 }
420 Ok(())
421 }
422
423 fn extract_imatrix_data_moe_experts_only(&mut self) -> hanzo_ml::Result<CollectedImatrixData> {
426 let layers = self
427 .get_layers()
428 .0
429 .into_iter()
430 .enumerate()
431 .map(|(i, (layer, _))| (i, layer))
432 .collect::<Vec<_>>();
433 let mut data = HashMap::new();
434 for (i, layer) in layers {
435 data.insert(i, Some(layer.end_track_stats()?.to_vec1::<f32>()?));
436 }
437 Ok(CollectedImatrixData(data))
438 }
439
440 fn imatrix_names(&self) -> hanzo_ml::Result<Vec<Option<String>>> {
447 hanzo_ml::bail!("This model does not support quantizing with an imatrix.");
449 }
450
451 fn residual_tensors(&self) -> Vec<(String, Tensor)>;
453
454 fn residual_tensors_moe_experts_only(&self) -> Option<Vec<(String, Tensor)>> {
456 None
457 }
458
459 #[allow(clippy::too_many_arguments)]
464 fn quantize(
465 &mut self,
466 dtype: Option<IsqType>,
467 device: Device,
468 topology: Option<&Topology>,
469 silent: bool,
470 imatrix_source: Option<ImatrixDataSource<'_>>,
471 organization: IsqOrganization,
472 apply_quantization: bool,
473 write_artifacts: Option<&PathBuf>,
474 full_ser: UqffFullSer<'_>,
475 multi_progress: Arc<MultiProgress>,
476 ) -> hanzo_ml::Result<()> {
477 {
478 let mut imatrix_source = imatrix_source;
479 let mut imatrix_to_weight_map: Option<HashMap<usize, Option<Vec<f32>>>> =
480 if apply_quantization {
481 match imatrix_source.take() {
482 Some(ImatrixDataSource::File(imatrix)) => {
483 let ext = imatrix.extension().ok_or(hanzo_ml::Error::msg(
484 "Expected an extension for the imatrix source file.",
485 ))?;
486 if ext == "cimatrix" {
487 info!(
488 "Loading collected imatrix source file: `{}`",
489 imatrix.display()
490 );
491 let data = CollectedImatrixData::load_imatrix(imatrix)?;
492 info!(
493 "Quantizing with collected imatrix data, {} imatrix weights",
494 data.0.iter().filter(|(_, x)| x.is_some()).count()
495 );
496 Some(data.0)
497 } else {
498 if ext != "imatrix" {
499 warn!("Imatrix source file extension is {ext:?}, expected .imatrix/.cimatrix. Assuming GGUF specification");
500 }
501 info!(
502 "Loading GGUF-format imatrix source file: `{}`",
503 imatrix.display()
504 );
505 let mut imatrix_data =
506 quantized::imatrix_file::load_imatrix(imatrix.clone())?;
507 let imatrix_mapping = self
508 .imatrix_names()?
509 .into_iter()
510 .enumerate()
511 .collect::<HashMap<_, _>>();
512
513 let layer_to_weight = imatrix_mapping
514 .into_iter()
515 .map(|(i, name)| {
516 if let Some(name) = name {
517 (i, Some(imatrix_data.remove(&name).unwrap()))
518 } else {
519 (i, None)
520 }
521 })
522 .collect::<HashMap<_, _>>();
523 info!(
524 "Quantizing with imatrix file `{}`, {} imatrix weights",
525 imatrix.display(),
526 layer_to_weight.iter().filter(|(_, x)| x.is_some()).count()
527 );
528 Some(layer_to_weight)
529 }
530 }
531 Some(ImatrixDataSource::Collected) => {
532 let data = match organization {
533 IsqOrganization::Default => self.extract_imatrix_data()?,
534 IsqOrganization::MoeExpertsOnly => {
535 self.extract_imatrix_data_moe_experts_only()?
536 }
537 };
538 let count = data.0.iter().filter(|(_, x)| x.is_some()).count();
540 let save_path = format!("collected-{count}.cimatrix");
541 info!("Saving collected imatrix data to `{save_path}`");
542 data.save_imatrix(save_path)?;
543 info!(
544 "Quantizing with collected imatrix data, {count} imatrix weights"
545 );
546 Some(data.0)
547 }
548 None => None,
549 }
550 } else {
551 if imatrix_source.is_some() {
552 info!("Imatrix source provided but quantization disabled; ignoring input.");
553 }
554 None
555 };
556
557 let (mut tensors, mapper) = match organization {
558 IsqOrganization::Default => self.get_layers(),
559 IsqOrganization::MoeExpertsOnly => self.get_layers_moe_experts_only(),
560 };
561
562 let total_tensors = tensors.len();
563
564 if apply_quantization {
565 let imatrix_to_weight: Vec<Option<Vec<f32>>> =
566 if let Some(mut imatrix_to_weight) = imatrix_to_weight_map.take() {
567 let ordered_keys = imatrix_to_weight
568 .keys()
569 .copied()
570 .sorted()
571 .collect::<Vec<_>>();
572 ordered_keys
573 .into_iter()
574 .map(|layer| imatrix_to_weight.remove(&layer).unwrap())
575 .collect()
576 } else {
577 vec![None; tensors.len()]
578 };
579
580 let n_quantized = AtomicUsize::new(0);
581 if let Some(topology) = topology {
582 let mut dtypes = HashSet::new();
583 for layer in topology.layers.iter().flatten() {
584 if let LayerTopology {
585 isq: Some(isq_dtype),
586 device: _,
587 } = layer
588 {
589 dtypes.insert(isq_dtype);
590 }
591 }
592 info!("Applying in-situ quantization into {:?} to {total_tensors} tensors according to topology.", dtypes.into_iter().collect::<Vec<_>>());
593 } else {
594 info!(
595 "Applying in-situ quantization into {dtype:?} to {total_tensors} tensors."
596 );
597 }
598 let bar = ProgressBar::new(total_tensors as u64);
599 configure_progress_bar(&bar);
600 bar.set_style(
601 ProgressStyle::default_bar()
602 .template("[{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})")
603 .unwrap()
604 .progress_chars("#>-"),
605 );
606 multi_progress.add(bar.clone());
607
608 let layers = topology.map(|x| {
609 x.layers
610 .iter()
611 .filter_map(|topo| topo.as_ref().map(|x| (x.isq, x.device.clone())))
612 .collect::<Vec<_>>()
613 });
614
615 let mut devices_and_dtypes = Vec::new();
616 for (_, layer_num) in &tensors {
617 let device = if let Some(ref layers) = layers {
618 if let Some(layer) = layer_num {
619 layers
620 .get(*layer)
621 .as_ref()
622 .map(|x| x.1.clone())
623 .unwrap_or(Some(device.clone()))
624 .unwrap_or(device.clone())
625 } else {
626 device.clone()
627 }
628 } else if let Some(layer_num) = layer_num {
629 mapper
630 .device_for(*layer_num, false)
631 .cloned()
632 .unwrap_or(device.clone())
633 } else {
634 device.clone()
635 };
636 let dtype = if let Some(ref layers) = layers {
637 if let Some(layer) = layer_num {
638 layers.get(*layer).cloned().map(|x| x.0).unwrap_or(dtype)
639 } else {
640 dtype
641 }
642 } else {
643 dtype
644 };
645 devices_and_dtypes.push((device, dtype));
646 }
647
648 let t_start = Instant::now();
649
650 let mut minimum_max_threads = {
652 let current_rayon_threads = rayon::current_num_threads();
653 if let Some(dtype) = dtype {
654 dtype
655 .get_max_isq_cpu_threads()
656 .map(usize::from)
657 .unwrap_or(current_rayon_threads)
658 } else {
659 current_rayon_threads
660 }
661 };
662 if env::var("HANZO_ISQ_SINGLETHREAD").is_ok() {
663 minimum_max_threads = 1;
664 }
665
666 if matches!(imatrix_source, Some(ImatrixDataSource::Collected)) {
667 minimum_max_threads = 1;
669 }
670
671 info!("Applying ISQ on {minimum_max_threads} threads.");
672
673 let pool = rayon::ThreadPoolBuilder::new()
674 .num_threads(minimum_max_threads)
675 .build()
676 .map_err(hanzo_ml::Error::msg)?;
677
678 let guard = QuantizeOntoGuard::new();
679
680 pool.install(|| {
681 use indicatif::ParallelProgressIterator;
682 use rayon::iter::{
683 IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
684 };
685 if silent {
686 tensors
687 .par_iter_mut()
688 .zip(devices_and_dtypes)
689 .zip(imatrix_to_weight)
690 .for_each(|(((tensor, _), (device, dtype)), imatrix_weight)| {
691 **tensor = tensor
692 .clone()
693 .apply_isq(
694 dtype,
695 device.clone(),
696 &n_quantized,
697 imatrix_weight,
698 guard.clone(),
699 )
700 .unwrap();
701 device.synchronize().unwrap();
702 });
703 } else {
704 tensors
705 .par_iter_mut()
706 .zip(devices_and_dtypes)
707 .zip(imatrix_to_weight)
708 .progress_with(bar)
709 .for_each(|(((tensor, _), (device, dtype)), imatrix_weight)| {
710 **tensor = tensor
711 .clone()
712 .apply_isq(
713 dtype,
714 device.clone(),
715 &n_quantized,
716 imatrix_weight,
717 guard.clone(),
718 )
719 .unwrap();
720 device.synchronize().unwrap();
721 });
722 }
723 });
724
725 let t_end = Instant::now();
726 info!(
727 "Finished quantization pass in {:.2}s ({} tensors).",
728 t_end.duration_since(t_start).as_secs_f32(),
729 total_tensors
730 );
731 } else if imatrix_source.is_some() {
732 info!(
733 "Imatrix data provided but quantization was skipped; existing tensors will be serialized as-is."
734 );
735 } else if write_artifacts.is_some() {
736 info!(
737 "Skipping additional quantization; serializing {total_tensors} existing tensors."
738 );
739 }
740
741 if let Some(serialized) = write_artifacts {
742 info!(
743 "Serializing {total_tensors} ISQ tensors to `{}`.",
744 serialized.display()
745 );
746
747 if serialized.extension().is_none_or(|ext| ext != "uqff") {
748 hanzo_ml::bail!("UQFF output path extension must be `.uqff`",);
749 }
750
751 let bar = ProgressBar::new(total_tensors as u64);
752 configure_progress_bar(&bar);
753 bar.set_style(
754 ProgressStyle::default_bar()
755 .template("[{elapsed_precise}] [{bar:40.red/magenta}] {pos}/{len} ({eta})")
756 .unwrap()
757 .progress_chars("#>-"),
758 );
759
760 #[cfg(any(feature = "metal", feature = "cuda"))]
763 let quantized_values: hanzo_ml::Result<Vec<_>> = {
764 tensors
765 .iter()
766 .enumerate()
767 .filter(|(_, (layer, _))| layer.isq_serde_supported())
768 .map(|(i, (layer, _))| {
769 if !silent {
770 bar.inc(1);
771 }
772 Ok((
773 i.to_string(),
774 match layer.serialize()? {
775 Cow::Borrowed(_) => unreachable!(),
776 Cow::Owned(owned) => owned,
777 },
778 ))
779 })
780 .collect()
781 };
782
783 #[cfg(not(any(feature = "metal", feature = "cuda")))]
784 let quantized_values: hanzo_ml::Result<Vec<_>> = {
785 let pool = rayon::ThreadPoolBuilder::new()
786 .num_threads(2)
787 .build()
788 .map_err(hanzo_ml::Error::msg)?;
789
790 pool.install(|| {
791 use rayon::iter::IntoParallelRefIterator;
792 if silent {
793 tensors
794 .par_iter()
795 .enumerate()
796 .filter(|(_, (layer, _))| layer.isq_serde_supported())
797 .map(|(i, (layer, _))| {
798 Ok((
799 i.to_string(),
800 match layer.serialize()? {
801 Cow::Borrowed(_) => unreachable!(),
802 Cow::Owned(owned) => owned,
803 },
804 ))
805 })
806 .collect::<hanzo_ml::Result<Vec<_>>>()
807 } else {
808 tensors
809 .par_iter()
810 .enumerate()
811 .progress_with(bar)
812 .filter(|(_, (layer, _))| layer.isq_serde_supported())
813 .map(|(i, (layer, _))| {
814 Ok((
815 i.to_string(),
816 match layer.serialize()? {
817 Cow::Borrowed(_) => unreachable!(),
818 Cow::Owned(owned) => owned,
819 },
820 ))
821 })
822 .collect::<hanzo_ml::Result<Vec<_>>>()
823 }
824 })
825 };
826
827 let quantized_values = quantized_values?;
828
829 let parent = serialized
830 .parent()
831 .context("Target UQFF path must have a filename!")?;
832
833 std::fs::create_dir_all(parent)?;
834
835 let file_stem = serialized
836 .file_stem()
837 .context("Target UQFF path must have a file stem!")?
838 .to_string_lossy()
839 .to_string();
840
841 let mut current_chunk = Vec::new();
843 let mut current_bytes: usize = 0;
844 let mut shard_index = 0;
845
846 for (name, tensor) in quantized_values.iter() {
848 let tensor_bytes = tensor.len();
849 if !current_chunk.is_empty()
850 && current_bytes + tensor_bytes > MAX_UQFF_SIZE_BYTES
851 {
852 let mut shard_path = parent.to_path_buf();
853 shard_path.push(format!("{file_stem}-{shard_index}.uqff"));
854 info!(
855 "Writing shard {} to `{}`",
856 shard_index,
857 shard_path.display()
858 );
859 safetensors::serialize_to_file(current_chunk.clone(), None, &shard_path)?;
860 shard_index += 1;
861 current_chunk.clear();
862 current_bytes = 0;
863 }
864 current_bytes += tensor_bytes;
865 current_chunk.push((name, CowBytesView::new(Cow::Borrowed(tensor))));
866 }
867
868 if !current_chunk.is_empty() {
869 let mut shard_path = parent.to_path_buf();
870 shard_path.push(format!("{file_stem}-{shard_index}.uqff"));
871 info!(
872 "Writing final shard {} to `{}`",
873 shard_index,
874 shard_path.display()
875 );
876 safetensors::serialize_to_file(current_chunk.clone(), None, &shard_path)?;
877 }
878
879 let residual = match organization {
880 IsqOrganization::Default => self.residual_tensors(),
881 IsqOrganization::MoeExpertsOnly => self
882 .residual_tensors_moe_experts_only()
883 .unwrap_or(self.residual_tensors()),
884 };
885
886 let residual_out = parent.join(UQFF_RESIDUAL_SAFETENSORS);
887 let config_out = parent.join("config.json");
888 let modules_out = parent.join("modules.json");
889 let tokenizer_out = parent.join("tokenizer.json");
890 let tokenizer_cfg_out = parent.join("tokenizer_config.json");
891 let chat_template_jinja_out = parent.join("chat_template.jinja");
892 let gen_cfg_out = parent.join("generation_config.json");
893 let processor_out = parent.join("processor_config.json");
894 let preprocessor_out = parent.join("preprocessor_config.json");
895
896 info!(
897 "Serializing {} residual tensors to `{}`.",
898 residual.len(),
899 residual_out.display()
900 );
901
902 safetensors::serialize_to_file(residual, None, &residual_out)?;
903
904 let UqffFullSer {
905 tokenizer,
906 template_filename,
907 modules,
908 module_paths,
909 generation_config,
910 config,
911 processor_filename,
912 preprocessor_filename,
913 } = full_ser;
914
915 info!("Serializing configuration to `{}`.", config_out.display());
916
917 std::fs::write(config_out, config)?;
918
919 info!("Serializing tokenizer to `{}`.", tokenizer_out.display());
920
921 serde_json::to_writer_pretty(File::create(&tokenizer_out)?, tokenizer)
922 .map_err(hanzo_ml::Error::msg)?;
923
924 if let Some(template_filename) = template_filename {
925 let template =
926 std::fs::read(template_filename).map_err(hanzo_ml::Error::msg)?;
927
928 if template_filename.extension().map(|e| e.to_str()) == Some(Some("jinja")) {
929 info!(
930 "Serializing chat template to `{}`.",
931 chat_template_jinja_out.display()
932 );
933 std::fs::write(&chat_template_jinja_out, template)
934 .map_err(hanzo_ml::Error::msg)?;
935
936 let sibling_cfg = template_filename
942 .parent()
943 .map(|dir| dir.join("tokenizer_config.json"));
944 if let Some(cfg_path) = sibling_cfg.filter(|p| p.exists()) {
945 info!(
946 "Serializing tokenizer config to `{}`.",
947 tokenizer_cfg_out.display()
948 );
949 std::fs::copy(&cfg_path, &tokenizer_cfg_out)
950 .map_err(hanzo_ml::Error::msg)?;
951 }
952 } else {
953 info!(
954 "Serializing tokenizer config to `{}`.",
955 tokenizer_cfg_out.display()
956 );
957 std::fs::write(&tokenizer_cfg_out, template)
958 .map_err(hanzo_ml::Error::msg)?;
959 }
960 }
961
962 if let Some(generation_config) = generation_config {
963 info!(
964 "Serializing generation config to `{}`.",
965 gen_cfg_out.display()
966 );
967
968 let cfg = std::fs::read(generation_config).map_err(hanzo_ml::Error::msg)?;
969 std::fs::write(&gen_cfg_out, cfg).map_err(hanzo_ml::Error::msg)?;
970 }
971
972 if let Some(processor_config) = processor_filename {
973 info!(
974 "Serializing processor config to `{}`.",
975 processor_out.display()
976 );
977
978 let cfg = std::fs::read(processor_config).map_err(hanzo_ml::Error::msg)?;
979 std::fs::write(&processor_out, cfg).map_err(hanzo_ml::Error::msg)?;
980 }
981
982 if let Some(preprocessor_config) = preprocessor_filename {
983 info!(
984 "Serializing preprocessor config to `{}`.",
985 preprocessor_out.display()
986 );
987
988 let cfg = std::fs::read(preprocessor_config).map_err(hanzo_ml::Error::msg)?;
989 std::fs::write(&preprocessor_out, cfg).map_err(hanzo_ml::Error::msg)?;
990 }
991
992 if let Some(modules) = modules {
993 info!(
994 "Serializing modules manifest to `{}`.",
995 modules_out.display()
996 );
997
998 std::fs::write(&modules_out, modules).map_err(hanzo_ml::Error::msg)?;
999
1000 if let Some(module_paths) = module_paths {
1001 for module in module_paths {
1002 match module {
1003 EmbeddingModulePaths::Transformer { path }
1004 | EmbeddingModulePaths::Pooling { path, .. }
1005 | EmbeddingModulePaths::Dense { path, .. }
1006 | EmbeddingModulePaths::Normalize { path } => {
1007 if path.is_empty() {
1008 continue;
1009 }
1010 let module_dir = parent.join(path.as_str());
1011 std::fs::create_dir_all(&module_dir)
1012 .map_err(hanzo_ml::Error::msg)?;
1013
1014 match module {
1015 EmbeddingModulePaths::Pooling { config, .. } => {
1016 let dest = module_dir.join("config.json");
1017 if config != &dest {
1018 std::fs::copy(config, &dest)
1019 .map_err(hanzo_ml::Error::msg)?;
1020 }
1021 }
1022 EmbeddingModulePaths::Dense { config, model, .. } => {
1023 let dest_cfg = module_dir.join("config.json");
1024 if config != &dest_cfg {
1025 std::fs::copy(config, &dest_cfg)
1026 .map_err(hanzo_ml::Error::msg)?;
1027 }
1028 let dest_model = module_dir.join("model.safetensors");
1029 if model != &dest_model {
1030 std::fs::copy(model, &dest_model)
1031 .map_err(hanzo_ml::Error::msg)?;
1032 }
1033 }
1034 EmbeddingModulePaths::Transformer { .. }
1035 | EmbeddingModulePaths::Normalize { .. } => {}
1036 }
1037 }
1038 }
1039 }
1040 }
1041 }
1042 }
1043 }
1044 Ok(())
1045 }
1046
1047 fn load_from_artifacts(
1048 &mut self,
1049 device: Device,
1050 topology: Option<&Topology>,
1051 silent: bool,
1052 artifacts: &[PathBuf],
1053 ) -> hanzo_ml::Result<()> {
1054 let (tensors, mapper) = self.get_layers();
1055 let total_tensors = tensors.len();
1056
1057 let layers = topology.map(|x| {
1058 x.layers
1059 .iter()
1060 .filter_map(|topo| topo.as_ref().map(|x| (x.isq, x.device.clone())))
1061 .collect::<Vec<_>>()
1062 });
1063
1064 let mut devices = Vec::new();
1065 let mut comms = Vec::new();
1066 for (_, layer_num) in &tensors {
1067 let device = if let Some(ref layers) = layers {
1068 if let Some(layer) = layer_num {
1069 layers
1070 .get(*layer)
1071 .as_ref()
1072 .map(|x| x.1.clone())
1073 .unwrap_or(Some(device.clone()))
1074 .unwrap_or(device.clone())
1075 } else {
1076 device.clone()
1077 }
1078 } else if let Some(layer_num) = layer_num {
1079 mapper
1080 .device_for(*layer_num, false)
1081 .cloned()
1082 .unwrap_or(device.clone())
1083 } else {
1084 device.clone()
1085 };
1086 devices.push(device);
1087 comms.push(mapper.get_comm_for(layer_num.unwrap_or(0))?)
1088 }
1089
1090 let artifacts = unsafe { hanzo_ml::safetensors::MmapedSafetensors::multi(artifacts)? };
1091
1092 let artifact_isqs = artifacts
1093 .tensors()
1094 .into_iter()
1095 .map(|(name, tensor)| {
1096 (
1097 name.parse::<usize>()
1098 .expect("Name should be parseable as usize"),
1099 tensor,
1100 )
1101 })
1102 .collect::<HashMap<_, _>>();
1103
1104 if artifact_isqs.len() != total_tensors {
1105 hanzo_ml::bail!(
1106 "Number of artifacts ({}) does not match the number of ISQ layers ({total_tensors})",
1107 artifact_isqs.len(),
1108 );
1109 }
1110 info!("Loading UQFF artifacts into {total_tensors} quantized tensors.");
1111
1112 let bar = ProgressBar::new(total_tensors as u64);
1113 configure_progress_bar(&bar);
1114 bar.set_style(
1115 ProgressStyle::default_bar()
1116 .template("[{elapsed_precise}] [{bar:40.red/magenta}] {pos}/{len} ({eta})")
1117 .unwrap()
1118 .progress_chars("#>-"),
1119 );
1120
1121 let t_start = Instant::now();
1122
1123 let guard = QuantizeOntoGuard::new();
1124
1125 if silent {
1126 (0..tensors.len())
1127 .into_par_iter()
1128 .zip(tensors)
1129 .map(|(i, (tensor, _))| {
1130 if let Some(artifact) = artifact_isqs.get(&i) {
1131 let artifact = artifact.data();
1132
1133 let comm = comms[i].clone();
1134 let deserialized = match tensor.is_distributed() {
1135 Some(DistributedKind::ColumnParallel) => {
1136 ColumnParallelLayer::deserialize(
1137 Cow::from(artifact),
1138 &devices[i],
1139 &comm,
1140 guard.clone(),
1141 )?
1142 }
1143 Some(DistributedKind::RowParallel) => RowParallelLayer::deserialize(
1144 Cow::from(artifact),
1145 &devices[i],
1146 &comm,
1147 guard.clone(),
1148 )?,
1149 Some(DistributedKind::Replicated) => ReplicatedLayer::deserialize(
1150 Cow::from(artifact),
1151 &devices[i],
1152 &comm,
1153 guard.clone(),
1154 )?,
1155 None => {
1156 let isq_type = artifact[hanzo_quant::UQFF_QUANT_TYPE_OFFSET];
1158 match QuantizedSerdeType::try_from(isq_type as usize)? {
1159 QuantizedSerdeType::Gguf => GgufMatMul::deserialize(
1160 Cow::from(artifact),
1161 &devices[i],
1162 &comm,
1163 guard.clone(),
1164 )?,
1165 QuantizedSerdeType::Unquant => UnquantLinear::deserialize(
1166 Cow::from(artifact),
1167 &devices[i],
1168 &comm,
1169 guard.clone(),
1170 )?,
1171 QuantizedSerdeType::Hqq => HqqLayer::deserialize(
1172 Cow::from(artifact),
1173 &devices[i],
1174 &comm,
1175 guard.clone(),
1176 )?,
1177 QuantizedSerdeType::Fp8 => FP8Linear::deserialize(
1178 Cow::from(artifact),
1179 &devices[i],
1180 &comm,
1181 guard.clone(),
1182 )?,
1183 QuantizedSerdeType::Afq => AfqLayer::deserialize(
1184 Cow::from(artifact),
1185 &devices[i],
1186 &comm,
1187 guard.clone(),
1188 )?,
1189 QuantizedSerdeType::F8Q8 => F8Q8Linear::deserialize(
1190 Cow::from(artifact),
1191 &devices[i],
1192 &comm,
1193 guard.clone(),
1194 )?,
1195 QuantizedSerdeType::Mxfp4 => MXFP4Layer::deserialize(
1196 Cow::from(artifact),
1197 &devices[i],
1198 &comm,
1199 guard.clone(),
1200 )?,
1201 }
1202 }
1203 };
1204 *tensor = deserialized;
1205 }
1206 Ok(())
1207 })
1208 .collect::<hanzo_ml::Result<Vec<_>>>()?;
1209 } else {
1210 (0..tensors.len())
1211 .into_par_iter()
1212 .zip(tensors)
1213 .progress_with(bar)
1214 .map(|(i, (tensor, _))| {
1215 if let Some(artifact) = artifact_isqs.get(&i) {
1216 let artifact = artifact.data();
1217
1218 let comm = comms[i].clone();
1219 let deserialized = match tensor.is_distributed() {
1220 Some(DistributedKind::ColumnParallel) => {
1221 ColumnParallelLayer::deserialize(
1222 Cow::from(artifact),
1223 &devices[i],
1224 &comm,
1225 guard.clone(),
1226 )?
1227 }
1228 Some(DistributedKind::RowParallel) => RowParallelLayer::deserialize(
1229 Cow::from(artifact),
1230 &devices[i],
1231 &comm,
1232 guard.clone(),
1233 )?,
1234 Some(DistributedKind::Replicated) => ReplicatedLayer::deserialize(
1235 Cow::from(artifact),
1236 &devices[i],
1237 &comm,
1238 guard.clone(),
1239 )?,
1240 None => {
1241 let isq_type = artifact[hanzo_quant::UQFF_QUANT_TYPE_OFFSET];
1243 match QuantizedSerdeType::try_from(isq_type as usize)? {
1244 QuantizedSerdeType::Gguf => GgufMatMul::deserialize(
1245 Cow::from(artifact),
1246 &devices[i],
1247 &comm,
1248 guard.clone(),
1249 )?,
1250 QuantizedSerdeType::Unquant => UnquantLinear::deserialize(
1251 Cow::from(artifact),
1252 &devices[i],
1253 &comm,
1254 guard.clone(),
1255 )?,
1256 QuantizedSerdeType::Hqq => HqqLayer::deserialize(
1257 Cow::from(artifact),
1258 &devices[i],
1259 &comm,
1260 guard.clone(),
1261 )?,
1262 QuantizedSerdeType::Fp8 => FP8Linear::deserialize(
1263 Cow::from(artifact),
1264 &devices[i],
1265 &comm,
1266 guard.clone(),
1267 )?,
1268 QuantizedSerdeType::Afq => AfqLayer::deserialize(
1269 Cow::from(artifact),
1270 &devices[i],
1271 &comm,
1272 guard.clone(),
1273 )?,
1274 QuantizedSerdeType::F8Q8 => F8Q8Linear::deserialize(
1275 Cow::from(artifact),
1276 &devices[i],
1277 &comm,
1278 guard.clone(),
1279 )?,
1280 QuantizedSerdeType::Mxfp4 => MXFP4Layer::deserialize(
1281 Cow::from(artifact),
1282 &devices[i],
1283 &comm,
1284 guard.clone(),
1285 )?,
1286 }
1287 }
1288 };
1289 *tensor = deserialized;
1290 }
1291 Ok(())
1292 })
1293 .collect::<hanzo_ml::Result<Vec<_>>>()?;
1294 }
1295
1296 {
1298 let (check_tensors, _) = self.get_layers();
1299 for (i, (tensor, layer_num)) in check_tensors.iter().enumerate() {
1300 if let Some(info) = tensor.dummy_info() {
1301 let artifact_note = if artifact_isqs.contains_key(&i) {
1302 "the matching UQFF artifact did not deserialize into a real layer"
1303 } else {
1304 "the UQFF artifact set did not contain an entry for this layer index"
1305 };
1306 hanzo_ml::bail!(
1307 "UQFF placeholder was not replaced at artifact index {i}, model layer {layer_num:?}: {artifact_note}. {}",
1308 info.message("UQFF artifact loading")
1309 );
1310 }
1311 }
1312 }
1313
1314 let delta = Instant::now().duration_since(t_start).as_secs_f32();
1315 info!("Loaded UQFF artifacts into {total_tensors} quantized tensors. Took {delta:.2}s");
1316
1317 Ok(())
1318 }
1319}
1320
1321pub(crate) trait IsqModelLoader {
1323 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1327 Ok(Vec::new())
1328 }
1329
1330 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
1334 self.isq_layer_regexes(config)
1335 }
1336
1337 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1341 Ok(Vec::new())
1342 }
1343
1344 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
1348 self.isq_layer_regexes(config)
1349 }
1350}
1351
1352#[cfg(test)]
1353mod tests {
1354 use super::*;
1355
1356 #[test]
1357 fn test_resolve_uqff_shorthand_numeric_q8() {
1358 let files = vec!["q8_0-0.uqff".to_string(), "config.json".to_string()];
1359 assert_eq!(
1360 resolve_uqff_shorthand("8", &files),
1361 Some("q8_0-0.uqff".to_string())
1362 );
1363 }
1364
1365 #[test]
1366 fn test_resolve_uqff_shorthand_numeric_afq8() {
1367 let files = vec!["afq8-0.uqff".to_string(), "config.json".to_string()];
1368 assert_eq!(
1369 resolve_uqff_shorthand("8", &files),
1370 Some("afq8-0.uqff".to_string())
1371 );
1372 }
1373
1374 #[test]
1375 fn test_resolve_uqff_shorthand_prefers_platform_variant() {
1376 let files = vec!["q8_0-0.uqff".to_string(), "afq8-0.uqff".to_string()];
1379 let expected = if cfg!(feature = "metal") {
1380 "afq8-0.uqff"
1381 } else {
1382 "q8_0-0.uqff"
1383 };
1384 assert_eq!(
1385 resolve_uqff_shorthand("8", &files),
1386 Some(expected.to_string())
1387 );
1388 }
1389
1390 #[test]
1391 fn test_resolve_uqff_shorthand_numeric_q4() {
1392 let files = vec!["q4k-0.uqff".to_string()];
1393 assert_eq!(
1394 resolve_uqff_shorthand("4", &files),
1395 Some("q4k-0.uqff".to_string())
1396 );
1397 }
1398
1399 #[test]
1400 fn test_resolve_uqff_shorthand_numeric_q5() {
1401 let files = vec!["q5k-0.uqff".to_string()];
1402 assert_eq!(
1403 resolve_uqff_shorthand("5", &files),
1404 Some("q5k-0.uqff".to_string())
1405 );
1406 }
1407
1408 #[test]
1409 fn test_resolve_uqff_shorthand_isq_name() {
1410 let files = vec!["q4k-0.uqff".to_string(), "q8_0-0.uqff".to_string()];
1411 assert_eq!(
1412 resolve_uqff_shorthand("q4k", &files),
1413 Some("q4k-0.uqff".to_string())
1414 );
1415 }
1416
1417 #[test]
1418 fn test_resolve_uqff_shorthand_explicit_filename_returns_none() {
1419 let files = vec!["q8_0-0.uqff".to_string()];
1420 assert_eq!(resolve_uqff_shorthand("q8_0-0.uqff", &files), None);
1421 }
1422
1423 #[test]
1424 fn test_resolve_uqff_shorthand_no_match() {
1425 let files = vec!["config.json".to_string()];
1426 assert_eq!(resolve_uqff_shorthand("8", &files), None);
1427 }
1428}