#[cfg(feature = "safetensors-compare")]
pub mod safetensors;
#[cfg(feature = "safetensors-compare")]
pub use safetensors::{BatchComparison, HfSafetensors, TensorComparison, TensorData};
use std::collections::HashMap;
use std::fmt;
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct InspectionResult {
pub header: HeaderInspection,
pub metadata: MetadataInspection,
pub weights: Option<WeightStats>,
pub quality_score: Option<u32>,
pub duration: Duration,
pub warnings: Vec<InspectionWarning>,
pub errors: Vec<InspectionError>,
}
impl InspectionResult {
#[must_use]
pub fn new(header: HeaderInspection, metadata: MetadataInspection) -> Self {
Self {
header,
metadata,
weights: None,
quality_score: None,
duration: Duration::ZERO,
warnings: Vec::new(),
errors: Vec::new(),
}
}
#[must_use]
pub fn has_issues(&self) -> bool {
!self.warnings.is_empty() || !self.errors.is_empty()
}
#[must_use]
pub fn is_valid(&self) -> bool {
self.errors.is_empty()
}
#[must_use]
pub fn issue_count(&self) -> usize {
self.warnings.len() + self.errors.len()
}
}
#[derive(Debug, Clone)]
pub struct HeaderInspection {
pub magic: [u8; 4],
pub version: (u8, u8),
pub model_type: u16,
pub flags: HeaderFlags,
pub compressed_size: u64,
pub uncompressed_size: u64,
pub checksum: u32,
pub magic_valid: bool,
pub version_supported: bool,
}
impl HeaderInspection {
#[must_use]
pub fn new() -> Self {
Self {
magic: *b"APRN",
version: (1, 0),
model_type: 0,
flags: HeaderFlags::default(),
compressed_size: 0,
uncompressed_size: 0,
checksum: 0,
magic_valid: true,
version_supported: true,
}
}
#[must_use]
pub fn magic_string(&self) -> String {
String::from_utf8_lossy(&self.magic).to_string()
}
#[must_use]
pub fn version_string(&self) -> String {
format!("{}.{}", self.version.0, self.version.1)
}
#[must_use]
pub fn compression_ratio(&self) -> f64 {
if self.compressed_size == 0 {
1.0
} else {
self.uncompressed_size as f64 / self.compressed_size as f64
}
}
#[must_use]
pub fn is_valid(&self) -> bool {
self.magic_valid && self.version_supported
}
}
impl Default for HeaderInspection {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, Default)]
#[allow(clippy::struct_excessive_bools)] pub struct HeaderFlags {
pub compressed: bool,
pub signed: bool,
pub encrypted: bool,
pub streaming: bool,
pub licensed: bool,
pub quantized: bool,
}
impl HeaderFlags {
#[must_use]
pub fn from_byte(byte: u8) -> Self {
Self {
compressed: byte & 0x01 != 0,
signed: byte & 0x02 != 0,
encrypted: byte & 0x04 != 0,
streaming: byte & 0x08 != 0,
licensed: byte & 0x10 != 0,
quantized: byte & 0x20 != 0,
}
}
#[must_use]
pub fn to_byte(&self) -> u8 {
let mut byte = 0u8;
if self.compressed {
byte |= 0x01;
}
if self.signed {
byte |= 0x02;
}
if self.encrypted {
byte |= 0x04;
}
if self.streaming {
byte |= 0x08;
}
if self.licensed {
byte |= 0x10;
}
if self.quantized {
byte |= 0x20;
}
byte
}
#[must_use]
pub fn flag_list(&self) -> Vec<&'static str> {
let mut flags = Vec::new();
if self.compressed {
flags.push("COMPRESSED");
}
if self.signed {
flags.push("SIGNED");
}
if self.encrypted {
flags.push("ENCRYPTED");
}
if self.streaming {
flags.push("STREAMING");
}
if self.licensed {
flags.push("LICENSED");
}
if self.quantized {
flags.push("QUANTIZED");
}
flags
}
}
#[derive(Debug, Clone)]
pub struct MetadataInspection {
pub model_type_name: String,
pub n_parameters: u64,
pub n_features: u32,
pub n_outputs: u32,
pub hyperparameters: HashMap<String, String>,
pub training_info: Option<TrainingInfo>,
pub license_info: Option<LicenseInfo>,
pub custom: HashMap<String, String>,
}
impl MetadataInspection {
#[must_use]
pub fn new(model_type_name: impl Into<String>) -> Self {
Self {
model_type_name: model_type_name.into(),
n_parameters: 0,
n_features: 0,
n_outputs: 0,
hyperparameters: HashMap::new(),
training_info: None,
license_info: None,
custom: HashMap::new(),
}
}
#[must_use]
pub fn has_training_info(&self) -> bool {
self.training_info.is_some()
}
#[must_use]
pub fn is_licensed(&self) -> bool {
self.license_info.is_some()
}
}
#[derive(Debug, Clone)]
pub struct TrainingInfo {
pub trained_at: Option<String>,
pub duration: Option<Duration>,
pub dataset_name: Option<String>,
pub n_samples: Option<u64>,
pub final_loss: Option<f64>,
pub framework: Option<String>,
pub framework_version: Option<String>,
}
impl TrainingInfo {
#[must_use]
pub fn new() -> Self {
Self {
trained_at: None,
duration: None,
dataset_name: None,
n_samples: None,
final_loss: None,
framework: None,
framework_version: None,
}
}
}
impl Default for TrainingInfo {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct LicenseInfo {
pub license_type: String,
pub licensee: Option<String>,
pub expires_at: Option<String>,
pub restrictions: Vec<String>,
}
impl LicenseInfo {
#[must_use]
pub fn new(license_type: impl Into<String>) -> Self {
Self {
license_type: license_type.into(),
licensee: None,
expires_at: None,
restrictions: Vec::new(),
}
}
#[must_use]
pub fn has_restrictions(&self) -> bool {
!self.restrictions.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct WeightStats {
pub count: u64,
pub min: f64,
pub max: f64,
pub mean: f64,
pub std: f64,
pub zero_count: u64,
pub nan_count: u64,
pub inf_count: u64,
pub sparsity: f64,
pub l1_norm: f64,
pub l2_norm: f64,
}
include!("weight_stats.rs");
include!("tests.rs");