use crate::arch_requirements::{required_roles, WeightRole};
use crate::error::RealizarError;
use crate::gguf::ArchConstraints;
use std::fmt;
pub use trueno::contracts::{
self as kernel_contracts, validate_f32_buffer, validate_gemv_shapes, validate_weight_buffer,
QuantFormat, TensorLayout, WeightBufferError, STACK_LAYOUT,
};
#[derive(Debug, Clone)]
pub struct ModelLoadProof {
architecture: String,
num_layers: usize,
}
impl ModelLoadProof {
#[must_use]
pub fn architecture(&self) -> &str {
&self.architecture
}
#[must_use]
pub fn num_layers(&self) -> usize {
self.num_layers
}
}
#[derive(Debug, Clone)]
pub struct ModelLoadConfig {
pub architecture: String,
pub num_layers: usize,
pub hidden_dim: usize,
pub num_heads: usize,
pub num_kv_heads: usize,
pub intermediate_dim: usize,
pub vocab_size: usize,
pub present_roles: Vec<WeightRole>,
}
#[derive(Debug, Clone)]
pub struct ModelLoadError {
pub gate: &'static str,
pub reason: String,
}
impl fmt::Display for ModelLoadError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"GH-279 contract gate '{}' failed: {}",
self.gate, self.reason
)
}
}
impl std::error::Error for ModelLoadError {}
impl From<ModelLoadError> for RealizarError {
fn from(e: ModelLoadError) -> Self {
RealizarError::UnsupportedOperation {
operation: format!("contract_gate::{}", e.gate),
reason: e.reason,
}
}
}
pub fn validate_model_load(
config: &ModelLoadConfig,
) -> std::result::Result<ModelLoadProof, ModelLoadError> {
validate_dimensions(config)?;
let arch = validate_architecture(&config.architecture)?;
if !config.present_roles.is_empty() {
validate_completeness(&arch, &config.present_roles, &config.architecture)?;
}
Ok(ModelLoadProof {
architecture: config.architecture.clone(),
num_layers: config.num_layers,
})
}
pub fn validate_model_load_basic(
architecture: &str,
num_layers: usize,
hidden_dim: usize,
num_heads: usize,
num_kv_heads: usize,
intermediate_dim: usize,
vocab_size: usize,
) -> std::result::Result<ModelLoadProof, ModelLoadError> {
validate_model_load(&ModelLoadConfig {
architecture: architecture.to_string(),
num_layers,
hidden_dim,
num_heads,
num_kv_heads,
intermediate_dim,
vocab_size,
present_roles: Vec::new(), })
}
pub fn gate_error(e: ModelLoadError) -> RealizarError {
e.into()
}
fn require_nonzero(field_name: &str, value: usize) -> std::result::Result<(), ModelLoadError> {
if value == 0 {
return Err(ModelLoadError {
gate: "dimension_plausibility",
reason: format!("{field_name} is 0"),
});
}
Ok(())
}
fn validate_dimensions(config: &ModelLoadConfig) -> std::result::Result<(), ModelLoadError> {
require_nonzero("hidden_dim", config.hidden_dim)?;
require_nonzero("num_heads", config.num_heads)?;
if !config.hidden_dim.is_multiple_of(config.num_heads) {
return Err(ModelLoadError {
gate: "dimension_plausibility",
reason: format!(
"hidden_dim ({}) is not divisible by num_heads ({})",
config.hidden_dim, config.num_heads
),
});
}
require_nonzero("vocab_size", config.vocab_size)?;
require_nonzero("num_kv_heads", config.num_kv_heads)?;
if config.num_kv_heads > config.num_heads {
return Err(ModelLoadError {
gate: "dimension_plausibility",
reason: format!(
"num_kv_heads ({}) > num_heads ({})",
config.num_kv_heads, config.num_heads
),
});
}
require_nonzero("intermediate_dim", config.intermediate_dim)?;
require_nonzero("num_layers", config.num_layers)?;
Ok(())
}
fn validate_architecture(arch_name: &str) -> std::result::Result<ArchConstraints, ModelLoadError> {
let arch = ArchConstraints::from_architecture(arch_name);
Ok(arch)
}
fn validate_completeness(
arch: &ArchConstraints,
present: &[WeightRole],
arch_name: &str,
) -> std::result::Result<(), ModelLoadError> {
let required = required_roles(arch);
let mut missing = Vec::new();
for &role in required {
if !present.contains(&role) {
missing.push(role.field_name());
}
}
if !missing.is_empty() {
return Err(ModelLoadError {
gate: "architecture_completeness",
reason: format!(
"Architecture '{}' requires {} weights but model is missing: [{}]",
arch_name,
required.len(),
missing.join(", "),
),
});
}
Ok(())
}
pub fn validate_f32_dequant_limits(
tensor_entries: &[(usize, u8)],
file_size: u64,
) -> std::result::Result<(), ModelLoadError> {
let mut estimated_f32_bytes: u64 = 0;
for &(byte_size, dtype) in tensor_entries {
let elements = estimate_elements(byte_size, dtype);
estimated_f32_bytes += elements as u64 * 4;
}
let estimated_peak = file_size + estimated_f32_bytes;
let mem_total = system_memory_bytes().unwrap_or(u64::MAX);
let threshold = mem_total * 80 / 100;
if estimated_peak > threshold {
return Err(ModelLoadError {
gate: "resource_limits",
reason: format!(
"F32 dequant would use ~{} GB (file {} GB + dequant {} GB), \
exceeds 80% of system RAM ({} GB). Use quantized inference path.",
estimated_peak / (1 << 30),
file_size / (1 << 30),
estimated_f32_bytes / (1 << 30),
mem_total / (1 << 30),
),
});
}
Ok(())
}
fn estimate_elements(byte_size: usize, dtype: u8) -> usize {
match dtype {
12 => byte_size / 144 * 256, 14 => byte_size / 210 * 256, 2 => byte_size / 36 * 32, 1 => byte_size / 2, 30 => byte_size / 2, 8 => byte_size / 5 * 4, 9 => byte_size / 5 * 4, _ => byte_size / 4, }
}
pub fn exceeds_f32_dequant_estimate(file_size: u64) -> bool {
if file_size == 0 {
return false;
}
let estimated_peak = file_size.saturating_mul(8);
let mem_total = system_memory_bytes().unwrap_or(u64::MAX);
estimated_peak > mem_total * 80 / 100
}
pub fn system_memory_bytes() -> Option<u64> {
let content = std::fs::read_to_string("/proc/meminfo").ok()?;
for line in content.lines() {
if line.starts_with("MemTotal:") {
let kb: u64 = line.split_whitespace().nth(1)?.parse().ok()?;
return Some(kb * 1024);
}
}
None
}
#[must_use]
pub fn transpose_f32(data: &[f32], rows: usize, cols: usize) -> Vec<f32> {
assert_eq!(
data.len(),
rows * cols,
"transpose_f32: data.len()={} != rows*cols={}",
data.len(),
rows * cols
);
let mut out = vec![0.0f32; rows * cols];
trueno::blis::transpose::transpose(rows, cols, data, &mut out)
.expect("transpose_f32: dimension mismatch (should be impossible after assert)");
out
}
#[cfg(test)]
mod tests {
use super::*;
fn valid_config() -> ModelLoadConfig {
ModelLoadConfig {
architecture: "llama".to_string(),
num_layers: 32,
hidden_dim: 4096,
num_heads: 32,
num_kv_heads: 8,
intermediate_dim: 11008,
vocab_size: 32000,
present_roles: Vec::new(),
}
}
#[test]
fn test_valid_model_passes() {
let proof = validate_model_load(&valid_config()).expect("should pass");
assert_eq!(proof.architecture(), "llama");
assert_eq!(proof.num_layers(), 32);
}
#[test]
fn test_zero_hidden_dim_fails() {
let mut config = valid_config();
config.hidden_dim = 0;
let err = validate_model_load(&config).unwrap_err();
assert_eq!(err.gate, "dimension_plausibility");
assert!(err.reason.contains("hidden_dim"));
}
#[test]
fn test_zero_num_heads_fails() {
let mut config = valid_config();
config.num_heads = 0;
let err = validate_model_load(&config).unwrap_err();
assert!(err.reason.contains("num_heads"));
}
#[test]
fn test_hidden_not_divisible_by_heads() {
let mut config = valid_config();
config.hidden_dim = 4097;
let err = validate_model_load(&config).unwrap_err();
assert!(err.reason.contains("not divisible"));
}
#[test]
fn test_kv_heads_greater_than_heads() {
let mut config = valid_config();
config.num_kv_heads = 64;
let err = validate_model_load(&config).unwrap_err();
assert!(err.reason.contains("num_kv_heads"));
}
#[test]
fn test_zero_vocab_fails() {
let mut config = valid_config();
config.vocab_size = 0;
let err = validate_model_load(&config).unwrap_err();
assert!(err.reason.contains("vocab_size"));
}
#[test]
fn test_zero_layers_fails() {
let mut config = valid_config();
config.num_layers = 0;
let err = validate_model_load(&config).unwrap_err();
assert!(err.reason.contains("num_layers"));
}
#[test]
fn test_zero_intermediate_fails() {
let mut config = valid_config();
config.intermediate_dim = 0;
let err = validate_model_load(&config).unwrap_err();
assert!(err.reason.contains("intermediate_dim"));
}
#[test]
fn test_basic_convenience() {
let proof =
validate_model_load_basic("qwen2", 28, 1536, 12, 2, 8960, 151936).expect("should pass");
assert_eq!(proof.architecture(), "qwen2");
}
#[test]
fn test_completeness_llama_all_present() {
let mut config = valid_config();
config.present_roles = vec![
WeightRole::AttnNorm,
WeightRole::FfnNorm,
WeightRole::QProj,
WeightRole::KProj,
WeightRole::VProj,
WeightRole::OProj,
WeightRole::FfnGate,
WeightRole::FfnUp,
WeightRole::FfnDown,
];
assert!(validate_model_load(&config).is_ok());
}
#[test]
fn test_completeness_llama_missing_gate() {
let mut config = valid_config();
config.present_roles = vec![
WeightRole::AttnNorm,
WeightRole::FfnNorm,
WeightRole::QProj,
WeightRole::KProj,
WeightRole::VProj,
WeightRole::OProj,
];
let err = validate_model_load(&config).unwrap_err();
assert_eq!(err.gate, "architecture_completeness");
assert!(err.reason.contains("ffn_gate"));
}
#[test]
fn test_completeness_qwen3_needs_qk_norm() {
let mut config = valid_config();
config.architecture = "qwen3".to_string();
config.present_roles = vec![
WeightRole::AttnNorm,
WeightRole::FfnNorm,
WeightRole::QProj,
WeightRole::KProj,
WeightRole::VProj,
WeightRole::OProj,
WeightRole::FfnGate,
WeightRole::FfnUp,
WeightRole::FfnDown,
];
let err = validate_model_load(&config).unwrap_err();
assert!(err.reason.contains("attn_q_norm"));
}
#[test]
fn test_completeness_qwen3_with_qk_norm_passes() {
let mut config = valid_config();
config.architecture = "qwen3".to_string();
config.present_roles = vec![
WeightRole::AttnNorm,
WeightRole::FfnNorm,
WeightRole::QProj,
WeightRole::KProj,
WeightRole::VProj,
WeightRole::OProj,
WeightRole::FfnGate,
WeightRole::FfnUp,
WeightRole::FfnDown,
WeightRole::AttnQNorm,
WeightRole::AttnKNorm,
];
assert!(validate_model_load(&config).is_ok());
}
#[test]
fn test_no_roles_skips_completeness() {
let config = valid_config();
assert!(config.present_roles.is_empty());
assert!(validate_model_load(&config).is_ok());
}
#[test]
fn test_unknown_architecture_uses_base() {
let proof = validate_model_load_basic("unknown_future_arch", 1, 128, 4, 4, 512, 1000)
.expect("unknown arch should pass with base constraints");
assert_eq!(proof.architecture(), "unknown_future_arch");
}
#[test]
fn test_error_display() {
let err = ModelLoadError {
gate: "test_gate",
reason: "test reason".to_string(),
};
let msg = format!("{err}");
assert!(msg.contains("GH-279"));
assert!(msg.contains("test_gate"));
assert!(msg.contains("test reason"));
}
#[test]
fn test_error_converts_to_realizar_error() {
let err = ModelLoadError {
gate: "test",
reason: "test".to_string(),
};
let r_err: RealizarError = err.into();
match r_err {
RealizarError::UnsupportedOperation { operation, .. } => {
assert!(operation.contains("contract_gate"));
},
_ => panic!("expected UnsupportedOperation"),
}
}
#[test]
fn test_estimate_elements_f32() {
assert_eq!(estimate_elements(400, 0), 100);
}
#[test]
fn test_estimate_elements_q4k() {
assert_eq!(estimate_elements(144, 12), 256);
assert_eq!(estimate_elements(288, 12), 512);
}
#[test]
fn test_estimate_elements_q6k() {
assert_eq!(estimate_elements(210, 14), 256);
}
#[test]
fn test_estimate_elements_f16() {
assert_eq!(estimate_elements(200, 1), 100);
}
#[test]
fn test_estimate_elements_bf16() {
assert_eq!(estimate_elements(200, 30), 100);
}
#[test]
fn test_small_model_passes_resource_check() {
let tensors: Vec<(usize, u8)> = vec![
(144 * 1000, 12), ];
let result = validate_f32_dequant_limits(&tensors, 4_000_000_000);
let _ = result;
}
#[test]
fn test_system_memory_bytes_returns_some_on_linux() {
if cfg!(target_os = "linux") {
let mem = system_memory_bytes();
assert!(
mem.is_some(),
"system_memory_bytes should return Some on Linux"
);
assert!(mem.unwrap() > 0, "system memory should be > 0");
}
}
#[test]
fn test_exceeds_f32_dequant_estimate_zero_file() {
assert!(!exceeds_f32_dequant_estimate(0));
}
}