use crate::cli::logging::log;
use crate::cli::LogLevel;
use crate::config::{QuantMethod, QuantizeArgs};
use crate::quant::{quantize_tensor, QuantGranularity, QuantMode, QuantizedTensor};
use safetensors::SafeTensors;
use std::collections::HashMap;
fn load_safetensors(args: &QuantizeArgs) -> Result<Vec<u8>, String> {
std::fs::read(&args.model).map_err(|e| format!("Failed to read model file: {e}"))
}
fn save_quantized_json(
quantized_tensors: &HashMap<String, QuantizedTensor>,
args: &QuantizeArgs,
) -> Result<(), String> {
let output_data = serde_json::to_vec_pretty(quantized_tensors)
.map_err(|e| format!("Failed to serialize: {e}"))?;
std::fs::write(&args.output, &output_data)
.map_err(|e| format!("Failed to write output: {e}"))?;
Ok(())
}
fn save_quantized_safetensors(
quantized_tensors: &HashMap<String, QuantizedTensor>,
args: &QuantizeArgs,
) -> Result<(), String> {
use safetensors::tensor::{Dtype, TensorView};
let mut i8_buffers: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
let mut scale_buffers: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
for (name, qt) in quantized_tensors {
let i8_bytes: Vec<u8> = qt.data.iter().map(|&v| v as u8).collect();
i8_buffers.push((name.clone(), i8_bytes, qt.shape.clone()));
let scale_name = format!("{name}.__scale");
let scale_bytes: Vec<u8> = qt.params.scales.iter().flat_map(|s| s.to_le_bytes()).collect();
let scale_shape = vec![qt.params.scales.len()];
scale_buffers.push((scale_name, scale_bytes, scale_shape));
}
let mut views: Vec<(&str, TensorView<'_>)> = Vec::new();
for (name, bytes, shape) in &i8_buffers {
let view = TensorView::new(Dtype::I8, shape.clone(), bytes)
.map_err(|e| format!("Failed to create I8 TensorView for {name}: {e}"))?;
views.push((name.as_str(), view));
}
for (name, bytes, shape) in &scale_buffers {
let view = TensorView::new(Dtype::F32, shape.clone(), bytes)
.map_err(|e| format!("Failed to create F32 TensorView for {name}: {e}"))?;
views.push((name.as_str(), view));
}
let mut metadata = HashMap::new();
metadata.insert("quantization".to_string(), format!("int{}", args.bits));
metadata.insert("method".to_string(), format!("{:?}", args.method).to_lowercase());
metadata.insert("num_tensors".to_string(), quantized_tensors.len().to_string());
let safetensor_bytes = safetensors::serialize(views, Some(metadata))
.map_err(|e| format!("SafeTensors serialization failed: {e}"))?;
std::fs::write(&args.output, safetensor_bytes)
.map_err(|e| format!("Failed to write output: {e}"))?;
Ok(())
}
fn resolve_quant_params(args: &QuantizeArgs) -> Result<(QuantMode, QuantGranularity), String> {
if args.bits != 4 && args.bits != 8 {
return Err(format!("Unsupported bit width: {}. Use 4 or 8.", args.bits));
}
let mode = match args.method {
QuantMethod::Symmetric => QuantMode::Symmetric,
QuantMethod::Asymmetric => QuantMode::Asymmetric,
};
let granularity =
if args.per_channel { QuantGranularity::PerChannel } else { QuantGranularity::PerTensor };
Ok((mode, granularity))
}
struct ByteAccumulator {
original: usize,
quantized: usize,
}
impl ByteAccumulator {
fn new() -> Self {
Self { original: 0, quantized: 0 }
}
fn compression_ratio(&self) -> f64 {
if self.quantized > 0 {
self.original as f64 / self.quantized as f64
} else {
1.0
}
}
}
fn quantize_single_tensor(
tensor: &safetensors::tensor::TensorView<'_>,
granularity: QuantGranularity,
mode: QuantMode,
bits: u8,
) -> (QuantizedTensor, usize) {
let shape: Vec<usize> = tensor.shape().to_vec();
let num_elements: usize = shape.iter().product();
let original_bytes = num_elements * 4;
let bytes = tensor.data();
let values: Vec<f32> = bytes
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
let quantized = quantize_tensor(&values, &shape, granularity, mode, bits);
(quantized, original_bytes)
}
fn log_quant_args(args: &QuantizeArgs, level: LogLevel) {
log(level, LogLevel::Verbose, &format!(" Method: {:?}", args.method));
log(level, LogLevel::Verbose, &format!(" Per-channel: {}", args.per_channel));
log(level, LogLevel::Verbose, &format!(" Output: {}", args.output.display()));
}
pub fn run_quantize(args: QuantizeArgs, level: LogLevel) -> Result<(), String> {
log(
level,
LogLevel::Normal,
&format!("Quantizing {} to {}-bit", args.model.display(), args.bits),
);
log_quant_args(&args, level);
let (mode, granularity) = resolve_quant_params(&args)?;
let data = load_safetensors(&args)?;
let tensors =
SafeTensors::deserialize(&data).map_err(|e| format!("Failed to parse safetensors: {e}"))?;
let mut quantized_tensors: HashMap<String, QuantizedTensor> = HashMap::new();
let mut bytes = ByteAccumulator::new();
for name in tensors.names() {
let tensor =
tensors.tensor(name).map_err(|e| format!("Failed to get tensor {name}: {e}"))?;
if tensor.dtype() != safetensors::tensor::Dtype::F32 {
log(level, LogLevel::Verbose, &format!(" Skipping {name} (not F32)"));
continue;
}
let (quantized, original_bytes) =
quantize_single_tensor(&tensor, granularity, mode, args.bits);
bytes.original += original_bytes;
bytes.quantized += quantized.memory_bytes();
log(
level,
LogLevel::Verbose,
&format!(
" Quantized {}: {:?} -> {} bytes",
name,
tensor.shape(),
quantized.memory_bytes()
),
);
quantized_tensors.insert((*name).to_string(), quantized);
}
if args.safetensors {
save_quantized_safetensors(&quantized_tensors, &args)?;
} else {
save_quantized_json(&quantized_tensors, &args)?;
}
log(
level,
LogLevel::Normal,
&format!(
"Quantization complete: {} tensors, {:.1}x compression",
quantized_tensors.len(),
bytes.compression_ratio()
),
);
log(level, LogLevel::Normal, &format!(" Output: {}", args.output.display()));
Ok(())
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use safetensors::tensor::{Dtype, TensorView};
fn create_test_safetensors(path: &std::path::Path) {
let data: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 0.1).collect();
let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
let view = TensorView::new(Dtype::F32, vec![8, 8], &bytes).unwrap();
let views = vec![("test_weight", view)];
let serialized = safetensors::serialize(views, None::<HashMap<String, String>>).unwrap();
std::fs::write(path, serialized).unwrap();
}
#[test]
fn test_wasm002_quantize_safetensors_int8_output() {
let dir = tempfile::tempdir().unwrap();
let input_path = dir.path().join("model.safetensors");
let output_path = dir.path().join("model_int8.safetensors");
create_test_safetensors(&input_path);
let args = QuantizeArgs {
model: input_path,
output: output_path.clone(),
bits: 8,
method: crate::config::QuantMethod::Symmetric,
per_channel: false,
calibration_data: None,
safetensors: true,
};
run_quantize(args, crate::cli::LogLevel::Quiet).expect("quantize should succeed");
let data = std::fs::read(&output_path).unwrap();
let tensors = SafeTensors::deserialize(&data).unwrap();
let names: Vec<&str> = tensors.names().into_iter().collect();
assert!(names.contains(&"test_weight"), "Must contain weight tensor");
assert!(names.contains(&"test_weight.__scale"), "Must contain scale tensor");
let weight = tensors.tensor("test_weight").unwrap();
assert_eq!(weight.dtype(), Dtype::I8);
assert_eq!(weight.shape(), &[8, 8]);
assert_eq!(weight.data().len(), 64);
let scale = tensors.tensor("test_weight.__scale").unwrap();
assert_eq!(scale.dtype(), Dtype::F32);
}
#[test]
fn test_wasm002_quantize_safetensors_compression() {
let dir = tempfile::tempdir().unwrap();
let input_path = dir.path().join("model.safetensors");
let output_path = dir.path().join("model_int8.safetensors");
create_test_safetensors(&input_path);
let args = QuantizeArgs {
model: input_path.clone(),
output: output_path.clone(),
bits: 8,
method: crate::config::QuantMethod::Symmetric,
per_channel: false,
calibration_data: None,
safetensors: true,
};
run_quantize(args, crate::cli::LogLevel::Quiet).expect("quantize");
let input_size = std::fs::metadata(&input_path).unwrap().len();
let output_size = std::fs::metadata(&output_path).unwrap().len();
assert!(
output_size < input_size,
"Int8 output ({output_size}) must be smaller than F32 input ({input_size})"
);
}
#[test]
fn test_wasm002_quantize_json_still_works() {
let dir = tempfile::tempdir().unwrap();
let input_path = dir.path().join("model.safetensors");
let output_path = dir.path().join("model_int8.json");
create_test_safetensors(&input_path);
let args = QuantizeArgs {
model: input_path,
output: output_path.clone(),
bits: 8,
method: crate::config::QuantMethod::Symmetric,
per_channel: false,
calibration_data: None,
safetensors: false,
};
run_quantize(args, crate::cli::LogLevel::Quiet).expect("quantize");
let json = std::fs::read_to_string(&output_path).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert!(parsed.is_object());
assert!(parsed.get("test_weight").is_some());
}
}