1use std::{
2 borrow::Cow,
3 collections::{HashMap, HashSet},
4 env,
5 fs::File,
6 path::PathBuf,
7 str::FromStr,
8 sync::{atomic::AtomicUsize, Arc},
9 time::Instant,
10};
11#[derive(Clone)]
19pub struct CowBytesView<'a> {
20 data: Cow<'a, [u8]>,
21 shape: [usize; 1],
22}
23
24impl<'a> CowBytesView<'a> {
25 pub fn new(data: Cow<'a, [u8]>) -> Self {
27 let len = data.len();
28 Self { data, shape: [len] }
29 }
30}
31
32impl safetensors::tensor::View for CowBytesView<'_> {
33 fn dtype(&self) -> safetensors::tensor::Dtype {
34 safetensors::tensor::Dtype::U8
36 }
37
38 fn shape(&self) -> &[usize] {
39 &self.shape
40 }
41
42 fn data(&self) -> Cow<'_, [u8]> {
43 assert!(matches!(self.data, Cow::Borrowed(_)));
44 self.data.clone()
46 }
47
48 fn data_len(&self) -> usize {
49 self.data.len()
50 }
51}
52
53use anyhow::Result;
54use candle_core::{quantized, Context, Device, Tensor};
55use indicatif::{MultiProgress, ParallelProgressIterator, ProgressBar, ProgressStyle};
56use itertools::Itertools;
57use mistralrs_quant::{
58 AfqLayer, CollectedImatrixData, ColumnParallelLayer, DistributedKind, FP8Linear, GgufMatMul,
59 HqqLayer, IsqType, QuantMethod, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
60 ReplicatedLayer, RowParallelLayer, UnquantLinear,
61};
62use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
63use regex::Regex;
64use serde::Deserialize;
65use tokenizers::Tokenizer;
66use tracing::{info, warn};
67
68use crate::{
69 device_map::DeviceMapper, pipeline::EmbeddingModulePaths, topology::LayerTopology,
70 utils::progress::configure_progress_bar, Topology,
71};
72
73pub(crate) const UQFF_RESIDUAL_SAFETENSORS: &str = "residual.safetensors";
74const MAX_UQFF_SIZE_BYTES: usize = 10 * 1024 * 1024 * 1024;
76pub const UQFF_MULTI_FILE_DELIMITER: &str = ";";
77
78pub fn parse_isq_value(s: &str, device: Option<&Device>) -> Result<IsqType, String> {
107 let is_metal = device.map(|device| device.is_metal()).unwrap_or(false);
108 let tp = match s.to_lowercase().as_str() {
109 "2" if is_metal => IsqType::AFQ2,
110 "2" if !is_metal => IsqType::Q2K,
111 "3" if is_metal => IsqType::AFQ3,
112 "3" if !is_metal => IsqType::Q3K,
113 "4" if is_metal => IsqType::AFQ4,
114 "4" if !is_metal => IsqType::Q4K,
115 "5" => IsqType::Q5K,
116 "6" if is_metal => IsqType::AFQ6,
117 "6" if !is_metal => IsqType::Q6K,
118 "8" if is_metal => IsqType::AFQ8,
119 "8" if !is_metal => IsqType::Q8_0,
120 "q4_0" => IsqType::Q4_0,
121 "q4_1" => IsqType::Q4_1,
122 "q5_0" => IsqType::Q5_0,
123 "q5_1" => IsqType::Q5_1,
124 "q8_0" => IsqType::Q8_0,
125 "q8_1" => IsqType::Q8_1,
126 "q2k" => IsqType::Q2K,
127 "q3k" => IsqType::Q3K,
128 "q4k" => IsqType::Q4K,
129 "q5k" => IsqType::Q5K,
130 "q6k" => IsqType::Q6K,
131 "q8k" => IsqType::Q8K,
132 "hqq8" => IsqType::HQQ8,
133 "hqq4" => IsqType::HQQ4,
134 "fp8" => IsqType::F8E4M3,
135 "afq8" => IsqType::AFQ8,
136 "afq6" => IsqType::AFQ6,
137 "afq4" => IsqType::AFQ4,
138 "afq3" => IsqType::AFQ3,
139 "afq2" => IsqType::AFQ2,
140 _ => return Err(format!("ISQ type {s} unknown, choose one of `2`, `3`, `4`, `6`, `8`, `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q8_1`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `Q8K`, `HQQ8`, `HQQ4`, `FP8`, `AFQ8`, `AFQ6`, `AFQ4`, `AFQ3`, `AFQ2`.")),
144 };
145 #[cfg(feature = "cuda")]
146 {
147 if !matches!(
148 tp,
149 IsqType::Q4_0
150 | IsqType::Q4_1
151 | IsqType::Q5_0
152 | IsqType::Q5_1
153 | IsqType::Q8_0
154 | IsqType::Q2K
155 | IsqType::Q3K
156 | IsqType::Q4K
157 | IsqType::Q5K
158 | IsqType::Q6K
159 | IsqType::HQQ8
160 | IsqType::HQQ4
161 | IsqType::F8E4M3 ) {
165 return Err("ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`, `FP8`".to_string());
166 }
167 }
168 Ok(tp)
169}
170
171#[derive(Clone, Debug, Copy, Default, Deserialize, serde::Serialize)]
172pub enum IsqOrganization {
173 #[default]
174 #[serde(rename = "default")]
175 Default,
176 #[serde(rename = "moqe")]
179 MoeExpertsOnly,
180}
181
182impl FromStr for IsqOrganization {
183 type Err = String;
184 fn from_str(s: &str) -> Result<Self, Self::Err> {
185 match s {
186 "default" => Ok(Self::Default),
187 "moqe" => Ok(Self::MoeExpertsOnly),
188 other => Err(format!(
189 "Expected ISQ organization `default` or `moqe`, got `{other}`"
190 )),
191 }
192 }
193}
194
195pub struct UqffFullSer<'a> {
196 pub tokenizer: &'a Tokenizer,
197 pub template_filename: &'a Option<PathBuf>,
198 pub modules: Option<&'a String>,
199 pub module_paths: Option<&'a [EmbeddingModulePaths]>,
200 pub generation_config: Option<&'a PathBuf>,
201 pub config: String,
202 pub processor_filename: &'a Option<PathBuf>,
203 pub preprocessor_filename: &'a Option<PathBuf>,
204}
205
206#[derive(Debug, Clone, Copy)]
207pub enum ImatrixDataSource<'a> {
208 File(&'a PathBuf),
209 Collected,
210}
211
212pub trait IsqModel {
213 #[allow(clippy::type_complexity)]
215 fn get_layers(
216 &mut self,
217 ) -> (
218 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
219 &dyn DeviceMapper,
220 );
221
222 fn begin_track_stats(&mut self) -> anyhow::Result<()> {
224 let layers = self
225 .get_layers()
226 .0
227 .into_iter()
228 .map(|(layer, _)| layer)
229 .collect::<Vec<_>>();
230 for layer in layers {
231 Arc::get_mut(layer).unwrap().begin_track_stats()?;
232 }
233 Ok(())
234 }
235
236 fn extract_imatrix_data(&mut self) -> candle_core::Result<CollectedImatrixData> {
238 let layers = self
239 .get_layers()
240 .0
241 .into_iter()
242 .enumerate()
243 .map(|(i, (layer, _))| (i, layer))
244 .collect::<Vec<_>>();
245 let mut data = HashMap::new();
246 for (i, layer) in layers {
247 data.insert(i, Some(layer.end_track_stats()?.to_vec1::<f32>()?));
248 }
249 Ok(CollectedImatrixData(data))
250 }
251
252 #[allow(clippy::type_complexity)]
255 fn get_layers_moe_experts_only(
256 &mut self,
257 ) -> (
258 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
259 &dyn DeviceMapper,
260 ) {
261 self.get_layers()
262 }
263
264 fn begin_track_stats_moe_experts_only(&mut self) -> anyhow::Result<()> {
267 let layers = self
268 .get_layers()
269 .0
270 .into_iter()
271 .map(|(layer, _)| layer)
272 .collect::<Vec<_>>();
273 for layer in layers {
274 Arc::get_mut(layer).unwrap().begin_track_stats()?;
275 }
276 Ok(())
277 }
278
279 fn extract_imatrix_data_moe_experts_only(
282 &mut self,
283 ) -> candle_core::Result<CollectedImatrixData> {
284 let layers = self
285 .get_layers()
286 .0
287 .into_iter()
288 .enumerate()
289 .map(|(i, (layer, _))| (i, layer))
290 .collect::<Vec<_>>();
291 let mut data = HashMap::new();
292 for (i, layer) in layers {
293 data.insert(i, Some(layer.end_track_stats()?.to_vec1::<f32>()?));
294 }
295 Ok(CollectedImatrixData(data))
296 }
297
298 fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
305 candle_core::bail!("This model does not support quantizing with an imatrix.");
307 }
308
309 fn residual_tensors(&self) -> Vec<(String, Tensor)>;
311
312 fn residual_tensors_moe_experts_only(&self) -> Option<Vec<(String, Tensor)>> {
314 None
315 }
316
317 #[allow(clippy::too_many_arguments)]
322 fn quantize(
323 &mut self,
324 dtype: Option<IsqType>,
325 device: Device,
326 topology: Option<&Topology>,
327 silent: bool,
328 imatrix_source: Option<ImatrixDataSource<'_>>,
329 organization: IsqOrganization,
330 apply_quantization: bool,
331 write_artifacts: Option<&PathBuf>,
332 full_ser: UqffFullSer<'_>,
333 multi_progress: Arc<MultiProgress>,
334 ) -> candle_core::Result<()> {
335 {
336 let mut imatrix_source = imatrix_source;
337 let mut imatrix_to_weight_map: Option<HashMap<usize, Option<Vec<f32>>>> =
338 if apply_quantization {
339 match imatrix_source.take() {
340 Some(ImatrixDataSource::File(imatrix)) => {
341 let ext = imatrix.extension().ok_or(candle_core::Error::msg(
342 "Expected an extension for the imatrix source file.",
343 ))?;
344 if ext == "cimatrix" {
345 info!(
346 "Loading collected imatrix source file: `{}`",
347 imatrix.display()
348 );
349 let data = CollectedImatrixData::load_imatrix(imatrix)?;
350 info!(
351 "Quantizing with collected imatrix data, {} imatrix weights",
352 data.0.iter().filter(|(_, x)| x.is_some()).count()
353 );
354 Some(data.0)
355 } else {
356 if ext != "imatrix" {
357 warn!("Imatrix source file extension is {ext:?}, expected .imatrix/.cimatrix. Assuming GGUF specification");
358 }
359 info!(
360 "Loading GGUF-format imatrix source file: `{}`",
361 imatrix.display()
362 );
363 let mut imatrix_data =
364 quantized::imatrix_file::load_imatrix(imatrix.clone())?;
365 let imatrix_mapping = self
366 .imatrix_names()?
367 .into_iter()
368 .enumerate()
369 .collect::<HashMap<_, _>>();
370
371 let layer_to_weight = imatrix_mapping
372 .into_iter()
373 .map(|(i, name)| {
374 if let Some(name) = name {
375 (i, Some(imatrix_data.remove(&name).unwrap()))
376 } else {
377 (i, None)
378 }
379 })
380 .collect::<HashMap<_, _>>();
381 info!(
382 "Quantizing with imatrix file `{}`, {} imatrix weights",
383 imatrix.display(),
384 layer_to_weight.iter().filter(|(_, x)| x.is_some()).count()
385 );
386 Some(layer_to_weight)
387 }
388 }
389 Some(ImatrixDataSource::Collected) => {
390 let data = match organization {
391 IsqOrganization::Default => self.extract_imatrix_data()?,
392 IsqOrganization::MoeExpertsOnly => {
393 self.extract_imatrix_data_moe_experts_only()?
394 }
395 };
396 let count = data.0.iter().filter(|(_, x)| x.is_some()).count();
398 let save_path = format!("collected-{count}.cimatrix");
399 info!("Saving collected imatrix data to `{save_path}`");
400 data.save_imatrix(save_path)?;
401 info!(
402 "Quantizing with collected imatrix data, {count} imatrix weights"
403 );
404 Some(data.0)
405 }
406 None => None,
407 }
408 } else {
409 if imatrix_source.is_some() {
410 info!("Imatrix source provided but quantization disabled; ignoring input.");
411 }
412 None
413 };
414
415 let (mut tensors, mapper) = match organization {
416 IsqOrganization::Default => self.get_layers(),
417 IsqOrganization::MoeExpertsOnly => self.get_layers_moe_experts_only(),
418 };
419
420 let total_tensors = tensors.len();
421
422 if apply_quantization {
423 let imatrix_to_weight: Vec<Option<Vec<f32>>> =
424 if let Some(mut imatrix_to_weight) = imatrix_to_weight_map.take() {
425 let ordered_keys = imatrix_to_weight
426 .keys()
427 .copied()
428 .sorted()
429 .collect::<Vec<_>>();
430 ordered_keys
431 .into_iter()
432 .map(|layer| imatrix_to_weight.remove(&layer).unwrap())
433 .collect()
434 } else {
435 vec![None; tensors.len()]
436 };
437
438 let n_quantized = AtomicUsize::new(0);
439 if let Some(topology) = topology {
440 let mut dtypes = HashSet::new();
441 for layer in topology.layers.iter().flatten() {
442 if let LayerTopology {
443 isq: Some(isq_dtype),
444 device: _,
445 } = layer
446 {
447 dtypes.insert(isq_dtype);
448 }
449 }
450 info!("Applying in-situ quantization into {:?} to {total_tensors} tensors according to topology.", dtypes.into_iter().collect::<Vec<_>>());
451 } else {
452 info!(
453 "Applying in-situ quantization into {dtype:?} to {total_tensors} tensors."
454 );
455 }
456 let bar = ProgressBar::new(total_tensors as u64);
457 configure_progress_bar(&bar);
458 bar.set_style(
459 ProgressStyle::default_bar()
460 .template("[{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})")
461 .unwrap()
462 .progress_chars("#>-"),
463 );
464 multi_progress.add(bar.clone());
465
466 let layers = topology.map(|x| {
467 x.layers
468 .iter()
469 .filter_map(|topo| topo.as_ref().map(|x| (x.isq, x.device.clone())))
470 .collect::<Vec<_>>()
471 });
472
473 let mut devices_and_dtypes = Vec::new();
474 for (_, layer_num) in &tensors {
475 let device = if let Some(ref layers) = layers {
476 if let Some(layer) = layer_num {
477 layers
478 .get(*layer)
479 .as_ref()
480 .map(|x| x.1.clone())
481 .unwrap_or(Some(device.clone()))
482 .unwrap_or(device.clone())
483 } else {
484 device.clone()
485 }
486 } else if let Some(layer_num) = layer_num {
487 mapper
488 .device_for(*layer_num, false)
489 .cloned()
490 .unwrap_or(device.clone())
491 } else {
492 device.clone()
493 };
494 let dtype = if let Some(ref layers) = layers {
495 if let Some(layer) = layer_num {
496 layers.get(*layer).cloned().map(|x| x.0).unwrap_or(dtype)
497 } else {
498 dtype
499 }
500 } else {
501 dtype
502 };
503 devices_and_dtypes.push((device, dtype));
504 }
505
506 let t_start = Instant::now();
507
508 let mut minimum_max_threads = {
510 let current_rayon_threads = rayon::current_num_threads();
511 if let Some(dtype) = dtype {
512 dtype
513 .get_max_isq_cpu_threads()
514 .map(usize::from)
515 .unwrap_or(current_rayon_threads)
516 } else {
517 current_rayon_threads
518 }
519 };
520 if env::var("MISTRALRS_ISQ_SINGLETHREAD").is_ok() {
521 minimum_max_threads = 1;
522 }
523
524 if matches!(imatrix_source, Some(ImatrixDataSource::Collected)) {
525 minimum_max_threads = 1;
527 }
528
529 info!("Applying ISQ on {minimum_max_threads} threads.");
530
531 let pool = rayon::ThreadPoolBuilder::new()
532 .num_threads(minimum_max_threads)
533 .build()
534 .map_err(candle_core::Error::msg)?;
535
536 let guard = QuantizeOntoGuard::new();
537
538 pool.install(|| {
539 use indicatif::ParallelProgressIterator;
540 use rayon::iter::{
541 IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
542 };
543 if silent {
544 tensors
545 .par_iter_mut()
546 .zip(devices_and_dtypes)
547 .zip(imatrix_to_weight)
548 .for_each(|(((tensor, _), (device, dtype)), imatrix_weight)| {
549 **tensor = tensor
550 .clone()
551 .apply_isq(
552 dtype,
553 device.clone(),
554 &n_quantized,
555 imatrix_weight,
556 guard.clone(),
557 )
558 .unwrap();
559 device.synchronize().unwrap();
560 });
561 } else {
562 tensors
563 .par_iter_mut()
564 .zip(devices_and_dtypes)
565 .zip(imatrix_to_weight)
566 .progress_with(bar)
567 .for_each(|(((tensor, _), (device, dtype)), imatrix_weight)| {
568 **tensor = tensor
569 .clone()
570 .apply_isq(
571 dtype,
572 device.clone(),
573 &n_quantized,
574 imatrix_weight,
575 guard.clone(),
576 )
577 .unwrap();
578 device.synchronize().unwrap();
579 });
580 }
581 });
582
583 let t_end = Instant::now();
584 info!(
585 "Finished quantization pass in {:.2}s ({} tensors).",
586 t_end.duration_since(t_start).as_secs_f32(),
587 total_tensors
588 );
589 } else if imatrix_source.is_some() {
590 info!(
591 "Imatrix data provided but quantization was skipped; existing tensors will be serialized as-is."
592 );
593 } else if write_artifacts.is_some() {
594 info!(
595 "Skipping additional quantization; serializing {total_tensors} existing tensors."
596 );
597 }
598
599 if let Some(serialized) = write_artifacts {
600 info!(
601 "Serializing {total_tensors} ISQ tensors to `{}`.",
602 serialized.display()
603 );
604
605 if serialized.extension().is_none_or(|ext| ext != "uqff") {
606 candle_core::bail!("UQFF output path extension must be `.uqff`",);
607 }
608
609 let bar = ProgressBar::new(total_tensors as u64);
610 configure_progress_bar(&bar);
611 bar.set_style(
612 ProgressStyle::default_bar()
613 .template("[{elapsed_precise}] [{bar:40.red/magenta}] {pos}/{len} ({eta})")
614 .unwrap()
615 .progress_chars("#>-"),
616 );
617
618 #[cfg(any(feature = "metal", feature = "cuda"))]
621 let quantized_values: candle_core::Result<Vec<_>> = {
622 tensors
623 .iter()
624 .enumerate()
625 .filter(|(_, (layer, _))| layer.isq_serde_supported())
626 .map(|(i, (layer, _))| {
627 if !silent {
628 bar.inc(1);
629 }
630 Ok((
631 i.to_string(),
632 match layer.serialize()? {
633 Cow::Borrowed(_) => unreachable!(),
634 Cow::Owned(owned) => owned,
635 },
636 ))
637 })
638 .collect()
639 };
640
641 #[cfg(not(any(feature = "metal", feature = "cuda")))]
642 let quantized_values: candle_core::Result<Vec<_>> = {
643 let pool = rayon::ThreadPoolBuilder::new()
644 .num_threads(2)
645 .build()
646 .map_err(candle_core::Error::msg)?;
647
648 pool.install(|| {
649 use rayon::iter::IntoParallelRefIterator;
650 if silent {
651 tensors
652 .par_iter()
653 .enumerate()
654 .filter(|(_, (layer, _))| layer.isq_serde_supported())
655 .map(|(i, (layer, _))| {
656 Ok((
657 i.to_string(),
658 match layer.serialize()? {
659 Cow::Borrowed(_) => unreachable!(),
660 Cow::Owned(owned) => owned,
661 },
662 ))
663 })
664 .collect::<candle_core::Result<Vec<_>>>()
665 } else {
666 tensors
667 .par_iter()
668 .enumerate()
669 .progress_with(bar)
670 .filter(|(_, (layer, _))| layer.isq_serde_supported())
671 .map(|(i, (layer, _))| {
672 Ok((
673 i.to_string(),
674 match layer.serialize()? {
675 Cow::Borrowed(_) => unreachable!(),
676 Cow::Owned(owned) => owned,
677 },
678 ))
679 })
680 .collect::<candle_core::Result<Vec<_>>>()
681 }
682 })
683 };
684
685 let quantized_values = quantized_values?;
686
687 let parent = serialized
688 .parent()
689 .context("Target UQFF path must have a filename!")?;
690
691 std::fs::create_dir_all(parent)?;
692
693 let file_stem = serialized
694 .file_stem()
695 .context("Target UQFF path must have a file stem!")?
696 .to_string_lossy()
697 .to_string();
698
699 let mut current_chunk = Vec::new();
701 let mut current_bytes: usize = 0;
702 let mut shard_index = 0;
703
704 for (name, tensor) in quantized_values.iter() {
706 let tensor_bytes = tensor.len();
707 if !current_chunk.is_empty()
708 && current_bytes + tensor_bytes > MAX_UQFF_SIZE_BYTES
709 {
710 let mut shard_path = parent.to_path_buf();
711 shard_path.push(format!("{file_stem}-{shard_index}.uqff"));
712 info!(
713 "Writing shard {} to `{}`",
714 shard_index,
715 shard_path.display()
716 );
717 safetensors::serialize_to_file(current_chunk.clone(), None, &shard_path)?;
718 shard_index += 1;
719 current_chunk.clear();
720 current_bytes = 0;
721 }
722 current_bytes += tensor_bytes;
723 current_chunk.push((name, CowBytesView::new(Cow::Borrowed(tensor))));
724 }
725
726 if !current_chunk.is_empty() {
727 let mut shard_path = parent.to_path_buf();
728 shard_path.push(format!("{file_stem}-{shard_index}.uqff"));
729 info!(
730 "Writing final shard {} to `{}`",
731 shard_index,
732 shard_path.display()
733 );
734 safetensors::serialize_to_file(current_chunk.clone(), None, &shard_path)?;
735 }
736
737 let residual = match organization {
738 IsqOrganization::Default => self.residual_tensors(),
739 IsqOrganization::MoeExpertsOnly => self
740 .residual_tensors_moe_experts_only()
741 .unwrap_or(self.residual_tensors()),
742 };
743
744 let residual_out = parent.join(UQFF_RESIDUAL_SAFETENSORS);
745 let config_out = parent.join("config.json");
746 let modules_out = parent.join("modules.json");
747 let tokenizer_out = parent.join("tokenizer.json");
748 let tokenizer_cfg_out = parent.join("tokenizer_config.json");
749 let chat_template_jinja_out = parent.join("chat_template.jinja");
750 let gen_cfg_out = parent.join("generation_config.json");
751 let processor_out = parent.join("processor_config.json");
752 let preprocessor_out = parent.join("preprocessor_config.json");
753
754 info!(
755 "Serializing {} residual tensors to `{}`.",
756 residual.len(),
757 residual_out.display()
758 );
759
760 safetensors::serialize_to_file(residual, None, &residual_out)?;
761
762 let UqffFullSer {
763 tokenizer,
764 template_filename,
765 modules,
766 module_paths,
767 generation_config,
768 config,
769 processor_filename,
770 preprocessor_filename,
771 } = full_ser;
772
773 info!("Serializing configuration to `{}`.", config_out.display());
774
775 std::fs::write(config_out, config)?;
776
777 info!("Serializing tokenizer to `{}`.", tokenizer_out.display());
778
779 serde_json::to_writer_pretty(File::create(&tokenizer_out)?, tokenizer)
780 .map_err(candle_core::Error::msg)?;
781
782 if let Some(template_filename) = template_filename {
783 let template =
784 std::fs::read(template_filename).map_err(candle_core::Error::msg)?;
785
786 if template_filename.extension().map(|e| e.to_str()) == Some(Some("jinja")) {
787 info!(
788 "Serializing chat template to `{}`.",
789 chat_template_jinja_out.display()
790 );
791 std::fs::write(&chat_template_jinja_out, template)
792 .map_err(candle_core::Error::msg)?;
793 } else {
794 info!(
795 "Serializing tokenizer config to `{}`.",
796 tokenizer_cfg_out.display()
797 );
798 std::fs::write(&tokenizer_cfg_out, template)
799 .map_err(candle_core::Error::msg)?;
800 }
801 }
802
803 if let Some(generation_config) = generation_config {
804 info!(
805 "Serializing generation config to `{}`.",
806 gen_cfg_out.display()
807 );
808
809 let cfg = std::fs::read(generation_config).map_err(candle_core::Error::msg)?;
810 std::fs::write(&gen_cfg_out, cfg).map_err(candle_core::Error::msg)?;
811 }
812
813 if let Some(processor_config) = processor_filename {
814 info!(
815 "Serializing processor config to `{}`.",
816 processor_out.display()
817 );
818
819 let cfg = std::fs::read(processor_config).map_err(candle_core::Error::msg)?;
820 std::fs::write(&processor_out, cfg).map_err(candle_core::Error::msg)?;
821 }
822
823 if let Some(preprocessor_config) = preprocessor_filename {
824 info!(
825 "Serializing preprocessor config to `{}`.",
826 preprocessor_out.display()
827 );
828
829 let cfg =
830 std::fs::read(preprocessor_config).map_err(candle_core::Error::msg)?;
831 std::fs::write(&preprocessor_out, cfg).map_err(candle_core::Error::msg)?;
832 }
833
834 if let Some(modules) = modules {
835 info!(
836 "Serializing modules manifest to `{}`.",
837 modules_out.display()
838 );
839
840 std::fs::write(&modules_out, modules).map_err(candle_core::Error::msg)?;
841
842 if let Some(module_paths) = module_paths {
843 for module in module_paths {
844 match module {
845 EmbeddingModulePaths::Transformer { path }
846 | EmbeddingModulePaths::Pooling { path, .. }
847 | EmbeddingModulePaths::Dense { path, .. }
848 | EmbeddingModulePaths::Normalize { path } => {
849 if path.is_empty() {
850 continue;
851 }
852 let module_dir = parent.join(path.as_str());
853 std::fs::create_dir_all(&module_dir)
854 .map_err(candle_core::Error::msg)?;
855
856 match module {
857 EmbeddingModulePaths::Pooling { config, .. } => {
858 let dest = module_dir.join("config.json");
859 if config != &dest {
860 std::fs::copy(config, &dest)
861 .map_err(candle_core::Error::msg)?;
862 }
863 }
864 EmbeddingModulePaths::Dense { config, model, .. } => {
865 let dest_cfg = module_dir.join("config.json");
866 if config != &dest_cfg {
867 std::fs::copy(config, &dest_cfg)
868 .map_err(candle_core::Error::msg)?;
869 }
870 let dest_model = module_dir.join("model.safetensors");
871 if model != &dest_model {
872 std::fs::copy(model, &dest_model)
873 .map_err(candle_core::Error::msg)?;
874 }
875 }
876 EmbeddingModulePaths::Transformer { .. }
877 | EmbeddingModulePaths::Normalize { .. } => {}
878 }
879 }
880 }
881 }
882 }
883 }
884 }
885 }
886 Ok(())
887 }
888
889 fn load_from_artifacts(
890 &mut self,
891 device: Device,
892 topology: Option<&Topology>,
893 silent: bool,
894 artifacts: &[PathBuf],
895 ) -> candle_core::Result<()> {
896 let (tensors, mapper) = self.get_layers();
897 let total_tensors = tensors.len();
898
899 let layers = topology.map(|x| {
900 x.layers
901 .iter()
902 .filter_map(|topo| topo.as_ref().map(|x| (x.isq, x.device.clone())))
903 .collect::<Vec<_>>()
904 });
905
906 let mut devices = Vec::new();
907 let mut comms = Vec::new();
908 for (_, layer_num) in &tensors {
909 let device = if let Some(ref layers) = layers {
910 if let Some(layer) = layer_num {
911 layers
912 .get(*layer)
913 .as_ref()
914 .map(|x| x.1.clone())
915 .unwrap_or(Some(device.clone()))
916 .unwrap_or(device.clone())
917 } else {
918 device.clone()
919 }
920 } else if let Some(layer_num) = layer_num {
921 mapper
922 .device_for(*layer_num, false)
923 .cloned()
924 .unwrap_or(device.clone())
925 } else {
926 device.clone()
927 };
928 devices.push(device);
929 comms.push(mapper.get_comm_for(layer_num.unwrap_or(0))?)
930 }
931
932 let artifacts = unsafe { candle_core::safetensors::MmapedSafetensors::multi(artifacts)? };
933
934 let artifact_isqs = artifacts
935 .tensors()
936 .into_iter()
937 .map(|(name, tensor)| {
938 (
939 name.parse::<usize>()
940 .expect("Name should be parseable as usize"),
941 tensor,
942 )
943 })
944 .collect::<HashMap<_, _>>();
945
946 if artifact_isqs.len() != total_tensors {
947 candle_core::bail!(
948 "Number of artifacts ({}) does not match the number of ISQ layers ({total_tensors})",
949 artifact_isqs.len(),
950 );
951 }
952
953 let bar = ProgressBar::new(total_tensors as u64);
954 configure_progress_bar(&bar);
955 bar.set_style(
956 ProgressStyle::default_bar()
957 .template("[{elapsed_precise}] [{bar:40.red/magenta}] {pos}/{len} ({eta})")
958 .unwrap()
959 .progress_chars("#>-"),
960 );
961
962 let t_start = Instant::now();
963
964 let guard = QuantizeOntoGuard::new();
965
966 if silent {
967 (0..tensors.len())
968 .into_par_iter()
969 .zip(tensors)
970 .map(|(i, (tensor, _))| {
971 if let Some(artifact) = artifact_isqs.get(&i) {
972 let artifact = artifact.data();
973
974 let comm = comms[i].clone();
975 let deserialized = match tensor.is_distributed() {
976 Some(DistributedKind::ColumnParallel) => {
977 ColumnParallelLayer::deserialize(
978 Cow::from(artifact),
979 &devices[i],
980 &comm,
981 guard.clone(),
982 )?
983 }
984 Some(DistributedKind::RowParallel) => RowParallelLayer::deserialize(
985 Cow::from(artifact),
986 &devices[i],
987 &comm,
988 guard.clone(),
989 )?,
990 Some(DistributedKind::Replicated) => ReplicatedLayer::deserialize(
991 Cow::from(artifact),
992 &devices[i],
993 &comm,
994 guard.clone(),
995 )?,
996 None => {
997 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
999 match QuantizedSerdeType::try_from(isq_type as usize)? {
1000 QuantizedSerdeType::Gguf => GgufMatMul::deserialize(
1001 Cow::from(artifact),
1002 &devices[i],
1003 &comm,
1004 guard.clone(),
1005 )?,
1006 QuantizedSerdeType::Unquant => UnquantLinear::deserialize(
1007 Cow::from(artifact),
1008 &devices[i],
1009 &comm,
1010 guard.clone(),
1011 )?,
1012 QuantizedSerdeType::Hqq => HqqLayer::deserialize(
1013 Cow::from(artifact),
1014 &devices[i],
1015 &comm,
1016 guard.clone(),
1017 )?,
1018 QuantizedSerdeType::Fp8 => FP8Linear::deserialize(
1019 Cow::from(artifact),
1020 &devices[i],
1021 &comm,
1022 guard.clone(),
1023 )?,
1024 QuantizedSerdeType::Afq => AfqLayer::deserialize(
1025 Cow::from(artifact),
1026 &devices[i],
1027 &comm,
1028 guard.clone(),
1029 )?,
1030 }
1031 }
1032 };
1033 *tensor = deserialized;
1034 }
1035 Ok(())
1036 })
1037 .collect::<candle_core::Result<Vec<_>>>()?;
1038 } else {
1039 (0..tensors.len())
1040 .into_par_iter()
1041 .zip(tensors)
1042 .progress_with(bar)
1043 .map(|(i, (tensor, _))| {
1044 if let Some(artifact) = artifact_isqs.get(&i) {
1045 let artifact = artifact.data();
1046
1047 let comm = comms[i].clone();
1048 let deserialized = match tensor.is_distributed() {
1049 Some(DistributedKind::ColumnParallel) => {
1050 ColumnParallelLayer::deserialize(
1051 Cow::from(artifact),
1052 &devices[i],
1053 &comm,
1054 guard.clone(),
1055 )?
1056 }
1057 Some(DistributedKind::RowParallel) => RowParallelLayer::deserialize(
1058 Cow::from(artifact),
1059 &devices[i],
1060 &comm,
1061 guard.clone(),
1062 )?,
1063 Some(DistributedKind::Replicated) => ReplicatedLayer::deserialize(
1064 Cow::from(artifact),
1065 &devices[i],
1066 &comm,
1067 guard.clone(),
1068 )?,
1069 None => {
1070 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
1072 match QuantizedSerdeType::try_from(isq_type as usize)? {
1073 QuantizedSerdeType::Gguf => GgufMatMul::deserialize(
1074 Cow::from(artifact),
1075 &devices[i],
1076 &comm,
1077 guard.clone(),
1078 )?,
1079 QuantizedSerdeType::Unquant => UnquantLinear::deserialize(
1080 Cow::from(artifact),
1081 &devices[i],
1082 &comm,
1083 guard.clone(),
1084 )?,
1085 QuantizedSerdeType::Hqq => HqqLayer::deserialize(
1086 Cow::from(artifact),
1087 &devices[i],
1088 &comm,
1089 guard.clone(),
1090 )?,
1091 QuantizedSerdeType::Fp8 => FP8Linear::deserialize(
1092 Cow::from(artifact),
1093 &devices[i],
1094 &comm,
1095 guard.clone(),
1096 )?,
1097 QuantizedSerdeType::Afq => AfqLayer::deserialize(
1098 Cow::from(artifact),
1099 &devices[i],
1100 &comm,
1101 guard.clone(),
1102 )?,
1103 }
1104 }
1105 };
1106 *tensor = deserialized;
1107 }
1108 Ok(())
1109 })
1110 .collect::<candle_core::Result<Vec<_>>>()?;
1111 }
1112
1113 let delta = Instant::now().duration_since(t_start).as_secs_f32();
1114 info!("Loaded in-situ quantization artifacts into {total_tensors} total tensors. Took {delta:.2}s", );
1115
1116 Ok(())
1117 }
1118}
1119
1120pub(crate) trait IsqModelLoader {
1122 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1126 Ok(Vec::new())
1127 }
1128
1129 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
1133 self.isq_layer_regexes(config)
1134 }
1135
1136 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1140 Ok(Vec::new())
1141 }
1142
1143 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
1147 self.isq_layer_regexes(config)
1148 }
1149}