use candle_core::{DType, Result, Tensor};
use crate::scalar_fp8::ops::fp8_to_dtype;
pub fn fp8_pertensor_dequantize(
weight: &Tensor,
scale_inv: &Tensor,
out_dtype: DType,
) -> Result<Tensor> {
let weight_f32 = fp8_to_dtype(weight, DType::F32)?;
let scale_inv_f32 = scale_inv.to_dtype(DType::F32)?;
(weight_f32.broadcast_mul(&scale_inv_f32))?.to_dtype(out_dtype)
}
#[allow(dead_code)]
pub fn fp8_pertensor_quantize(x: &Tensor, scale: &Tensor) -> Result<Tensor> {
let x_f32 = x.to_dtype(DType::F32)?;
let scale_f32 = scale.to_dtype(DType::F32)?;
let x_scaled = x_f32.broadcast_div(&scale_f32)?;
let clamped = x_scaled.clamp(-448.0f64, 448.0f64)?;
clamped.to_dtype(DType::F8E4M3)
}