use crate::error::Result;
use std::collections::HashMap;
use std::path::Path;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum LintLevel {
Info,
Warn,
Error,
}
impl LintLevel {
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
Self::Info => "INFO",
Self::Warn => "WARN",
Self::Error => "ERROR",
}
}
}
impl std::fmt::Display for LintLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum LintCategory {
Metadata,
Naming,
Efficiency,
Layout,
}
impl LintCategory {
#[must_use]
pub fn name(&self) -> &'static str {
match self {
Self::Metadata => "Metadata",
Self::Naming => "Tensor Naming",
Self::Efficiency => "Efficiency",
Self::Layout => "Layout Contract",
}
}
}
impl std::fmt::Display for LintCategory {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name())
}
}
#[derive(Debug, Clone)]
pub struct LintIssue {
pub level: LintLevel,
pub category: LintCategory,
pub message: String,
pub suggestion: Option<String>,
}
impl LintIssue {
#[must_use]
pub fn new(level: LintLevel, category: LintCategory, message: impl Into<String>) -> Self {
Self {
level,
category,
message: message.into(),
suggestion: None,
}
}
#[must_use]
pub fn with_suggestion(mut self, suggestion: impl Into<String>) -> Self {
self.suggestion = Some(suggestion.into());
self
}
#[must_use]
pub fn metadata_warn(message: impl Into<String>) -> Self {
Self::new(LintLevel::Warn, LintCategory::Metadata, message)
}
#[must_use]
pub fn naming_info(message: impl Into<String>) -> Self {
Self::new(LintLevel::Info, LintCategory::Naming, message)
}
#[must_use]
pub fn naming_warn(message: impl Into<String>) -> Self {
Self::new(LintLevel::Warn, LintCategory::Naming, message)
}
#[must_use]
pub fn efficiency_info(message: impl Into<String>) -> Self {
Self::new(LintLevel::Info, LintCategory::Efficiency, message)
}
#[must_use]
pub fn layout_warn(message: impl Into<String>) -> Self {
Self::new(LintLevel::Warn, LintCategory::Layout, message)
}
#[must_use]
pub fn layout_error(message: impl Into<String>) -> Self {
Self::new(LintLevel::Error, LintCategory::Layout, message)
}
}
impl std::fmt::Display for LintIssue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "[{}] {}: {}", self.level, self.category, self.message)?;
if let Some(ref suggestion) = self.suggestion {
write!(f, " (suggestion: {})", suggestion)?;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct LintReport {
pub issues: Vec<LintIssue>,
pub by_category: HashMap<LintCategory, Vec<usize>>,
pub info_count: usize,
pub warn_count: usize,
pub error_count: usize,
}
impl LintReport {
#[must_use]
pub fn new() -> Self {
Self {
issues: Vec::new(),
by_category: HashMap::new(),
info_count: 0,
warn_count: 0,
error_count: 0,
}
}
pub fn add_issue(&mut self, issue: LintIssue) {
let idx = self.issues.len();
let category = issue.category;
match issue.level {
LintLevel::Info => self.info_count += 1,
LintLevel::Warn => self.warn_count += 1,
LintLevel::Error => self.error_count += 1,
}
self.by_category.entry(category).or_default().push(idx);
self.issues.push(issue);
}
#[must_use]
pub fn passed(&self) -> bool {
self.warn_count == 0 && self.error_count == 0
}
#[must_use]
pub fn passed_strict(&self) -> bool {
self.issues.is_empty()
}
#[must_use]
pub fn total_issues(&self) -> usize {
self.issues.len()
}
#[must_use]
pub fn issues_at_level(&self, level: LintLevel) -> Vec<&LintIssue> {
self.issues.iter().filter(|i| i.level == level).collect()
}
#[must_use]
pub fn issues_in_category(&self, category: LintCategory) -> Vec<&LintIssue> {
match self.by_category.get(&category) {
Some(indices) => indices.iter().map(|&i| &self.issues[i]).collect(),
None => Vec::new(),
}
}
}
impl Default for LintReport {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Default)]
pub struct ModelLintInfo {
pub has_license: bool,
pub has_model_card: bool,
pub has_provenance: bool,
pub tensors: Vec<TensorLintInfo>,
pub is_compressed: bool,
pub vocab_size: Option<usize>,
pub hidden_dim: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct TensorLintInfo {
pub name: String,
pub size_bytes: usize,
pub alignment: usize,
pub is_compressed: bool,
pub shape: Vec<usize>,
}
const CANONICAL_PATTERNS: &[(&str, &str)] = &[
("encoder.conv1.weight", "Initial convolution weight"),
("encoder.conv1.bias", "Initial convolution bias"),
("encoder.conv2.weight", "Second convolution weight"),
("encoder.conv2.bias", "Second convolution bias"),
("encoder.positional_embedding", "Encoder position embedding"),
(
"encoder.layer_norm.weight",
"Encoder final layer norm weight",
),
("encoder.layer_norm.bias", "Encoder final layer norm bias"),
("decoder.token_embedding", "Token embedding"),
("decoder.positional_embedding", "Decoder position embedding"),
(
"decoder.layer_norm.weight",
"Decoder final layer norm weight",
),
("decoder.layer_norm.bias", "Decoder final layer norm bias"),
("proj_out.weight", "Output projection weight"),
];
const ABBREVIATION_SUGGESTIONS: &[(&str, &str)] = &[
(".w", ".weight"),
(".b", ".bias"),
("_w", ".weight"),
("_b", ".bias"),
(".wt", ".weight"),
(".bs", ".bias"),
("attn_", "self_attn."),
("ffn_", "fc"),
("ln_", "layer_norm."),
("emb_", "embedding"),
("embed_", "embedding"),
];
#[must_use]
pub fn lint_model(info: &ModelLintInfo) -> LintReport {
let mut report = LintReport::new();
check_metadata(&mut report, info);
check_tensor_naming(&mut report, info);
check_efficiency(&mut report, info);
check_layout_contract(&mut report, info);
report
}
fn check_metadata(report: &mut LintReport, info: &ModelLintInfo) {
if !info.has_license {
report.add_issue(LintIssue::metadata_warn("Missing 'license' field"));
}
if !info.has_model_card {
report.add_issue(LintIssue::metadata_warn("Missing 'model_card'"));
}
if !info.has_provenance {
report.add_issue(LintIssue::metadata_warn("Missing 'provenance' information"));
}
}
fn check_tensor_naming(report: &mut LintReport, info: &ModelLintInfo) {
for tensor in &info.tensors {
for (abbrev, full) in ABBREVIATION_SUGGESTIONS {
if is_abbreviated(&tensor.name, abbrev, full) {
let suggested = tensor.name.replace(abbrev, full);
report.add_issue(
LintIssue::naming_info(format!(
"'{}' should be '{}' for auto-mapping",
tensor.name, suggested
))
.with_suggestion(format!("Rename to '{}'", suggested)),
);
}
}
if is_nonstandard_pattern(&tensor.name) {
report.add_issue(
LintIssue::naming_warn(format!(
"'{}' does not follow canonical naming schema",
tensor.name
))
.with_suggestion("See APR-SPEC §10.8 for canonical tensor naming"),
);
}
}
}
fn is_at_word_boundary(name: &str, position: usize) -> bool {
if position >= name.len() {
return true;
}
matches!(name.chars().nth(position), Some('.' | '_' | '-'))
}
fn is_abbreviated(name: &str, abbrev: &str, full: &str) -> bool {
if name.contains(full) {
return false;
}
name.find(abbrev)
.is_some_and(|pos| is_at_word_boundary(name, pos + abbrev.len()))
}
const STANDARD_NUMERIC_PATTERNS: &[&str] = &[
"layers.", "blk.", "conv1", "conv2", "fc1", "fc2", ];
fn has_standard_numbering(name: &str) -> bool {
STANDARD_NUMERIC_PATTERNS.iter().any(|p| name.contains(p))
}
fn has_unusual_separators(name: &str) -> bool {
name.contains("__") || name.contains("--") || name.contains("..")
}
fn is_nonstandard_pattern(name: &str) -> bool {
let has_odd_numbers = name.chars().any(|c| c.is_ascii_digit()) && !has_standard_numbering(name);
let too_short = !name.is_empty() && name.len() < 5;
has_odd_numbers || has_unusual_separators(name) || too_short
}
fn check_layout_contract(report: &mut LintReport, info: &ModelLintInfo) {
use crate::format::layout_contract::contract;
let (vocab_size, hidden_dim) = match (info.vocab_size, info.hidden_dim) {
(Some(v), Some(h)) => (v, h),
_ => return, };
let layout = contract();
for tensor in &info.tensors {
let Some(tc) = layout.get_apr_contract(&tensor.name) else {
continue;
};
validate_critical_tensor_shape(report, &layout, tc, tensor, vocab_size, hidden_dim);
validate_transpose_dimensions(report, tc, tensor, vocab_size);
}
}
include!("lint.rs");