#[must_use]
#[ensures(ret.0.len() == input_shape.len())]
#[ensures(ret.0.iter().product::<usize>() == input_shape.iter().product::<usize>())]
#[ensures(!ret.1)]
pub fn enforce_import_contract(
tensor_name: &str,
input_shape: &[usize],
_vocab_size: usize,
_hidden_dim: usize,
) -> (Vec<usize>, bool) {
let layout = contract();
let tc = layout
.get_gguf_contract(tensor_name)
.or_else(|| layout.get_apr_contract(tensor_name));
let apr_shape = match tc {
Some(tc) => known_tensor_apr_shape(input_shape, tc.should_transpose),
None => unknown_tensor_apr_shape(input_shape),
};
(apr_shape, false)
}
fn known_tensor_apr_shape(input_shape: &[usize], should_transpose: bool) -> Vec<usize> {
if should_transpose && input_shape.len() == 2 {
vec![input_shape[1], input_shape[0]]
} else {
input_shape.to_vec()
}
}
fn unknown_tensor_apr_shape(input_shape: &[usize]) -> Vec<usize> {
if input_shape.len() == 2 {
vec![input_shape[1], input_shape[0]]
} else {
input_shape.to_vec()
}
}
pub fn enforce_load_contract(
apr_name: &str,
apr_shape: &[usize],
vocab_size: usize,
hidden_dim: usize,
) -> Result<(), ContractError> {
let layout = contract();
if let Some(tc) = layout.get_apr_contract(apr_name) {
if tc.is_critical {
layout.validate_apr_shape(apr_name, apr_shape, vocab_size, hidden_dim)?;
}
}
Ok(())
}
pub fn enforce_embedding_contract(embedding_len: usize, vocab_size: usize, hidden_dim: usize) {
let expected_len = vocab_size * hidden_dim;
assert_eq!(
embedding_len, expected_len,
"CONTRACT VIOLATION: Embedding length {} != vocab({}) * hidden({}) = {}. \
This will cause garbage inference output. \
See: contracts/tensor-layout-v1.yaml",
embedding_len, vocab_size, hidden_dim, expected_len
);
}
pub fn enforce_matmul_contract(
tensor_name: &str,
weight_shape: &[usize],
expected_out_dim: usize,
expected_in_dim: usize,
) {
assert_eq!(
weight_shape.len(),
2,
"CONTRACT VIOLATION: {} must be 2D, got {:?}",
tensor_name,
weight_shape
);
assert_eq!(
weight_shape[0], expected_out_dim,
"CONTRACT VIOLATION: {} shape[0]={} but kernel expects out_dim={}. \
See: contracts/tensor-layout-v1.yaml",
tensor_name, weight_shape[0], expected_out_dim
);
assert_eq!(
weight_shape[1], expected_in_dim,
"CONTRACT VIOLATION: {} shape[1]={} but kernel expects in_dim={}. \
See: contracts/tensor-layout-v1.yaml",
tensor_name, weight_shape[1], expected_in_dim
);
}
#[must_use]
fn required_tensor_pattern_pairs(
has_qk_norm: bool,
has_bias: bool,
) -> Vec<(&'static str, &'static str)> {
let mut pairs = vec![
(
"blk.{i}.attn_norm.weight",
"model.layers.{i}.input_layernorm.weight",
),
(
"blk.{i}.ffn_norm.weight",
"model.layers.{i}.post_attention_layernorm.weight",
),
(
"blk.{i}.attn_q.weight",
"model.layers.{i}.self_attn.q_proj.weight",
),
(
"blk.{i}.attn_k.weight",
"model.layers.{i}.self_attn.k_proj.weight",
),
(
"blk.{i}.attn_v.weight",
"model.layers.{i}.self_attn.v_proj.weight",
),
(
"blk.{i}.attn_output.weight",
"model.layers.{i}.self_attn.o_proj.weight",
),
(
"blk.{i}.ffn_gate.weight",
"model.layers.{i}.mlp.gate_proj.weight",
),
(
"blk.{i}.ffn_up.weight",
"model.layers.{i}.mlp.up_proj.weight",
),
(
"blk.{i}.ffn_down.weight",
"model.layers.{i}.mlp.down_proj.weight",
),
];
if has_qk_norm {
pairs.push((
"blk.{i}.attn_q_norm.weight",
"model.layers.{i}.self_attn.q_norm.weight",
));
pairs.push((
"blk.{i}.attn_k_norm.weight",
"model.layers.{i}.self_attn.k_norm.weight",
));
}
if has_bias {
pairs.push((
"blk.{i}.attn_q.bias",
"model.layers.{i}.self_attn.q_proj.bias",
));
pairs.push((
"blk.{i}.attn_k.bias",
"model.layers.{i}.self_attn.k_proj.bias",
));
pairs.push((
"blk.{i}.attn_v.bias",
"model.layers.{i}.self_attn.v_proj.bias",
));
}
pairs
}
#[provable_contracts_macros::contract(
"architecture-requirements-v1",
equation = "import_completeness_gate"
)]
pub fn enforce_architecture_completeness(
tensor_names: &[&str],
architecture: &str,
num_layers: usize,
) -> Result<(), ContractError> {
let (has_qk_norm, has_bias) = match architecture {
"qwen3" => (true, false),
"qwen3_5" | "qwen3.5" => (false, false), "qwen2" | "qwen2.5" | "qwen" => (false, true),
"phi" | "phi2" | "phi3" => (false, true),
_ => (false, false), };
let pairs = required_tensor_pattern_pairs(has_qk_norm, has_bias);
for layer_idx in 0..num_layers {
for (apr_pat, hf_pat) in &pairs {
let apr_name = apr_pat.replace("{i}", &layer_idx.to_string());
let hf_name = hf_pat.replace("{i}", &layer_idx.to_string());
let found = tensor_names.iter().any(|n| *n == apr_name || *n == hf_name);
if !found {
return Err(ContractError::TransposeError {
tensor: apr_name,
message: format!(
"GH-279: Missing required tensor for architecture '{}' \
(checked both APR and HF naming) \
— see contracts/architecture-requirements-v1.yaml",
architecture
),
});
}
}
}
Ok(())
}
#[cfg(test)]
mod architecture_completeness_tests {
use super::*;
#[test]
fn test_llama_base_complete() {
let owned: Vec<String> = (0..2)
.flat_map(|i| {
vec![
format!("blk.{i}.attn_norm.weight"),
format!("blk.{i}.ffn_norm.weight"),
format!("blk.{i}.attn_q.weight"),
format!("blk.{i}.attn_k.weight"),
format!("blk.{i}.attn_v.weight"),
format!("blk.{i}.attn_output.weight"),
format!("blk.{i}.ffn_gate.weight"),
format!("blk.{i}.ffn_up.weight"),
format!("blk.{i}.ffn_down.weight"),
]
})
.collect();
let refs: Vec<&str> = owned.iter().map(String::as_str).collect();
assert!(enforce_architecture_completeness(&refs, "llama", 2).is_ok());
}
#[test]
fn test_llama_missing_ffn_gate() {
let owned: Vec<String> = (0..2)
.flat_map(|i| {
let mut v = vec![
format!("blk.{i}.attn_norm.weight"),
format!("blk.{i}.ffn_norm.weight"),
format!("blk.{i}.attn_q.weight"),
format!("blk.{i}.attn_k.weight"),
format!("blk.{i}.attn_v.weight"),
format!("blk.{i}.attn_output.weight"),
format!("blk.{i}.ffn_up.weight"),
format!("blk.{i}.ffn_down.weight"),
];
if i == 0 {
v.push(format!("blk.{i}.ffn_gate.weight"));
}
v
})
.collect();
let refs: Vec<&str> = owned.iter().map(String::as_str).collect();
let result = enforce_architecture_completeness(&refs, "llama", 2);
assert!(result.is_err());
let err = result.unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("blk.1.ffn_gate.weight"),
"Error should name the missing tensor: {msg}"
);
}
#[test]
fn test_qwen3_requires_qk_norm() {
let owned: Vec<String> = (0..1)
.flat_map(|i| {
vec![
format!("blk.{i}.attn_norm.weight"),
format!("blk.{i}.ffn_norm.weight"),
format!("blk.{i}.attn_q.weight"),
format!("blk.{i}.attn_k.weight"),
format!("blk.{i}.attn_v.weight"),
format!("blk.{i}.attn_output.weight"),
format!("blk.{i}.ffn_gate.weight"),
format!("blk.{i}.ffn_up.weight"),
format!("blk.{i}.ffn_down.weight"),
]
})
.collect();
let refs: Vec<&str> = owned.iter().map(String::as_str).collect();
let result = enforce_architecture_completeness(&refs, "qwen3", 1);
assert!(result.is_err());
let msg = format!("{}", result.unwrap_err());
assert!(
msg.contains("attn_q_norm"),
"Should require QK norm for Qwen3: {}",
msg
);
}
#[test]
fn test_qwen3_complete_with_qk_norm() {
let owned: Vec<String> = (0..1)
.flat_map(|i| {
vec![
format!("blk.{i}.attn_norm.weight"),
format!("blk.{i}.ffn_norm.weight"),
format!("blk.{i}.attn_q.weight"),
format!("blk.{i}.attn_k.weight"),
format!("blk.{i}.attn_v.weight"),
format!("blk.{i}.attn_output.weight"),
format!("blk.{i}.ffn_gate.weight"),
format!("blk.{i}.ffn_up.weight"),
format!("blk.{i}.ffn_down.weight"),
format!("blk.{i}.attn_q_norm.weight"),
format!("blk.{i}.attn_k_norm.weight"),
]
})
.collect();
let refs: Vec<&str> = owned.iter().map(String::as_str).collect();
assert!(enforce_architecture_completeness(&refs, "qwen3", 1).is_ok());
}
#[test]
fn test_qwen2_requires_bias() {
let owned: Vec<String> = (0..1)
.flat_map(|i| {
vec![
format!("blk.{i}.attn_norm.weight"),
format!("blk.{i}.ffn_norm.weight"),
format!("blk.{i}.attn_q.weight"),
format!("blk.{i}.attn_k.weight"),
format!("blk.{i}.attn_v.weight"),
format!("blk.{i}.attn_output.weight"),
format!("blk.{i}.ffn_gate.weight"),
format!("blk.{i}.ffn_up.weight"),
format!("blk.{i}.ffn_down.weight"),
]
})
.collect();
let refs: Vec<&str> = owned.iter().map(String::as_str).collect();
let result = enforce_architecture_completeness(&refs, "qwen2", 1);
assert!(result.is_err());
let msg = format!("{}", result.unwrap_err());
assert!(
msg.contains("bias"),
"Should require bias for Qwen2: {}",
msg
);
}
}
pub fn validate_ffn_shape_symmetry(
gate_shape: &[usize],
up_shape: &[usize],
down_shape: &[usize],
) -> Result<(), ContractError> {
if gate_shape != up_shape {
return Err(ContractError::ShapeMismatch {
tensor: "ffn_gate_proj/up_proj".to_string(),
expected: format!("gate_proj {:?} == up_proj {:?}", gate_shape, up_shape),
actual: up_shape.to_vec(),
});
}
if gate_shape.len() == 2
&& down_shape.len() == 2
&& (down_shape[0] != gate_shape[1] || down_shape[1] != gate_shape[0])
{
return Err(ContractError::ShapeMismatch {
tensor: "ffn_down_proj".to_string(),
expected: format!(
"down_proj [{}, {}] (reversed from gate [{}, {}])",
gate_shape[1], gate_shape[0], gate_shape[0], gate_shape[1]
),
actual: down_shape.to_vec(),
});
}
Ok(())
}
#[cfg(test)]
mod ffn_shape_tests {
use super::*;
#[test]
fn test_ffn_valid_shapes() {
let gate = [4864, 896];
let up = [4864, 896];
let down = [896, 4864];
assert!(validate_ffn_shape_symmetry(&gate, &up, &down).is_ok());
}
#[test]
fn test_ffn_gate_up_mismatch() {
let gate = [4864, 896];
let up = [3072, 896]; let down = [896, 4864];
let result = validate_ffn_shape_symmetry(&gate, &up, &down);
assert!(result.is_err());
}
#[test]
fn test_ffn_down_not_reversed() {
let gate = [4864, 896];
let up = [4864, 896];
let down = [4864, 896]; let result = validate_ffn_shape_symmetry(&gate, &up, &down);
assert!(result.is_err());
}
#[test]
fn test_ffn_1d_shapes_accepted() {
let gate = [4864];
let up = [4864];
let down = [896];
assert!(validate_ffn_shape_symmetry(&gate, &up, &down).is_ok());
}
}