use super::format::{FormatResult, NumRS2Model};
use crate::error::NumRs2Error;
use oxicode::{decode_from_slice, encode_to_vec};
use scirs2_core::ndarray::{Array1, Array2};
use serde_json;
use std::fs::File;
use std::io::{BufReader, Read};
use std::path::Path;
pub struct ModelCompression;
impl ModelCompression {
pub fn compress_weights(weights: &Array2<f64>) -> FormatResult<Vec<u8>> {
serde_json::to_vec(weights).map_err(|e| {
NumRs2Error::SerializationError(format!("Failed to compress weights: {}", e))
})
}
pub fn decompress_weights(data: &[u8]) -> FormatResult<Array2<f64>> {
serde_json::from_slice(data).map_err(|e| {
NumRs2Error::DeserializationError(format!("Failed to decompress weights: {}", e))
})
}
pub fn quantize_weights(weights: &Array2<f64>, bits: u8) -> FormatResult<(Vec<u8>, f64, f64)> {
if bits != 8 && bits != 16 && bits != 32 {
return Err(NumRs2Error::ValueError(
"Quantization bits must be 8, 16, or 32".to_string(),
));
}
let min_val = weights.iter().cloned().fold(f64::INFINITY, f64::min);
let max_val = weights.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let range = max_val - min_val;
let qmin = 0.0;
let qmax = (2_u32.pow(bits as u32) - 1) as f64;
let scale = range / (qmax - qmin);
let zero_point = qmin - min_val / scale;
let quantized: Vec<u8> = match bits {
8 => weights
.iter()
.map(|&w| {
let q = ((w / scale) + zero_point).round().clamp(qmin, qmax);
q as u8
})
.collect(),
16 => {
let mut bytes = Vec::new();
for &w in weights.iter() {
let q = ((w / scale) + zero_point).round().clamp(qmin, qmax);
let q_u16 = q as u16;
bytes.extend_from_slice(&q_u16.to_le_bytes());
}
bytes
}
32 => {
let mut bytes = Vec::new();
for &w in weights.iter() {
let q = ((w / scale) + zero_point).round().clamp(qmin, qmax);
let q_u32 = q as u32;
bytes.extend_from_slice(&q_u32.to_le_bytes());
}
bytes
}
_ => unreachable!(),
};
Ok((quantized, scale, zero_point))
}
pub fn dequantize_weights(
data: &[u8],
shape: (usize, usize),
bits: u8,
scale: f64,
zero_point: f64,
) -> FormatResult<Array2<f64>> {
let total_elements = shape.0 * shape.1;
let values: Vec<f64> = match bits {
8 => {
if data.len() != total_elements {
return Err(NumRs2Error::ValueError(
"Data size mismatch for 8-bit quantization".to_string(),
));
}
data.iter()
.map(|&q| (q as f64 - zero_point) * scale)
.collect()
}
16 => {
if data.len() != total_elements * 2 {
return Err(NumRs2Error::ValueError(
"Data size mismatch for 16-bit quantization".to_string(),
));
}
data.chunks_exact(2)
.map(|chunk| {
let q = u16::from_le_bytes([chunk[0], chunk[1]]);
(q as f64 - zero_point) * scale
})
.collect()
}
32 => {
if data.len() != total_elements * 4 {
return Err(NumRs2Error::ValueError(
"Data size mismatch for 32-bit quantization".to_string(),
));
}
data.chunks_exact(4)
.map(|chunk| {
let q = u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
(q as f64 - zero_point) * scale
})
.collect()
}
_ => {
return Err(NumRs2Error::ValueError(
"Unsupported quantization bits".to_string(),
))
}
};
Array2::from_shape_vec(shape, values).map_err(|e| {
NumRs2Error::ValueError(format!("Failed to create array from quantized data: {}", e))
})
}
}
pub struct ModelFingerprint;
impl ModelFingerprint {
pub fn compute_hash(model: &NumRS2Model) -> FormatResult<String> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let bytes = serde_json::to_vec(model).map_err(|e| {
NumRs2Error::SerializationError(format!("Failed to serialize model for hashing: {}", e))
})?;
let mut hasher = DefaultHasher::new();
bytes.hash(&mut hasher);
let hash = hasher.finish();
Ok(format!("{:x}", hash))
}
pub fn verify_hash(model: &NumRS2Model, expected_hash: &str) -> FormatResult<bool> {
let computed_hash = Self::compute_hash(model)?;
Ok(computed_hash == expected_hash)
}
pub fn compute_checksum(model: &NumRS2Model) -> FormatResult<u32> {
let bytes = serde_json::to_vec(model).map_err(|e| {
NumRs2Error::SerializationError(format!(
"Failed to serialize model for checksum: {}",
e
))
})?;
let mut checksum: u32 = 0;
for &byte in bytes.iter() {
checksum = checksum.wrapping_add(byte as u32);
}
Ok(checksum)
}
}
pub struct FormatDetector;
impl FormatDetector {
pub fn detect<P: AsRef<Path>>(path: P) -> FormatResult<String> {
let file = File::open(path.as_ref()).map_err(|e| {
NumRs2Error::IOError(format!("Failed to open file for format detection: {}", e))
})?;
let mut reader = BufReader::new(file);
let mut magic = [0u8; 8];
reader
.read_exact(&mut magic)
.map_err(|e| NumRs2Error::IOError(format!("Failed to read magic bytes: {}", e)))?;
if magic == *b"NUMRS2\x00\x00" {
return Ok("numrs2".to_string());
}
if magic[0..6] == *b"\x93NUMPY" {
return Ok("npy".to_string());
}
if magic[0..4] == *b"PK\x03\x04" {
return Ok("npz".to_string());
}
reader = BufReader::new(
File::open(path.as_ref())
.map_err(|e| NumRs2Error::IOError(format!("Failed to reopen file: {}", e)))?,
);
let mut first_byte = [0u8; 1];
reader
.read_exact(&mut first_byte)
.map_err(|e| NumRs2Error::IOError(format!("Failed to read first byte: {}", e)))?;
if first_byte[0] == b'{' || first_byte[0] == b'[' {
return Ok("json".to_string());
}
Ok("unknown".to_string())
}
pub fn is_numrs2_format<P: AsRef<Path>>(path: P) -> bool {
Self::detect(path).is_ok_and(|fmt| fmt == "numrs2")
}
}
pub struct StreamingSerializer;
impl StreamingSerializer {
pub fn save_streaming<P: AsRef<Path>>(model: &NumRS2Model, path: P) -> FormatResult<()> {
super::serialize::ModelSerializer::save(model, path)
}
pub fn load_streaming<P: AsRef<Path>>(path: P) -> FormatResult<NumRS2Model> {
super::serialize::ModelSerializer::load(path)
}
}
pub fn compress_weights(weights: &Array2<f64>) -> FormatResult<Vec<u8>> {
ModelCompression::compress_weights(weights)
}
pub fn decompress_weights(data: &[u8]) -> FormatResult<Array2<f64>> {
ModelCompression::decompress_weights(data)
}
pub fn quantize_weights(weights: &Array2<f64>, bits: u8) -> FormatResult<(Vec<u8>, f64, f64)> {
ModelCompression::quantize_weights(weights, bits)
}
pub fn compute_model_hash(model: &NumRS2Model) -> FormatResult<String> {
ModelFingerprint::compute_hash(model)
}
pub fn detect_format<P: AsRef<Path>>(path: P) -> FormatResult<String> {
FormatDetector::detect(path)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::new_modules::model_io::format::{LayerData, ModelMetadata};
use scirs2_core::ndarray::Array2;
use std::env;
use std::fs;
#[test]
fn test_compress_decompress_weights() {
let weights = Array2::from_shape_fn((10, 5), |(i, j)| (i * 5 + j) as f64);
let compressed = ModelCompression::compress_weights(&weights);
assert!(compressed.is_ok());
let compressed_data = compressed.expect("test: valid compression");
assert!(!compressed_data.is_empty());
let decompressed = ModelCompression::decompress_weights(&compressed_data);
assert!(decompressed.is_ok());
let recovered = decompressed.expect("test: valid decompression");
assert_eq!(recovered.shape(), weights.shape());
}
#[test]
fn test_quantize_weights_8bit() {
let weights = Array2::from_shape_fn((5, 4), |(i, j)| (i * 4 + j) as f64);
let result = ModelCompression::quantize_weights(&weights, 8);
assert!(result.is_ok());
let (quantized, scale, zero_point) = result.expect("test: valid quantization");
assert!(!quantized.is_empty());
assert!(scale > 0.0);
let dequantized =
ModelCompression::dequantize_weights(&quantized, (5, 4), 8, scale, zero_point);
assert!(dequantized.is_ok());
}
#[test]
fn test_quantize_weights_16bit() {
let weights = Array2::from_shape_fn((3, 3), |(i, j)| (i * 3 + j) as f64 * 0.1);
let result = ModelCompression::quantize_weights(&weights, 16);
assert!(result.is_ok());
let (quantized, scale, zero_point) = result.expect("test: valid quantization");
assert_eq!(quantized.len(), 3 * 3 * 2);
let dequantized =
ModelCompression::dequantize_weights(&quantized, (3, 3), 16, scale, zero_point);
assert!(dequantized.is_ok());
}
#[test]
fn test_quantize_invalid_bits() {
let weights = Array2::ones((5, 5));
let result = ModelCompression::quantize_weights(&weights, 7);
assert!(result.is_err());
}
#[test]
fn test_compute_hash() {
let metadata = ModelMetadata::builder()
.name("test_model")
.build()
.expect("test: valid metadata build");
let layer = LayerData::dense("layer1", Array2::ones((10, 5)), None);
let model = NumRS2Model::new(metadata, vec![layer]);
let hash = ModelFingerprint::compute_hash(&model);
assert!(hash.is_ok());
let hash_str = hash.expect("test: valid hash computation");
assert!(!hash_str.is_empty());
assert!(!hash_str.is_empty());
}
#[test]
fn test_verify_hash() {
let metadata = ModelMetadata::builder()
.name("test_model")
.build()
.expect("test: valid metadata build");
let layer = LayerData::dense("layer1", Array2::ones((10, 5)), None);
let model = NumRS2Model::new(metadata, vec![layer]);
let hash = ModelFingerprint::compute_hash(&model).expect("test: valid hash computation");
let verified = ModelFingerprint::verify_hash(&model, &hash);
assert!(verified.is_ok());
assert!(verified.expect("test: valid hash verification"));
let verified_wrong = ModelFingerprint::verify_hash(&model, "wrong_hash");
assert!(verified_wrong.is_ok());
assert!(!verified_wrong.expect("test: valid hash verification (wrong hash)"));
}
#[test]
fn test_compute_checksum() {
let metadata = ModelMetadata::builder()
.name("test_model")
.build()
.expect("test: valid metadata build");
let layer = LayerData::dense("layer1", Array2::ones((10, 5)), None);
let model = NumRS2Model::new(metadata, vec![layer]);
let checksum = ModelFingerprint::compute_checksum(&model);
assert!(checksum.is_ok());
assert!(checksum.expect("test: valid checksum computation") > 0);
}
#[test]
fn test_detect_format_numrs2() {
let temp_dir = env::temp_dir();
let path = temp_dir.join("test_detect.numrs2");
let metadata = ModelMetadata::builder()
.name("test_model")
.build()
.expect("test: valid metadata build");
let layer = LayerData::dense("layer1", Array2::ones((10, 5)), None);
let model = NumRS2Model::new(metadata, vec![layer]);
super::super::serialize::ModelSerializer::save(&model, &path)
.expect("test: valid model save");
let format = FormatDetector::detect(&path);
assert!(format.is_ok());
assert_eq!(format.expect("test: valid format detection"), "numrs2");
assert!(FormatDetector::is_numrs2_format(&path));
let _ = fs::remove_file(path);
}
#[test]
fn test_detect_format_unknown() {
let temp_dir = env::temp_dir();
let path = temp_dir.join("test_unknown.bin");
fs::write(&path, b"UNKNOWN_FORMAT_DATA").expect("test: valid file write");
let format = FormatDetector::detect(&path);
assert!(format.is_ok());
assert_eq!(format.expect("test: valid format detection"), "unknown");
let _ = fs::remove_file(path);
}
#[test]
fn test_streaming_serializer() {
let temp_dir = env::temp_dir();
let path = temp_dir.join("test_streaming.numrs2");
let metadata = ModelMetadata::builder()
.name("test_model")
.build()
.expect("test: valid metadata build");
let layer = LayerData::dense("layer1", Array2::ones((100, 50)), None);
let model = NumRS2Model::new(metadata, vec![layer]);
let result = StreamingSerializer::save_streaming(&model, &path);
assert!(result.is_ok());
let loaded = StreamingSerializer::load_streaming(&path);
assert!(loaded.is_ok());
let _ = fs::remove_file(path);
}
#[test]
fn test_convenience_functions() {
let weights = Array2::ones((5, 3));
let compressed = compress_weights(&weights);
assert!(compressed.is_ok());
let decompressed = decompress_weights(&compressed.expect("test: valid compression"));
assert!(decompressed.is_ok());
let quantized = quantize_weights(&weights, 8);
assert!(quantized.is_ok());
let metadata = ModelMetadata::builder()
.name("test")
.build()
.expect("test: valid metadata build");
let layer = LayerData::dense("layer1", weights, None);
let model = NumRS2Model::new(metadata, vec![layer]);
let hash = compute_model_hash(&model);
assert!(hash.is_ok());
}
#[test]
fn test_quantize_dequantize_roundtrip() {
let original = Array2::from_shape_fn((4, 3), |(i, j)| (i * 3 + j) as f64 * 0.5);
let (quantized, scale, zero_point) =
ModelCompression::quantize_weights(&original, 8).expect("test: valid quantization");
let recovered =
ModelCompression::dequantize_weights(&quantized, (4, 3), 8, scale, zero_point)
.expect("test: valid dequantization");
assert_eq!(recovered.shape(), original.shape());
for (orig, rec) in original.iter().zip(recovered.iter()) {
assert!((orig - rec).abs() < 0.5); }
}
#[test]
fn test_hash_consistency() {
let metadata = ModelMetadata::builder()
.name("test_model")
.build()
.expect("test: valid metadata build");
let layer = LayerData::dense("layer1", Array2::ones((10, 5)), None);
let model = NumRS2Model::new(metadata, vec![layer]);
let hash1 =
ModelFingerprint::compute_hash(&model).expect("test: valid model hash computation");
let hash2 =
ModelFingerprint::compute_hash(&model).expect("test: valid model hash computation");
assert_eq!(hash1, hash2);
}
}