use std::path::PathBuf;
use assert_cmd::Command;
use llama_rs::gguf::{GgmlType, GgufBuilder, MetadataValue, TensorToWrite};
use predicates::prelude::PredicateBooleanExt;
fn write_minimal_gguf(path: &PathBuf, arch: &str) {
let hidden: u64 = 8;
let n_heads: u64 = 2;
let n_layers: u64 = 2;
let ff: u64 = 16;
let vocab: u64 = 100;
let head_dim: u64 = hidden / n_heads;
let zero_vec = |n: usize| vec![0u8; n * 4];
let mut builder = GgufBuilder::new()
.architecture(arch)
.metadata(format!("{arch}.embedding_length"), MetadataValue::Uint32(hidden as u32))
.metadata(format!("{arch}.attention.head_count"), MetadataValue::Uint32(n_heads as u32))
.metadata(format!("{arch}.attention.head_count_kv"), MetadataValue::Uint32(n_heads as u32))
.metadata(format!("{arch}.block_count"), MetadataValue::Uint32(n_layers as u32))
.metadata(format!("{arch}.feed_forward_length"), MetadataValue::Uint32(ff as u32));
let sizes: &[(&str, Vec<u64>)] = &[
("token_embd.weight", vec![hidden, vocab]),
("output_norm.weight", vec![hidden]),
];
for (name, dims) in sizes {
let numel: u64 = dims.iter().product();
builder = builder.tensor(TensorToWrite::new(
(*name).to_string(),
dims.clone(),
GgmlType::F32,
zero_vec(numel as usize),
));
}
for l in 0..n_layers {
let per_layer: Vec<(&str, Vec<u64>)> = vec![
("attn_norm.weight", vec![hidden]),
("attn_q.weight", vec![hidden, n_heads * head_dim]),
("attn_k.weight", vec![hidden, n_heads * head_dim]),
("attn_v.weight", vec![hidden, n_heads * head_dim]),
("attn_output.weight", vec![hidden, hidden]),
("ffn_norm.weight", vec![hidden]),
("ffn_gate.weight", vec![hidden, ff]),
("ffn_up.weight", vec![hidden, ff]),
("ffn_down.weight", vec![ff, hidden]),
];
for (suffix, dims) in per_layer {
let numel: u64 = dims.iter().product();
builder = builder.tensor(TensorToWrite::new(
format!("blk.{l}.{suffix}"),
dims,
GgmlType::F32,
zero_vec(numel as usize),
));
}
}
builder.write_to_file(path).expect("write gguf");
}
fn bin() -> Command {
Command::cargo_bin("llama-rs-inspect").unwrap()
}
#[test]
fn list_profiles_prints_every_builtin() {
bin()
.arg("list-profiles")
.assert()
.success()
.stdout(predicates::str::contains("llama3"))
.stdout(predicates::str::contains("qwen3"))
.stdout(predicates::str::contains("qwen35moe"));
}
#[test]
fn diff_profile_matches_against_minimal_llama_gguf() {
let tmp = tempfile::tempdir().unwrap();
let path = tmp.path().join("mini.gguf");
write_minimal_gguf(&path, "llama");
bin()
.arg("diff-profile")
.arg(path.to_str().unwrap())
.assert()
.success()
.stdout(
predicates::str::contains("MATCHES").or(predicates::str::contains("OPTIONAL EXTRAS")),
);
}
#[test]
fn diff_profile_json_output_is_valid_json_with_schema_version() {
let tmp = tempfile::tempdir().unwrap();
let path = tmp.path().join("mini.gguf");
write_minimal_gguf(&path, "llama");
let out = bin()
.arg("diff-profile")
.arg(path.to_str().unwrap())
.arg("--json")
.assert()
.success()
.get_output()
.stdout
.clone();
let text = String::from_utf8(out).unwrap();
let value: serde_json::Value = serde_json::from_str(&text).expect("valid JSON");
assert_eq!(value["schema_version"], serde_json::json!(1));
let verdict = value["verdict"].as_str().unwrap_or("");
assert!(
verdict == "matches" || verdict == "optional_extras",
"unexpected verdict: {verdict} ({value})"
);
}
#[test]
fn diff_profile_unknown_arch_returns_exit_3() {
let tmp = tempfile::tempdir().unwrap();
let path = tmp.path().join("mystery.gguf");
write_minimal_gguf(&path, "totally-unknown-arch");
bin()
.arg("diff-profile")
.arg(path.to_str().unwrap())
.assert()
.code(3)
.stdout(predicates::str::contains("UNKNOWN ARCHITECTURE"));
}
#[test]
fn diff_profile_detects_ffn_norm_rename_hypothesis() {
let tmp = tempfile::tempdir().unwrap();
let path = tmp.path().join("rename.gguf");
let hidden: u64 = 8;
let n_layers: u64 = 2;
let zero_vec = |n: usize| vec![0u8; n * 4];
let mut builder = GgufBuilder::new()
.architecture("llama")
.metadata("llama.embedding_length", MetadataValue::Uint32(hidden as u32))
.metadata("llama.attention.head_count", MetadataValue::Uint32(2))
.metadata("llama.attention.head_count_kv", MetadataValue::Uint32(2))
.metadata("llama.block_count", MetadataValue::Uint32(n_layers as u32))
.metadata("llama.feed_forward_length", MetadataValue::Uint32(16));
builder = builder.tensor(TensorToWrite::new(
"token_embd.weight".to_string(),
vec![hidden, 100],
GgmlType::F32,
zero_vec((hidden * 100) as usize),
));
builder = builder.tensor(TensorToWrite::new(
"output_norm.weight".to_string(),
vec![hidden],
GgmlType::F32,
zero_vec(hidden as usize),
));
for l in 0..n_layers {
for suffix in [
"attn_norm.weight",
"attn_q.weight",
"attn_k.weight",
"attn_v.weight",
"attn_output.weight",
"ffn_gate.weight",
"ffn_up.weight",
"ffn_down.weight",
"post_attention_norm.weight",
] {
let dims: Vec<u64> = if suffix.ends_with("norm.weight") {
vec![hidden]
} else if suffix.starts_with("ffn_") {
vec![hidden, 16]
} else {
vec![hidden, hidden]
};
let numel: u64 = dims.iter().product();
builder = builder.tensor(TensorToWrite::new(
format!("blk.{l}.{suffix}"),
dims,
GgmlType::F32,
zero_vec(numel as usize),
));
}
}
builder.write_to_file(&path).unwrap();
bin()
.arg("diff-profile")
.arg(path.to_str().unwrap())
.assert()
.code(2) .stdout(predicates::str::contains("PROFILE MISMATCH"))
.stdout(predicates::str::contains("post_attn_norm_rename"));
}
#[test]
fn invalid_path_returns_exit_1() {
bin()
.arg("diff-profile")
.arg("/definitely/does/not/exist.gguf")
.assert()
.code(1);
}
#[test]
fn generate_profile_writes_parseable_toml() {
let tmp = tempfile::tempdir().unwrap();
let in_path = tmp.path().join("mini.gguf");
let out_path = tmp.path().join("generated.toml");
write_minimal_gguf(&in_path, "llama");
bin()
.arg("generate-profile")
.arg(in_path.to_str().unwrap())
.arg("--out")
.arg(out_path.to_str().unwrap())
.arg("--name")
.arg("mini")
.assert()
.success();
let toml = std::fs::read_to_string(&out_path).unwrap();
assert!(toml.contains("name = \"mini\""));
assert!(toml.contains("architecture = \"llama\""));
let _parsed: toml::Value = toml::from_str(&toml).expect("generated TOML must re-parse");
}