use super::{build_pygmy_apr_with_config, build_pygmy_safetensors_with_config, PygmyConfig};
use crate::format::converter::{
apr_export, apr_import, ExportFormat, ExportOptions, ImportOptions,
};
use crate::format::v2::AprV2Reader;
use crate::serialization::safetensors::MappedSafeTensors;
use std::fs;
use std::path::{Path, PathBuf};
use tempfile::TempDir;
#[derive(Debug, Clone, Copy)]
#[allow(clippy::struct_field_names)] pub(crate) struct ToleranceConfig {
pub(crate) f32_atol: f32,
pub(crate) f16_atol: f32,
pub(crate) q8_atol: f32,
pub(crate) q4_atol: f32,
}
impl Default for ToleranceConfig {
fn default() -> Self {
Self {
f32_atol: 1e-6,
f16_atol: 1e-3,
q8_atol: 0.1,
q4_atol: 0.5,
}
}
}
#[derive(Debug)]
pub(crate) struct TensorMismatch {
pub(crate) tensor_name: String,
pub(crate) kind: MismatchKind,
}
#[derive(Debug)]
pub(crate) enum MismatchKind {
Missing,
Extra,
ShapeMismatch {
expected: Vec<usize>,
actual: Vec<usize>,
},
DataMismatch {
index: usize,
expected: f32,
actual: f32,
tolerance: f32,
},
}
impl core::fmt::Display for TensorMismatch {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match &self.kind {
MismatchKind::Missing => {
write!(f, "tensor '{}': missing in output", self.tensor_name)
}
MismatchKind::Extra => {
write!(
f,
"tensor '{}': extra in output (not in source)",
self.tensor_name
)
}
MismatchKind::ShapeMismatch { expected, actual } => {
write!(
f,
"tensor '{}': shape mismatch expected={:?} actual={:?}",
self.tensor_name, expected, actual
)
}
MismatchKind::DataMismatch {
index,
expected,
actual,
tolerance,
} => {
write!(
f,
"tensor '{}': data[{}] expected={} actual={} (tol={})",
self.tensor_name, index, expected, actual, tolerance
)
}
}
}
}
#[derive(Debug)]
pub(crate) struct VerificationResult {
pub(crate) mismatches: Vec<TensorMismatch>,
}
impl VerificationResult {
pub(crate) fn assert_passed(&self) {
if !self.mismatches.is_empty() {
let msgs: Vec<String> = self.mismatches.iter().map(ToString::to_string).collect();
panic!(
"Verification failed with {} mismatch(es):\n {}",
self.mismatches.len(),
msgs.join("\n ")
);
}
}
#[must_use]
pub(crate) fn passed(&self) -> bool {
self.mismatches.is_empty()
}
}
pub(crate) struct ConversionTestHarness {
dir: TempDir,
input_path: Option<PathBuf>,
output_path: Option<PathBuf>,
source_tensors: Vec<(String, Vec<f32>, Vec<usize>)>,
pub(crate) tolerance: ToleranceConfig,
}
include!("harness_impl.rs");
include!("collection.rs");
include!("harness_strict_tests.rs");
include!("harness_roundtrip_tests.rs");