import pytest
import quantize_rs
import numpy as np
import os
import tempfile
MODEL_PATH = "mnist.onnx"
@pytest.fixture
def temp_output():
with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
yield f.name
if os.path.exists(f.name):
os.remove(f.name)
def test_model_info():
if not os.path.exists(MODEL_PATH):
pytest.skip(f"{MODEL_PATH} not found")
info = quantize_rs.model_info(MODEL_PATH)
assert isinstance(info.name, str)
assert info.num_nodes > 0
assert len(info.inputs) > 0
assert len(info.outputs) > 0
print(f"Model: {info.name}, Nodes: {info.num_nodes}")
def test_basic_quantize(temp_output):
if not os.path.exists(MODEL_PATH):
pytest.skip(f"{MODEL_PATH} not found")
quantize_rs.quantize(
input_path=MODEL_PATH,
output_path=temp_output,
bits=8,
per_channel=False
)
assert os.path.exists(temp_output)
assert os.path.getsize(temp_output) > 0
original_size = os.path.getsize(MODEL_PATH)
quantized_size = os.path.getsize(temp_output)
assert quantized_size < original_size
print(f"Compression: {original_size / quantized_size:.2f}×")
def test_quantize_with_calibration_random(temp_output):
if not os.path.exists(MODEL_PATH):
pytest.skip(f"{MODEL_PATH} not found")
quantize_rs.quantize_with_calibration(
input_path=MODEL_PATH,
output_path=temp_output,
calibration_data=None, bits=8,
method="minmax",
num_samples=10, sample_shape=[1, 28, 28] )
assert os.path.exists(temp_output)
assert os.path.getsize(temp_output) > 0
def test_quantize_with_calibration_npy(temp_output):
if not os.path.exists(MODEL_PATH):
pytest.skip(f"{MODEL_PATH} not found")
calib_data = np.random.randn(10, 1, 28, 28).astype(np.float32)
calib_path = "test_calibration.npy"
np.save(calib_path, calib_data)
try:
quantize_rs.quantize_with_calibration(
input_path=MODEL_PATH,
output_path=temp_output,
calibration_data=calib_path,
bits=8,
method="percentile"
)
assert os.path.exists(temp_output)
assert os.path.getsize(temp_output) > 0
finally:
if os.path.exists(calib_path):
os.remove(calib_path)
def test_int4_quantization(temp_output):
if not os.path.exists(MODEL_PATH):
pytest.skip(f"{MODEL_PATH} not found")
quantize_rs.quantize(
input_path=MODEL_PATH,
output_path=temp_output,
bits=4,
per_channel=True
)
assert os.path.exists(temp_output)
def test_invalid_model_path():
with pytest.raises(Exception) as exc_info:
quantize_rs.quantize(
input_path="nonexistent.onnx",
output_path="output.onnx"
)
assert "Failed to load model" in str(exc_info.value)
def test_invalid_calibration_method(temp_output):
if not os.path.exists(MODEL_PATH):
pytest.skip(f"{MODEL_PATH} not found")
with pytest.raises(Exception) as exc_info:
quantize_rs.quantize_with_calibration(
input_path=MODEL_PATH,
output_path=temp_output,
method="invalid_method"
)
assert "Unknown method" in str(exc_info.value)
def test_onnxruntime_compatibility(temp_output):
pytest.importorskip("onnxruntime")
import onnxruntime as ort
if not os.path.exists(MODEL_PATH):
pytest.skip(f"{MODEL_PATH} not found")
quantize_rs.quantize(
input_path=MODEL_PATH,
output_path=temp_output,
bits=8
)
session = ort.InferenceSession(temp_output)
assert len(session.get_inputs()) > 0
assert len(session.get_outputs()) > 0
input_name = session.get_inputs()[0].name
input_shape = session.get_inputs()[0].shape
dummy_input = np.random.randn(*input_shape).astype(np.float32)
output = session.run(None, {input_name: dummy_input})
assert len(output) > 0
print(f"✓ Inference works! Output shape: {output[0].shape}")
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])