1use std::path::Path;
13
14use candle_core::{safetensors::MmapedSafetensors, DType, Device, Shape, Tensor};
15use candle_nn::{var_builder::SimpleBackend, Init, VarBuilder};
16
17const FP8_BLOCK_SIZE: usize = 128;
18
19pub fn is_fp8_quantized(config_path: &Path) -> bool {
21 let Ok(data) = std::fs::read_to_string(config_path) else {
22 return false;
23 };
24 let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) else {
25 return false;
26 };
27 for root in [&json, json.get("text_config").unwrap_or(&json)] {
29 let is_fp8 = root
30 .get("quantization_config")
31 .and_then(|qc| qc.get("quant_method"))
32 .and_then(|qm| qm.as_str())
33 .map(|s| s == "fp8")
34 .unwrap_or(false);
35 if is_fp8 {
36 return true;
37 }
38 }
39 false
40}
41
42fn dequantize_fp8_blockwise(weight: &Tensor, scale_inv: &Tensor) -> candle_core::Result<Tensor> {
44 let (m, n) = weight.dims2()?;
45 let bm = FP8_BLOCK_SIZE;
46 let bn = FP8_BLOCK_SIZE;
47 let blocks_m = (m + bm - 1) / bm;
48 let blocks_n = (n + bn - 1) / bn;
49
50 let weight_f32 = weight.to_dtype(DType::F32)?;
52 let scale_f32 = scale_inv.to_dtype(DType::F32)?;
53
54 let weight_blocked = weight_f32.reshape((blocks_m, bm, blocks_n, bn))?;
58 let scale_blocked = scale_f32.reshape((blocks_m, 1usize, blocks_n, 1usize))?;
59
60 let dequantized = weight_blocked.broadcast_mul(&scale_blocked)?;
61 dequantized.reshape((m, n))
62}
63
64struct Fp8Backend {
67 inner: MmapedSafetensors,
68}
69
70impl SimpleBackend for Fp8Backend {
71 fn get(
72 &self,
73 s: Shape,
74 name: &str,
75 _h: Init,
76 dtype: DType,
77 dev: &Device,
78 ) -> candle_core::Result<Tensor> {
79 let tensor = self.load_tensor(name, dtype, dev)?;
80 if tensor.shape() != &s {
81 Err(candle_core::Error::UnexpectedShape {
82 msg: format!("shape mismatch for {name}"),
83 expected: s,
84 got: tensor.shape().clone(),
85 }
86 .bt())?
87 }
88 Ok(tensor)
89 }
90
91 fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> candle_core::Result<Tensor> {
92 self.load_tensor(name, dtype, dev)
93 }
94
95 fn contains_tensor(&self, name: &str) -> bool {
96 self.inner.get(name).is_ok()
97 }
98}
99
100impl Fp8Backend {
101 fn load_tensor(
102 &self,
103 name: &str,
104 dtype: DType,
105 dev: &Device,
106 ) -> candle_core::Result<Tensor> {
107 let scale_name = format!("{name}_scale_inv");
108
109 if self.inner.get(&scale_name).is_ok() {
110 let weight = self.inner.load(name, &Device::Cpu)?;
112 let scale = self.inner.load(&scale_name, &Device::Cpu)?;
113
114 let dequantized = dequantize_fp8_blockwise(&weight, &scale)?;
115 dequantized.to_dtype(dtype)?.to_device(dev)
116 } else {
117 let view = self.inner.get(name)?;
119 let file_dtype: DType = view.dtype().try_into()?;
120
121 if file_dtype == DType::F8E4M3 {
122 let tensor = self.inner.load(name, &Device::Cpu)?;
124 tensor.to_dtype(dtype)?.to_device(dev)
125 } else {
126 self.inner.load(name, dev)?.to_dtype(dtype)
128 }
129 }
130 }
131}
132
133pub unsafe fn load_fp8_var_builder<'a>(
139 filenames: &[std::path::PathBuf],
140 dtype: DType,
141 device: &Device,
142) -> anyhow::Result<VarBuilder<'a>> {
143 let inner = MmapedSafetensors::multi(filenames)?;
144
145 let fp8_count = inner
146 .tensors()
147 .iter()
148 .filter(|(_, v)| v.dtype() == safetensors::tensor::Dtype::F8_E4M3)
149 .count();
150 log::info!(
151 "FP8 model detected: {} tensors will be dequantized at load time",
152 fp8_count
153 );
154
155 let backend: Box<dyn SimpleBackend> = Box::new(Fp8Backend { inner });
156 Ok(VarBuilder::from_backend(backend, dtype, device.clone()))
157}