#![cfg(target_os = "macos")]
use std::collections::HashSet;
use std::path::{Path, PathBuf};
use moeflux::Ctx;
const PROMPT_TOKENS: [i32; 4] = [760, 3841, 13477, 37550];
const TOP_K: usize = 20;
const MIN_OVERLAP: usize = 19;
struct Golden {
argmax: u32,
top: Vec<(u32, f32)>, }
fn parse_golden(path: &Path) -> Option<Golden> {
let text = std::fs::read_to_string(path).ok()?;
Some(parse_golden_text(&text))
}
fn parse_golden_text(text: &str) -> Golden {
let mut argmax = None;
let mut top = Vec::new();
for line in text.lines() {
if let Some(rest) = line.strip_prefix("# variant=") {
for field in rest.split_whitespace() {
if let Some(v) = field.strip_prefix("argmax=") {
argmax = Some(v.parse::<u32>().expect("argmax parse"));
}
}
continue;
}
if line.starts_with('#') || line.is_empty() {
continue;
}
let mut parts = line.split('|');
let _rank = parts.next().expect("rank");
let tok: u32 = parts.next().expect("tok").parse().expect("tok int");
let logit: f32 = parts.next().expect("logit").parse().expect("logit f32");
top.push((tok, logit));
}
Golden {
argmax: argmax.expect("golden header missing argmax"),
top,
}
}
fn load_golden_or_skip(path: &Path, variant: &str) -> Golden {
match parse_golden(path) {
Some(g) => g,
None => {
eprintln!(
"[{variant}] SKIP fixture not found at {path:?} — \
regenerate with tools/mlx_reference/generate_goldens.py"
);
std::process::exit(0);
}
}
}
fn fixtures_dir() -> PathBuf {
Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests")
.join("fixtures")
}
fn run_moeflux(art: &Path, root: &Path, experts_per_tok: u32) -> Vec<f32> {
let mut ctx = Ctx::open(
&art.join("model_weights.bin"),
&art.join("model_weights.json"),
&art.join("vocab.bin"),
root,
experts_per_tok,
false,
)
.expect("Ctx::open");
let mut logits = vec![0.0f32; ctx.n_vocab()];
ctx.eval_prompt(&PROMPT_TOKENS, 0usize, 0, &mut logits)
.expect("eval_prompt");
logits
}
fn argmax_of(logits: &[f32]) -> u32 {
let (idx, _) = logits
.iter()
.enumerate()
.fold((0usize, f32::NEG_INFINITY), |(bi, bv), (i, &v)| {
if v > bv { (i, v) } else { (bi, bv) }
});
idx as u32
}
fn top_k_set(logits: &[f32], k: usize) -> HashSet<u32> {
let mut pairs: Vec<(f32, u32)> = logits
.iter()
.enumerate()
.map(|(i, &v)| (v, i as u32))
.collect();
pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
pairs.into_iter().take(k).map(|(_, i)| i).collect()
}
fn assert_matches_golden(logits: &[f32], golden: &Golden, variant: &str) {
let mf_argmax = argmax_of(logits);
assert_eq!(
mf_argmax, golden.argmax,
"[{variant}] moeflux argmax {} != MLX argmax {}",
mf_argmax, golden.argmax
);
let mf_top: HashSet<u32> = top_k_set(logits, TOP_K);
let mx_top: HashSet<u32> = golden.top.iter().take(TOP_K).map(|(id, _)| *id).collect();
let overlap = mf_top.intersection(&mx_top).count();
assert!(
overlap >= MIN_OVERLAP,
"[{variant}] top-{TOP_K} overlap {overlap}/{TOP_K} below threshold {MIN_OVERLAP}.\n \
moeflux: {:?}\n mlx: {:?}\n mf-only: {:?}\n mx-only: {:?}",
{
let mut v: Vec<u32> = mf_top.iter().copied().collect();
v.sort();
v
},
{
let mut v: Vec<u32> = mx_top.iter().copied().collect();
v.sort();
v
},
{
let mut v: Vec<u32> = mf_top.difference(&mx_top).copied().collect();
v.sort();
v
},
{
let mut v: Vec<u32> = mx_top.difference(&mf_top).copied().collect();
v.sort();
v
},
);
let mf_top_logit = logits[mf_argmax as usize];
let mx_top_logit = golden.top[0].1;
eprintln!(
"[{variant}] PASS argmax={mf_argmax} overlap={overlap}/{TOP_K} \
top-1 logit moeflux={mf_top_logit:.4} mlx={mx_top_logit:.4}"
);
}
#[cfg(feature = "model-qwen3-6-35b-a3b")]
#[test]
#[ignore]
fn mlx_regression_a3b() {
let art = PathBuf::from(
std::env::var("MOEFLUX_SMOKE_ARTIFACTS")
.unwrap_or("/Volumes/Temp Backup/models/moeflux/qwen3-6-35b-a3b-artifacts".into()),
);
let root = PathBuf::from(
std::env::var("MOEFLUX_SMOKE_ROOT")
.unwrap_or("/Volumes/Temp Backup/models/moeflux/qwen3-6-35b-a3b-root".into()),
);
let golden = load_golden_or_skip(
&fixtures_dir().join("mlx_golden_qwen3-6-35b-a3b.txt"),
"qwen3-6-35b-a3b",
);
let logits = run_moeflux(&art, &root, 8);
assert_matches_golden(&logits, &golden, "qwen3-6-35b-a3b");
}
#[cfg(feature = "model-qwen3-5-a17b")]
#[test]
#[ignore]
fn mlx_regression_a17b() {
let art = PathBuf::from(
std::env::var("MOEFLUX_SMOKE_ARTIFACTS")
.unwrap_or("/Volumes/Temp Backup/models/moeflux/qwen3-5-a17b-artifacts".into()),
);
let root = PathBuf::from(
std::env::var("MOEFLUX_SMOKE_ROOT")
.unwrap_or("/Volumes/Temp Backup/models/moeflux/qwen3-5-a17b-root".into()),
);
let golden = load_golden_or_skip(
&fixtures_dir().join("mlx_golden_qwen3-5-a17b.txt"),
"qwen3-5-a17b",
);
let logits = run_moeflux(&art, &root, 10);
assert_matches_golden(&logits, &golden, "qwen3-5-a17b");
}