#![cfg(all(
target_os = "macos",
any(
feature = "model-qwen3-5-a17b",
feature = "model-qwen3-6-35b-a3b",
),
))]
use std::path::Path;
use std::time::Instant;
mod common;
use common::diff_helpers::{
argmax, artifacts_dir, cosine_sim, default_a3b_paths,
};
use moeflux::riir::RsCtx;
pub trait DiffBackend {
fn open(
weights: &Path,
manifest: &Path,
vocab: &Path,
experts_dir: &Path,
experts_per_tok: u32,
use_2bit: bool,
) -> Self;
fn n_vocab(&self) -> usize;
fn n_ctx(&self) -> usize;
fn eos(&self) -> i32;
fn model_name(&self) -> &'static str;
fn embed(&self, token_id: i32) -> Vec<f32>;
fn rms_norm_cpu(&self, weight_name: &str, x: &[f32]) -> Vec<f32>;
fn apply_rotary_emb(
&self,
pos: i32,
q: &[f32],
k: &[f32],
) -> (Vec<f32>, Vec<f32>);
fn rms_norm_per_head_cpu(
&self,
weight_name: &str,
num_heads: usize,
head_dim: usize,
x: &[f32],
) -> Vec<f32>;
fn sdpa_cpu(
&self,
kv_len: i32,
q: &[f32],
q_gate: &[f32],
k_cache: &[f32],
v_cache: &[f32],
) -> Vec<f32>;
fn lm_head_cpu(&self, x: &[f32]) -> Vec<f32>;
fn moe_router_cpu(&self, scores: Vec<f32>, k: usize) -> (Vec<i32>, Vec<f32>);
fn conv1d_step_cpu(
&self,
weight_name: &str,
channels: usize,
kernel_size: usize,
conv_state: &[f32],
new_input: &[f32],
) -> Vec<f32>;
fn rms_norm_bare_cpu(&self, eps: f32, x: &[f32]) -> Vec<f32>;
fn rms_norm_gated_cpu(
&self,
weight_name: &str,
eps: f32,
x: &[f32],
z: &[f32],
) -> Vec<f32>;
#[allow(clippy::too_many_arguments)]
fn gated_delta_recurrence_cpu(
&self,
layer_idx: usize,
alpha: &[f32],
beta: &[f32],
q: &[f32],
k: &[f32],
v: &[f32],
v_heads: usize,
k_heads: usize,
key_dim: usize,
value_dim: usize,
ssm_state_in: Vec<f32>,
) -> (Vec<f32>, Vec<f32>);
fn load_expert_bytes(&self, layer_idx: i32, expert_idx: i32) -> Vec<u8>;
fn gpu_rms_norm_fused(
&mut self,
x: &[f32],
weight_bf16: &[u8],
) -> Vec<f32>;
fn gpu_expert_forward(
&mut self,
expert_data: &[u8],
h_post: &[f32],
) -> Vec<f32>;
#[allow(clippy::too_many_arguments)]
fn gpu_batched_experts_forward(
&mut self,
actual_k: i32,
expert_data: &[u8],
h_post: &[f32],
h_mid: &[f32],
shared_out: &[f32],
expert_weights: &[f32],
shared_gate_score: f32,
) -> Vec<f32>;
#[allow(clippy::too_many_arguments)]
fn attn_scores_batched(
&mut self,
num_heads: u32,
num_kv_heads: u32,
head_dim: u32,
seq_len: u32,
q: &[f32],
k_cache: &[f32],
scale: f32,
) -> Vec<f32>;
fn attn_softmax_batched(
&mut self,
num_heads: u32,
seq_len: u32,
scores_in: &[f32],
) -> Vec<f32>;
#[allow(clippy::too_many_arguments)]
fn attn_values_batched(
&mut self,
num_heads: u32,
num_kv_heads: u32,
head_dim: u32,
seq_len: u32,
scores: &[f32],
v_cache: &[f32],
) -> Vec<f32>;
fn sigmoid_gate(
&mut self,
dim: u32,
gate: &[f32],
x_in: &[f32],
) -> Vec<f32>;
#[allow(clippy::too_many_arguments)]
fn begin_deferred_experts(
&mut self,
actual_k: i32,
expert_data: &[u8],
h_post: &[f32],
h_mid: &[f32],
shared_out: &[f32],
expert_weights: &[f32],
shared_gate_score: f32,
);
fn complete_deferred_experts(&mut self) -> Vec<f32>;
fn discard_deferred_experts(&mut self);
fn layer_forward_dump(
&mut self,
layer_idx: i32,
pos: i32,
hidden_in: &[f32],
) -> Vec<f32>;
fn eval_prompt(&mut self, tokens: &[i32], start_pos: usize) -> Vec<f32>;
fn eval_token(&mut self, token: i32, pos: usize) -> Vec<f32>;
fn memory_clear(&mut self);
fn memory_seq_rm(&mut self, p0: i32, p1: i32) -> bool;
fn memory_seq_pos_max(&self) -> i32;
}
pub struct RsBackend(RsCtx);
impl DiffBackend for RsBackend {
fn open(
weights: &Path,
manifest: &Path,
vocab: &Path,
experts_dir: &Path,
experts_per_tok: u32,
use_2bit: bool,
) -> Self {
Self(
RsCtx::open(
weights,
manifest,
vocab,
experts_dir,
experts_per_tok,
use_2bit,
)
.expect("RsBackend RsCtx::open"),
)
}
fn n_vocab(&self) -> usize {
self.0.n_vocab()
}
fn n_ctx(&self) -> usize {
self.0.n_ctx()
}
fn eos(&self) -> i32 {
self.0.eos()
}
fn model_name(&self) -> &'static str {
self.0.model_name()
}
fn embed(&self, token_id: i32) -> Vec<f32> {
let mut out = vec![0.0f32; moeflux::riir::VARIANT.hidden_dim];
self.0.embed(token_id, &mut out).expect("RsBackend embed");
out
}
fn rms_norm_cpu(&self, weight_name: &str, x: &[f32]) -> Vec<f32> {
let mut out = vec![0.0f32; moeflux::riir::VARIANT.hidden_dim];
self.0
.rms_norm_cpu(weight_name, x, &mut out)
.expect("RsBackend rms_norm_cpu");
out
}
fn apply_rotary_emb(
&self,
pos: i32,
q: &[f32],
k: &[f32],
) -> (Vec<f32>, Vec<f32>) {
let mut q_out = q.to_vec();
let mut k_out = k.to_vec();
self.0
.apply_rotary_emb(pos, &mut q_out, &mut k_out)
.expect("RsBackend apply_rotary_emb");
(q_out, k_out)
}
fn rms_norm_per_head_cpu(
&self,
weight_name: &str,
num_heads: usize,
head_dim: usize,
x: &[f32],
) -> Vec<f32> {
let mut out = x.to_vec();
self.0
.rms_norm_per_head_cpu(weight_name, num_heads, head_dim, &mut out)
.expect("RsBackend rms_norm_per_head_cpu");
out
}
fn sdpa_cpu(
&self,
kv_len: i32,
q: &[f32],
q_gate: &[f32],
k_cache: &[f32],
v_cache: &[f32],
) -> Vec<f32> {
let mut out = vec![0.0f32; q.len()];
self.0
.sdpa_cpu(kv_len, q, q_gate, k_cache, v_cache, &mut out)
.expect("RsBackend sdpa_cpu");
out
}
fn lm_head_cpu(&self, x: &[f32]) -> Vec<f32> {
let mut out = vec![0.0f32; self.0.n_vocab()];
self.0
.lm_head_cpu(x, &mut out)
.expect("RsBackend lm_head_cpu");
out
}
fn moe_router_cpu(&self, scores: Vec<f32>, k: usize) -> (Vec<i32>, Vec<f32>) {
let mut s = scores;
let mut idx = vec![0i32; k];
let mut w = vec![0.0f32; k];
self.0
.moe_router_cpu(&mut s, k, &mut idx, &mut w)
.expect("RsBackend moe_router_cpu");
(idx, w)
}
fn conv1d_step_cpu(
&self,
weight_name: &str,
channels: usize,
kernel_size: usize,
conv_state: &[f32],
new_input: &[f32],
) -> Vec<f32> {
let mut out = vec![0.0f32; channels];
self.0
.conv1d_step_cpu(
weight_name,
channels,
kernel_size,
conv_state,
new_input,
&mut out,
)
.expect("RsBackend conv1d_step_cpu");
out
}
fn rms_norm_bare_cpu(&self, eps: f32, x: &[f32]) -> Vec<f32> {
let mut out = vec![0.0f32; x.len()];
self.0
.rms_norm_bare_cpu(eps, x, &mut out)
.expect("RsBackend rms_norm_bare_cpu");
out
}
fn rms_norm_gated_cpu(
&self,
weight_name: &str,
eps: f32,
x: &[f32],
z: &[f32],
) -> Vec<f32> {
let mut out = vec![0.0f32; x.len()];
self.0
.rms_norm_gated_cpu(weight_name, eps, x, z, &mut out)
.expect("RsBackend rms_norm_gated_cpu");
out
}
fn gated_delta_recurrence_cpu(
&self,
layer_idx: usize,
alpha: &[f32],
beta: &[f32],
q: &[f32],
k: &[f32],
v: &[f32],
v_heads: usize,
k_heads: usize,
key_dim: usize,
value_dim: usize,
ssm_state_in: Vec<f32>,
) -> (Vec<f32>, Vec<f32>) {
let mut state = ssm_state_in;
let mut out = vec![0.0f32; v_heads * value_dim];
self.0
.gated_delta_recurrence_cpu(
layer_idx,
alpha,
beta,
q,
k,
v,
v_heads,
k_heads,
key_dim,
value_dim,
&mut state,
&mut out,
)
.expect("RsBackend gated_delta_recurrence_cpu");
(state, out)
}
fn load_expert_bytes(&self, layer_idx: i32, expert_idx: i32) -> Vec<u8> {
let mut out = vec![0u8; moeflux::riir::VARIANT.expert_size_4bit()];
self.0
.load_expert_bytes(
layer_idx as usize,
expert_idx as usize,
&mut out,
)
.expect("RsBackend load_expert_bytes");
out
}
fn gpu_rms_norm_fused(
&mut self,
x: &[f32],
weight_bf16: &[u8],
) -> Vec<f32> {
let mut out = vec![0.0f32; moeflux::riir::VARIANT.hidden_dim];
self.0
.gpu_rms_norm_fused(x, weight_bf16, &mut out)
.expect("RsBackend gpu_rms_norm_fused");
out
}
fn gpu_expert_forward(
&mut self,
expert_data: &[u8],
h_post: &[f32],
) -> Vec<f32> {
let mut out = vec![0.0f32; moeflux::riir::VARIANT.hidden_dim];
self.0
.gpu_expert_forward(expert_data, h_post, &mut out)
.expect("RsBackend gpu_expert_forward");
out
}
fn gpu_batched_experts_forward(
&mut self,
actual_k: i32,
expert_data: &[u8],
h_post: &[f32],
h_mid: &[f32],
shared_out: &[f32],
expert_weights: &[f32],
shared_gate_score: f32,
) -> Vec<f32> {
let mut out = vec![0.0f32; moeflux::riir::VARIANT.hidden_dim];
self.0
.gpu_batched_experts_forward(
actual_k,
expert_data,
h_post,
h_mid,
shared_out,
expert_weights,
shared_gate_score,
&mut out,
)
.expect("RsBackend gpu_batched_experts_forward");
out
}
fn attn_scores_batched(
&mut self,
num_heads: u32,
num_kv_heads: u32,
head_dim: u32,
seq_len: u32,
q: &[f32],
k_cache: &[f32],
scale: f32,
) -> Vec<f32> {
let mut out = vec![0.0f32; (num_heads * seq_len) as usize];
self.0
.attn_scores_batched(
num_heads, num_kv_heads, head_dim, seq_len, q, k_cache,
scale, &mut out,
)
.expect("RsBackend attn_scores_batched");
out
}
fn attn_softmax_batched(
&mut self,
num_heads: u32,
seq_len: u32,
scores_in: &[f32],
) -> Vec<f32> {
let mut out = scores_in.to_vec();
self.0
.attn_softmax_batched(num_heads, seq_len, &mut out)
.expect("RsBackend attn_softmax_batched");
out
}
fn attn_values_batched(
&mut self,
num_heads: u32,
num_kv_heads: u32,
head_dim: u32,
seq_len: u32,
scores: &[f32],
v_cache: &[f32],
) -> Vec<f32> {
let mut out = vec![0.0f32; (num_heads * head_dim) as usize];
self.0
.attn_values_batched(
num_heads, num_kv_heads, head_dim, seq_len, scores, v_cache,
&mut out,
)
.expect("RsBackend attn_values_batched");
out
}
fn sigmoid_gate(
&mut self,
dim: u32,
gate: &[f32],
x_in: &[f32],
) -> Vec<f32> {
let mut out = x_in.to_vec();
self.0
.sigmoid_gate(dim, gate, &mut out)
.expect("RsBackend sigmoid_gate");
out
}
fn begin_deferred_experts(
&mut self,
actual_k: i32,
expert_data: &[u8],
h_post: &[f32],
h_mid: &[f32],
shared_out: &[f32],
expert_weights: &[f32],
shared_gate_score: f32,
) {
self.0
.begin_deferred_experts(
actual_k,
expert_data,
h_post,
h_mid,
shared_out,
expert_weights,
shared_gate_score,
-1,
)
.expect("RsBackend begin_deferred_experts");
}
fn complete_deferred_experts(&mut self) -> Vec<f32> {
let mut out = vec![0.0f32; moeflux::riir::VARIANT.hidden_dim];
self.0
.complete_deferred_experts(&mut out)
.expect("RsBackend complete_deferred_experts");
out
}
fn discard_deferred_experts(&mut self) {
self.0.discard_deferred_experts();
}
fn layer_forward_dump(
&mut self,
layer_idx: i32,
pos: i32,
hidden_in: &[f32],
) -> Vec<f32> {
let mut out = vec![0.0f32; moeflux::riir::VARIANT.hidden_dim];
self.0
.layer_forward_dump(layer_idx, pos, hidden_in, &mut out)
.expect("RsBackend layer_forward_dump");
out
}
fn eval_prompt(&mut self, tokens: &[i32], start_pos: usize) -> Vec<f32> {
let mut logits = vec![0.0f32; self.0.n_vocab()];
self.0
.eval_prompt(tokens, start_pos, 0, &mut logits)
.expect("RsBackend eval_prompt");
logits
}
fn eval_token(&mut self, token: i32, pos: usize) -> Vec<f32> {
let mut logits = vec![0.0f32; self.0.n_vocab()];
self.0
.eval_token(token, pos, 0, &mut logits)
.expect("RsBackend eval_token");
logits
}
fn memory_clear(&mut self) {
self.0.memory_clear()
}
fn memory_seq_rm(&mut self, p0: i32, p1: i32) -> bool {
self.0.memory_seq_rm(0, p0, p1)
}
fn memory_seq_pos_max(&self) -> i32 {
self.0.memory_seq_pos_max(0)
}
}
pub fn open_backend<B: DiffBackend>() -> B {
let p = default_a3b_paths();
B::open(
&p.weights,
&p.manifest,
&p.vocab,
&p.root,
p.experts_per_tok,
p.use_2bit,
)
}
#[test]
#[ignore = "long running; needs moeflux artifacts"]
fn weight_file_loads_a3b() {
let art = artifacts_dir();
let wf = moeflux::riir::WeightFile::open(
&art.join("model_weights.bin"),
&art.join("model_weights.json"),
)
.expect("WeightFile::open");
eprintln!(
"[diff:weight_file] {} tensors in {:.2} GB",
wf.len(),
wf.file_size() as f64 / 1e9,
);
assert_eq!(wf.len(), 1397, "tensor count drifted from C");
let embed = wf
.tensor_info("model.embed_tokens.weight")
.expect("model.embed_tokens.weight");
assert!(!embed.dtype.is_empty(), "embed_tokens dtype empty");
eprintln!(
"[diff:weight_file] embed_tokens dtype={} shape={:?} bits={} size={}",
embed.dtype, embed.shape, embed.bits, embed.size,
);
let bytes = wf
.tensor_bytes("model.embed_tokens.weight")
.expect("embed bytes");
assert_eq!(bytes.len() as u64, embed.size);
}
#[test]
#[ignore = "long running; needs moeflux artifacts"]
fn layer_forward_dump_back_to_back_no_deferred_leak() {
let mut rs: RsBackend = open_backend();
let hidden_dim = moeflux::riir::VARIANT.hidden_dim;
let hidden_in = rs.embed(1);
assert_eq!(hidden_in.len(), hidden_dim);
let layer_idx = 0i32; let pos = 0i32;
let n_iters = 5usize;
let mut outs: Vec<Vec<f32>> = Vec::with_capacity(n_iters);
for i in 0..n_iters {
rs.memory_clear();
let out = rs.layer_forward_dump(layer_idx, pos, &hidden_in);
assert_eq!(out.len(), hidden_dim, "iter {i}: output length");
assert!(
out.iter().all(|x| x.is_finite()),
"iter {i}: output has NaN/Inf — likely stale deferred state"
);
let max_abs = out.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
assert!(
max_abs > 1e-6,
"iter {i}: output magnitude {max_abs:.3e} too small — drain \
likely reading from wrong buffer or hitting AlreadyActive"
);
outs.push(out);
}
for i in 1..n_iters {
let drift_max = outs[0]
.iter()
.zip(outs[i].iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert_eq!(
drift_max, 0.0,
"iter 0 vs iter {i} differ by max_abs_diff={drift_max:.3e} — \
deferred-experts state leaked across calls or memory_clear \
did not reset all recurrence"
);
}
eprintln!(
"[diff:layer_forward_dump_back_to_back] {n_iters} iterations \
bit-identical (max_abs_diff=0)"
);
}
#[test]
#[ignore = "long running; needs moeflux artifacts"]
fn state_round_trip_rust() {
let mut rs_ref: RsBackend = open_backend();
let prompt: [i32; 4] = [1, 200, 600, 1100];
let next_token = 7i32;
let next_pos = prompt.len();
let _ = rs_ref.eval_prompt(&prompt, 0);
let ref_logits = rs_ref.eval_token(next_token, next_pos);
let mut rs: RsBackend = open_backend();
let _ = rs.eval_prompt(&prompt, 0);
let snap_size = rs.0.state_size();
let mut snap = vec![0u8; snap_size];
let written = rs.0.state_save(&mut snap).expect("Rust state_save");
assert_eq!(written, snap_size, "state_save wrote unexpected length");
rs.memory_clear();
rs.0.state_load(&snap).expect("Rust state_load");
let test_logits = rs.eval_token(next_token, next_pos);
assert_eq!(test_logits.len(), ref_logits.len());
let drift_max = ref_logits
.iter()
.zip(test_logits.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
let cos = cosine_sim(&ref_logits, &test_logits);
eprintln!(
"[diff:state_round_trip_rust] snap_bytes={snap_size} \
max_abs_diff={drift_max:.3e} cosine={cos:.7}"
);
assert_eq!(
argmax(&ref_logits),
argmax(&test_logits),
"round-trip changed argmax"
);
assert!(
cos >= 0.9999,
"round-trip cosine {cos:.7} below 0.9999"
);
}
#[test]
#[ignore = "long running; needs moeflux artifacts"]
fn prefetch_hit_miss_equivalence_rust() {
let prompt: [i32; 4] = [1, 200, 600, 1100];
let next_token = 7i32;
let next_pos = prompt.len();
let mut rs_normal: RsBackend = open_backend();
let _ = rs_normal.eval_prompt(&prompt, 0);
let normal_logits = rs_normal.eval_token(next_token, next_pos);
let mut rs_miss: RsBackend = open_backend();
let _ = rs_miss.eval_prompt(&prompt, 0);
rs_miss.0.clear_prefetch_predictions();
let miss_logits = rs_miss.eval_token(next_token, next_pos);
assert_eq!(normal_logits.len(), miss_logits.len());
let drift_max = normal_logits
.iter()
.zip(miss_logits.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
let cos = cosine_sim(&normal_logits, &miss_logits);
eprintln!(
"[diff:prefetch_hit_miss_equivalence] \
max_abs_diff={drift_max:.3e} cosine={cos:.7} \
argmax(normal)={a} argmax(miss)={b}",
a = argmax(&normal_logits),
b = argmax(&miss_logits),
);
assert_eq!(
argmax(&normal_logits),
argmax(&miss_logits),
"prefetch hit and all-miss paths produced different argmax"
);
assert_eq!(
drift_max, 0.0,
"prefetch hit and all-miss paths should be bit-identical, \
got drift {drift_max:.3e}"
);
}
#[test]
#[ignore = "long running; needs moeflux artifacts"]
fn memory_clear_cancels_prefetch_no_leak() {
let prompt_a: [i32; 4] = [1, 200, 600, 1100];
let prompt_b: [i32; 4] = [2, 300, 700, 1200];
let next_token = 7i32;
let next_pos = prompt_b.len();
let mut rs_ref: RsBackend = open_backend();
let _ = rs_ref.eval_prompt(&prompt_b, 0);
let ref_logits = rs_ref.eval_token(next_token, next_pos);
let mut rs: RsBackend = open_backend();
let _ = rs.eval_prompt(&prompt_a, 0);
rs.memory_clear();
let _ = rs.eval_prompt(&prompt_b, 0);
let test_logits = rs.eval_token(next_token, next_pos);
assert_eq!(test_logits.len(), ref_logits.len());
let drift_max = ref_logits
.iter()
.zip(test_logits.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
let cos = cosine_sim(&ref_logits, &test_logits);
eprintln!(
"[diff:memory_clear_cancels_prefetch] \
max_abs_diff={drift_max:.3e} cosine={cos:.7}"
);
assert_eq!(
argmax(&ref_logits),
argmax(&test_logits),
"memory_clear leaked prefetch state across reset"
);
assert!(
cos >= 0.9999,
"memory_clear leak: cosine {cos:.7} below 0.9999"
);
}
#[test]
#[ignore = "long running; needs moeflux artifacts"]
fn slot_reuse_race_regression_rust() {
let prompt: [i32; 4] = [1, 200, 600, 1100];
let token_t1 = 7i32;
let token_t2 = 42i32;
let pos_t1 = prompt.len();
let pos_t2 = pos_t1 + 1;
let mut rs_ref1: RsBackend = open_backend();
let _ = rs_ref1.eval_prompt(&prompt, 0);
let _ = rs_ref1.eval_token(token_t1, pos_t1);
let ref_t2 = rs_ref1.eval_token(token_t2, pos_t2);
let mut rs: RsBackend = open_backend();
let _ = rs.eval_prompt(&prompt, 0);
rs.0.clear_prefetch_predictions();
let _ = rs.eval_token(token_t1, pos_t1);
rs.0.clear_prefetch_predictions();
let test_t2 = rs.eval_token(token_t2, pos_t2);
let drift_max = ref_t2
.iter()
.zip(test_t2.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
let cos = cosine_sim(&ref_t2, &test_t2);
eprintln!(
"[diff:slot_reuse_race_regression] \
max_abs_diff={drift_max:.3e} cosine={cos:.7}"
);
assert_eq!(
argmax(&ref_t2),
argmax(&test_t2),
"slot-reuse race: argmax changed across consecutive evals"
);
assert!(
cos >= 0.9999,
"slot-reuse race regression: cosine {cos:.7} below 0.9999"
);
}
#[test]
#[ignore = "long running; needs moeflux artifacts"]
fn eval_prompt_matches_per_token_oracle() {
let prompt: [i32; 16] = [
1, 200, 600, 1100, 2, 300, 700, 1200, 3, 400, 800, 1300, 4, 500,
900, 1400,
];
let next_token = 7i32;
let next_pos = prompt.len();
let mut rs_ref: RsBackend = open_backend();
let n_vocab = rs_ref.0.n_vocab();
let mut ref_prompt_logits = vec![0.0f32; n_vocab];
for (i, &tok) in prompt.iter().enumerate() {
rs_ref
.0
.eval_token(tok, i, 0, &mut ref_prompt_logits)
.expect("oracle eval_token");
}
let ref_continuation = rs_ref.eval_token(next_token, next_pos);
let mut rs: RsBackend = open_backend();
let mut prompt_logits = vec![0.0f32; n_vocab];
rs.0.eval_prompt(&prompt, 0, 0, &mut prompt_logits)
.expect("canonical eval_prompt");
let test_continuation = rs.eval_token(next_token, next_pos);
let prompt_cos = cosine_sim(&ref_prompt_logits, &prompt_logits);
let prompt_drift = ref_prompt_logits
.iter()
.zip(prompt_logits.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
let cont_cos = cosine_sim(&ref_continuation, &test_continuation);
let cont_drift = ref_continuation
.iter()
.zip(test_continuation.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
eprintln!(
"[diff:eval_prompt_matches_per_token_oracle] \
prompt cosine={prompt_cos:.7} max_abs_diff={prompt_drift:.3e} | \
continuation cosine={cont_cos:.7} max_abs_diff={cont_drift:.3e}"
);
assert_eq!(
argmax(&ref_prompt_logits),
argmax(&prompt_logits),
"eval_prompt last-token argmax diverged from oracle"
);
assert_eq!(
argmax(&ref_continuation),
argmax(&test_continuation),
"post-prompt continuation argmax diverged from oracle"
);
assert!(
prompt_cos >= 0.9999,
"eval_prompt last-token cosine {prompt_cos:.7} below 0.9999"
);
assert!(
cont_cos >= 0.9999,
"post-prompt continuation cosine {cont_cos:.7} below 0.9999"
);
}
#[test]
#[ignore = "long running; needs moeflux artifacts"]
fn eval_prompt_chunked_matches_eval_prompt_whole_prompt() {
let prompt: [i32; 16] = [
1, 200, 600, 1100, 2, 300, 700, 1200, 3, 400, 800, 1300, 4, 500,
900, 1400,
];
let next_token = 7i32;
let next_pos = prompt.len();
let mut rs_ref: RsBackend = open_backend();
let n_vocab = rs_ref.0.n_vocab();
let mut ref_prompt_logits = vec![0.0f32; n_vocab];
for (i, &tok) in prompt.iter().enumerate() {
rs_ref
.0
.eval_token(tok, i, 0, &mut ref_prompt_logits)
.expect("oracle eval_token");
}
let ref_continuation = rs_ref.eval_token(next_token, next_pos);
moeflux::riir::set_batched_chunk_size_for_test(Some(4));
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut rs: RsBackend = open_backend();
let mut prompt_logits = vec![0.0f32; n_vocab];
rs.0.eval_prompt(&prompt, 0, 0, &mut prompt_logits)
.expect("chunked eval_prompt");
let test_continuation = rs.eval_token(next_token, next_pos);
(prompt_logits, test_continuation)
}));
moeflux::riir::set_batched_chunk_size_for_test(None);
let (prompt_logits, test_continuation) = match result {
Ok(t) => t,
Err(payload) => std::panic::resume_unwind(payload),
};
let prompt_cos = cosine_sim(&ref_prompt_logits, &prompt_logits);
let prompt_drift = ref_prompt_logits
.iter()
.zip(prompt_logits.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
let cont_cos = cosine_sim(&ref_continuation, &test_continuation);
let cont_drift = ref_continuation
.iter()
.zip(test_continuation.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
eprintln!(
"[diff:eval_prompt_chunked_matches_eval_prompt_whole_prompt] \
chunk=4 prompt cosine={prompt_cos:.7} max_abs={prompt_drift:.3e} | \
continuation cosine={cont_cos:.7} max_abs={cont_drift:.3e}"
);
assert_eq!(
argmax(&ref_prompt_logits),
argmax(&prompt_logits),
"chunked eval_prompt last-token argmax diverged from oracle"
);
assert_eq!(
argmax(&ref_continuation),
argmax(&test_continuation),
"post-chunked-prompt continuation argmax diverged from oracle"
);
assert!(
prompt_cos >= 0.9999,
"chunked eval_prompt cosine {prompt_cos:.7} below 0.9999"
);
assert!(
cont_cos >= 0.9999,
"post-chunked continuation cosine {cont_cos:.7} below 0.9999"
);
}
#[test]
#[ignore = "long running; needs moeflux artifacts; diagnostic"]
fn diag_b2_eval_prompt_chunk_1() {
let prompt: [i32; 16] = [
1, 200, 600, 1100, 2, 300, 700, 1200, 3, 400, 800, 1300, 4, 500,
900, 1400,
];
let next_token = 7i32;
let next_pos = prompt.len();
let mut rs_ref: RsBackend = open_backend();
let n_vocab = rs_ref.0.n_vocab();
let mut ref_logits = vec![0.0f32; n_vocab];
for (i, &tok) in prompt.iter().enumerate() {
rs_ref
.0
.eval_token(tok, i, 0, &mut ref_logits)
.expect("oracle eval_token");
}
let ref_cont = rs_ref.eval_token(next_token, next_pos);
moeflux::riir::set_batched_chunk_size_for_test(Some(1));
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut rs: RsBackend = open_backend();
let mut prompt_logits = vec![0.0f32; n_vocab];
rs.0.eval_prompt(&prompt, 0, 0, &mut prompt_logits)
.expect("chunked eval_prompt @ chunk=1");
let cont = rs.eval_token(next_token, next_pos);
(prompt_logits, cont)
}));
moeflux::riir::set_batched_chunk_size_for_test(None);
let (prompt_logits, test_cont) = match result {
Ok(t) => t,
Err(payload) => std::panic::resume_unwind(payload),
};
let prompt_cos = cosine_sim(&ref_logits, &prompt_logits);
let cont_cos = cosine_sim(&ref_cont, &test_cont);
eprintln!(
"[diag:b2_chunk_1] prompt_cos={prompt_cos:.7} cont_cos={cont_cos:.7}"
);
assert!(prompt_cos >= 0.9999, "chunk_size=1 prompt cosine {prompt_cos:.7}");
assert!(cont_cos >= 0.9999, "chunk_size=1 cont cosine {cont_cos:.7}");
}
#[test]
#[ignore = "long running; needs moeflux artifacts; directional only"]
fn bench_batched_eval_prompt_vs_per_token() {
const N: usize = 256;
let prompt: Vec<i32> =
(0..N).map(|i| ((i * 37 + 5) % 50000 + 1) as i32).collect();
let mut rs_oracle: RsBackend = open_backend();
let n_vocab = rs_oracle.0.n_vocab();
let mut oracle_logits = vec![0.0f32; n_vocab];
let t0 = Instant::now();
for (i, &tok) in prompt.iter().enumerate() {
rs_oracle
.0
.eval_token(tok, i, 0, &mut oracle_logits)
.expect("oracle eval_token");
}
let oracle_elapsed = t0.elapsed();
let mut rs_batched: RsBackend = open_backend();
let mut batched_logits = vec![0.0f32; n_vocab];
let t1 = Instant::now();
rs_batched
.0
.eval_prompt(&prompt, 0, 0, &mut batched_logits)
.expect("batched eval_prompt");
let batched_elapsed = t1.elapsed();
let oracle_tok_s = N as f64 / oracle_elapsed.as_secs_f64();
let batched_tok_s = N as f64 / batched_elapsed.as_secs_f64();
let speedup = batched_tok_s / oracle_tok_s;
eprintln!(
"[bench:eval_prompt_vs_per_token N={N}] \
per-token: {oracle_elapsed:?} ({oracle_tok_s:.2} tok/s) | \
batched: {batched_elapsed:?} ({batched_tok_s:.2} tok/s) | \
speedup: {speedup:.2}×"
);
let cos = cosine_sim(&oracle_logits, &batched_logits);
eprintln!(
"[bench:eval_prompt_vs_per_token] sanity cosine={cos:.7}"
);
assert!(cos >= 0.99, "bench cosine {cos:.7} below sanity floor");
}
#[test]
#[ignore = "long running; needs moeflux artifacts; directional only"]
fn bench_decode_per_token_vs_batched_n1() {
const PROMPT_LEN: usize = 32; const DECODE_N: usize = 32;
let prompt: Vec<i32> = (0..PROMPT_LEN)
.map(|i| ((i * 37 + 5) % 50000 + 1) as i32)
.collect();
let mut rs_oracle: RsBackend = open_backend();
let n_vocab = rs_oracle.0.n_vocab();
let mut prompt_logits = vec![0.0f32; n_vocab];
for (i, &tok) in prompt.iter().enumerate() {
rs_oracle
.0
.eval_token(tok, i, 0, &mut prompt_logits)
.expect("oracle warm-up");
}
let mut last_logits = prompt_logits.clone();
let t0 = Instant::now();
for d in 0..DECODE_N {
let next_tok = argmax(&last_logits) as i32;
rs_oracle
.0
.eval_token(next_tok, PROMPT_LEN + d, 0, &mut last_logits)
.expect("oracle decode");
}
let oracle_elapsed = t0.elapsed();
let oracle_decode_tok_s = DECODE_N as f64 / oracle_elapsed.as_secs_f64();
let mut rs_batched: RsBackend = open_backend();
let mut prompt_logits_b = vec![0.0f32; n_vocab];
rs_batched
.0
.eval_prompt(&prompt, 0, 0, &mut prompt_logits_b)
.expect("batched warm-up");
let mut last_logits_b = prompt_logits_b.clone();
let t1 = Instant::now();
for d in 0..DECODE_N {
let next_tok = argmax(&last_logits_b) as i32;
rs_batched
.0
.eval_prompt(
&[next_tok],
PROMPT_LEN + d,
0,
&mut last_logits_b,
)
.expect("batched decode N=1");
}
let batched_elapsed = t1.elapsed();
let batched_decode_tok_s =
DECODE_N as f64 / batched_elapsed.as_secs_f64();
let regression = (oracle_decode_tok_s - batched_decode_tok_s)
/ oracle_decode_tok_s
* 100.0;
eprintln!(
"[bench:decode_per_token_vs_batched_n1] kv_start={PROMPT_LEN} \
decode_n={DECODE_N} | per-token: {oracle_elapsed:?} \
({oracle_decode_tok_s:.2} tok/s) | batched-N1: \
{batched_elapsed:?} ({batched_decode_tok_s:.2} tok/s) | \
regression: {regression:.1}%"
);
let cos = cosine_sim(&last_logits, &last_logits_b);
eprintln!(
"[bench:decode_per_token_vs_batched_n1] final-logit cos={cos:.7}"
);
assert!(
cos >= 0.99,
"decode bench cosine {cos:.7} below sanity floor — \
per-token and batched-N1 diverged greedily"
);
}
#[test]
#[ignore = "long running; needs moeflux artifacts"]
fn prompt_cache_start_pos_nonzero_matches() {
let prefix: [i32; 4] = [1, 200, 600, 1100];
let suffix: [i32; 5] = [2, 300, 700, 1200, 3];
let next_token = 7i32;
let full_pos = prefix.len() + suffix.len();
let mut rs_ctrl: RsBackend = open_backend();
let mut full_prompt = Vec::with_capacity(full_pos);
full_prompt.extend_from_slice(&prefix);
full_prompt.extend_from_slice(&suffix);
let n_vocab = rs_ctrl.0.n_vocab();
let mut ctrl_prompt_logits = vec![0.0f32; n_vocab];
rs_ctrl
.0
.eval_prompt(&full_prompt, 0, 0, &mut ctrl_prompt_logits)
.expect("control eval_prompt");
let ctrl_continuation = rs_ctrl.eval_token(next_token, full_pos);
let mut rs: RsBackend = open_backend();
let mut _prefix_logits = vec![0.0f32; n_vocab];
rs.0.eval_prompt(&prefix, 0, 0, &mut _prefix_logits)
.expect("prefix eval_prompt");
let snap_size = rs.0.state_size();
let mut snap = vec![0u8; snap_size];
rs.0.state_save(&mut snap).expect("state_save");
rs.memory_clear();
rs.0.state_load(&snap).expect("state_load");
let mut test_prompt_logits = vec![0.0f32; n_vocab];
rs.0.eval_prompt(
&suffix,
prefix.len(),
0,
&mut test_prompt_logits,
)
.expect("suffix eval_prompt at start_pos != 0");
let test_continuation = rs.eval_token(next_token, full_pos);
let prompt_cos = cosine_sim(&ctrl_prompt_logits, &test_prompt_logits);
let prompt_drift = ctrl_prompt_logits
.iter()
.zip(test_prompt_logits.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
let cont_cos = cosine_sim(&ctrl_continuation, &test_continuation);
let cont_drift = ctrl_continuation
.iter()
.zip(test_continuation.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
eprintln!(
"[diff:prompt_cache_start_pos_nonzero_matches] \
prompt cosine={prompt_cos:.7} max_abs={prompt_drift:.3e} | \
continuation cosine={cont_cos:.7} max_abs={cont_drift:.3e}"
);
assert_eq!(
argmax(&ctrl_prompt_logits),
argmax(&test_prompt_logits),
"prompt-cache last-token argmax diverged from control"
);
assert!(
prompt_cos >= 0.9999,
"prompt-cache prompt cosine {prompt_cos:.7} below 0.9999"
);
assert!(
cont_cos >= 0.9999,
"prompt-cache continuation cosine {cont_cos:.7} below 0.9999"
);
}