Skip to main content

cake_core/utils/
fp8.rs

1//! FP8 (float8_e4m3fn) dequantization support.
2//!
3//! Models like Qwen3.5-27B-FP8 store most weight tensors in F8_E4M3 format with
4//! per-block scale factors (`weight_scale_inv`). This module provides a custom
5//! VarBuilder backend that transparently dequantizes FP8 weights at load time,
6//! allowing cake to run FP8-quantized models on any backend (CUDA, Metal, CPU).
7//!
8//! Dequantization formula (block size 128×128):
9//!   bf16_weight[i*128..(i+1)*128, j*128..(j+1)*128]
10//!     = cast(fp8_weight[...same...]) * scale_inv[i, j]
11
12use std::path::Path;
13
14use candle_core::{safetensors::MmapedSafetensors, DType, Device, Shape, Tensor};
15use candle_nn::{var_builder::SimpleBackend, Init, VarBuilder};
16
17const FP8_BLOCK_SIZE: usize = 128;
18
19/// Check whether a model uses FP8 block-wise quantization by looking at its config.
20pub fn is_fp8_quantized(config_path: &Path) -> bool {
21    let Ok(data) = std::fs::read_to_string(config_path) else {
22        return false;
23    };
24    let Ok(json) = serde_json::from_str::<serde_json::Value>(&data) else {
25        return false;
26    };
27    // Check top-level and nested text_config for quantization_config
28    for root in [&json, json.get("text_config").unwrap_or(&json)] {
29        let is_fp8 = root
30            .get("quantization_config")
31            .and_then(|qc| qc.get("quant_method"))
32            .and_then(|qm| qm.as_str())
33            .map(|s| s == "fp8")
34            .unwrap_or(false);
35        if is_fp8 {
36            return true;
37        }
38    }
39    false
40}
41
42/// Dequantize a 2-D FP8 weight tensor using its per-block scale factor.
43fn dequantize_fp8_blockwise(weight: &Tensor, scale_inv: &Tensor) -> candle_core::Result<Tensor> {
44    let (m, n) = weight.dims2()?;
45    let bm = FP8_BLOCK_SIZE;
46    let bn = FP8_BLOCK_SIZE;
47    let blocks_m = (m + bm - 1) / bm;
48    let blocks_n = (n + bn - 1) / bn;
49
50    // Cast FP8 → F32 on CPU (candle supports this on CPU)
51    let weight_f32 = weight.to_dtype(DType::F32)?;
52    let scale_f32 = scale_inv.to_dtype(DType::F32)?;
53
54    // Reshape for block-wise broadcast multiply:
55    //   weight: [M, N]         → [blocks_m, bm, blocks_n, bn]
56    //   scale:  [blocks_m, blocks_n] → [blocks_m, 1, blocks_n, 1]
57    let weight_blocked = weight_f32.reshape((blocks_m, bm, blocks_n, bn))?;
58    let scale_blocked = scale_f32.reshape((blocks_m, 1usize, blocks_n, 1usize))?;
59
60    let dequantized = weight_blocked.broadcast_mul(&scale_blocked)?;
61    dequantized.reshape((m, n))
62}
63
64/// Custom VarBuilder backend that wraps MmapedSafetensors and transparently
65/// dequantizes FP8-quantized weight tensors on CPU before moving to the target device.
66struct Fp8Backend {
67    inner: MmapedSafetensors,
68}
69
70impl SimpleBackend for Fp8Backend {
71    fn get(
72        &self,
73        s: Shape,
74        name: &str,
75        _h: Init,
76        dtype: DType,
77        dev: &Device,
78    ) -> candle_core::Result<Tensor> {
79        let tensor = self.load_tensor(name, dtype, dev)?;
80        if tensor.shape() != &s {
81            Err(candle_core::Error::UnexpectedShape {
82                msg: format!("shape mismatch for {name}"),
83                expected: s,
84                got: tensor.shape().clone(),
85            }
86            .bt())?
87        }
88        Ok(tensor)
89    }
90
91    fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> candle_core::Result<Tensor> {
92        self.load_tensor(name, dtype, dev)
93    }
94
95    fn contains_tensor(&self, name: &str) -> bool {
96        self.inner.get(name).is_ok()
97    }
98}
99
100impl Fp8Backend {
101    fn load_tensor(
102        &self,
103        name: &str,
104        dtype: DType,
105        dev: &Device,
106    ) -> candle_core::Result<Tensor> {
107        let scale_name = format!("{name}_scale_inv");
108
109        if self.inner.get(&scale_name).is_ok() {
110            // FP8 quantized tensor — dequantize on CPU then move to device
111            let weight = self.inner.load(name, &Device::Cpu)?;
112            let scale = self.inner.load(&scale_name, &Device::Cpu)?;
113
114            let dequantized = dequantize_fp8_blockwise(&weight, &scale)?;
115            dequantized.to_dtype(dtype)?.to_device(dev)
116        } else {
117            // Non-quantized tensor — check if the on-file dtype needs CPU-side handling
118            let view = self.inner.get(name)?;
119            let file_dtype: DType = view.dtype().try_into()?;
120
121            if file_dtype == DType::F8E4M3 {
122                // FP8 without scale (shouldn't happen, but handle gracefully)
123                let tensor = self.inner.load(name, &Device::Cpu)?;
124                tensor.to_dtype(dtype)?.to_device(dev)
125            } else {
126                // Normal path — load directly on target device
127                self.inner.load(name, dev)?.to_dtype(dtype)
128            }
129        }
130    }
131}
132
133/// Create a VarBuilder that transparently dequantizes FP8 weights.
134///
135/// # Safety
136///
137/// Inherits the mmap safety requirements from `MmapedSafetensors`.
138pub unsafe fn load_fp8_var_builder<'a>(
139    filenames: &[std::path::PathBuf],
140    dtype: DType,
141    device: &Device,
142) -> anyhow::Result<VarBuilder<'a>> {
143    let inner = MmapedSafetensors::multi(filenames)?;
144
145    let fp8_count = inner
146        .tensors()
147        .iter()
148        .filter(|(_, v)| v.dtype() == safetensors::tensor::Dtype::F8_E4M3)
149        .count();
150    log::info!(
151        "FP8 model detected: {} tensors will be dequantized at load time",
152        fp8_count
153    );
154
155    let backend: Box<dyn SimpleBackend> = Box::new(Fp8Backend { inner });
156    Ok(VarBuilder::from_backend(backend, dtype, device.clone()))
157}