#[test]
fn test_layer_timing_bottleneck() {
let timing = LayerTiming {
attn_norm_us: 1.0,
qkv_us: 10.0, rope_us: 2.0,
attention_us: 5.0,
o_proj_us: 3.0,
ffn_norm_us: 1.0,
ffn_us: 8.0,
total_us: 30.0,
};
let (name, time) = timing.bottleneck();
assert_eq!(name, "qkv");
assert_eq!(time, 10.0);
}
#[test]
#[cfg(feature = "cuda")]
fn t001_coalesced_dp4a_invalid_input_length() {
let brick = CoalescedDp4aBrick::new(256, 4);
let result = brick.forward(
&[1i8; 128], 1.0,
&[0x88u8; 512],
&[0.1f32; 4],
);
assert!(result.is_err());
if let Err(BrickError::InvalidInput(msg)) = result {
assert!(msg.contains("128"), "Error should mention actual size");
assert!(msg.contains("256"), "Error should mention expected size");
}
}
#[test]
#[cfg(feature = "cuda")]
fn t002_coalesced_dp4a_invalid_weights() {
let brick = CoalescedDp4aBrick::new(256, 4);
let result = brick.forward(
&[1i8; 256],
1.0,
&[0x88u8; 256], &[0.1f32; 4],
);
assert!(result.is_err());
if let Err(BrickError::InvalidInput(msg)) = result {
assert!(msg.contains("Weights"));
}
}
#[test]
#[cfg(feature = "cuda")]
fn t003_coalesced_dp4a_invalid_scales() {
let brick = CoalescedDp4aBrick::new(256, 4);
let result = brick.forward(
&[1i8; 256],
1.0,
&[0x88u8; 512],
&[0.1f32; 2], );
assert!(result.is_err());
if let Err(BrickError::InvalidInput(msg)) = result {
assert!(msg.contains("scale"));
}
}
#[test]
#[cfg(feature = "cuda")]
fn t004_coalesced_dp4a_can_run_invalid() {
let brick = CoalescedDp4aBrick::new(128, 4);
assert!(!brick.can_run(), "K=128 not multiple of 256 should not run");
let brick = CoalescedDp4aBrick::new(0, 4);
assert!(!brick.can_run(), "K=0 should not run");
let brick = CoalescedDp4aBrick::new(256, 0);
assert!(!brick.can_run(), "N=0 should not run");
}
#[test]
#[cfg(feature = "cuda")]
fn t005_coalesced_dp4a_arithmetic_intensity() {
let brick = CoalescedDp4aBrick::new(1024, 256);
let ai = brick.arithmetic_intensity();
assert!(ai > 0.0, "Arithmetic intensity should be positive");
assert!(ai < 1000.0, "Arithmetic intensity should be reasonable");
}
#[test]
#[cfg(feature = "cuda")]
fn t006_fused_ffn_invalid_input() {
let brick = FusedFfnBrick::new(4, 8);
let result = brick.forward(
&[1.0f32; 8], &[0.1f32; 32],
&[0.2f32; 32],
&[0.1f32; 32],
);
assert!(result.is_err());
if let Err(BrickError::InvalidInput(msg)) = result {
assert!(msg.contains("Input"));
}
}
#[test]
#[cfg(feature = "cuda")]
fn t007_fused_ffn_invalid_gate_up() {
let brick = FusedFfnBrick::new(4, 8);
let result = brick.forward(
&[1.0f32; 4],
&[0.1f32; 16], &[0.2f32; 32],
&[0.1f32; 32],
);
assert!(result.is_err());
if let Err(BrickError::InvalidInput(msg)) = result {
assert!(msg.contains("Gate/Up"));
}
}
#[test]
#[cfg(feature = "cuda")]
fn t008_fused_ffn_invalid_down() {
let brick = FusedFfnBrick::new(4, 8);
let result = brick.forward(
&[1.0f32; 4],
&[0.1f32; 32],
&[0.2f32; 32],
&[0.1f32; 16], );
assert!(result.is_err());
if let Err(BrickError::InvalidInput(msg)) = result {
assert!(msg.contains("Down"));
}
}
#[test]
#[cfg(feature = "cuda")]
fn t009_fused_ffn_can_run_invalid() {
let brick = FusedFfnBrick::new(0, 8);
assert!(!brick.can_run(), "Zero hidden_dim should not run");
let brick = FusedFfnBrick::new(4, 0);
assert!(!brick.can_run(), "Zero intermediate_dim should not run");
}
#[test]
fn t010_flash_attention_invalid_query() {
let brick = FlashAttentionBrick::new(4, 2, 8);
let result = brick.forward(
&[1.0f32; 16], &[0.5f32; 64],
&[0.25f32; 64],
4,
);
assert!(result.is_err());
if let Err(BrickError::InvalidInput(msg)) = result {
assert!(msg.contains("Query"));
}
}
#[test]
fn t011_flash_attention_invalid_kv() {
let brick = FlashAttentionBrick::new(4, 2, 8);
let seq_len = 4;
let result = brick.forward(
&[1.0f32; 32],
&[0.5f32; 32], &[0.25f32; 64],
seq_len,
);
assert!(result.is_err());
if let Err(BrickError::InvalidInput(msg)) = result {
assert!(msg.contains("KV"));
}
}
#[test]
fn t012_flash_attention_zero_dim() {
let brick = FlashAttentionBrick::new(0, 0, 0);
let result = brick.forward(&[], &[], &[], 0);
assert!(result.is_err());
if let Err(BrickError::InvalidInput(msg)) = result {
assert!(msg.contains("Zero"));
}
}
#[test]
fn t013_flash_attention_arithmetic_intensity() {
let brick = FlashAttentionBrick::new(8, 2, 64);
let ai = brick.arithmetic_intensity(512);
assert!(ai > 0.0, "Arithmetic intensity should be positive");
}
#[test]
fn t014_activation_quant_zero_dim_quantize() {
let brick = ActivationQuantBrick::new(0);
let result = brick.quantize(&[]);
assert!(result.is_err());
}
#[test]
fn t015_activation_quant_dequantize_mismatch() {
let brick = ActivationQuantBrick::new(32);
let result = brick.dequantize(&[0i8; 16], &[0.1f32; 1]);
assert!(result.is_err());
if let Err(BrickError::InvalidInput(msg)) = result {
assert!(msg.contains("16"));
}
}
#[test]
fn t016_activation_quant_measure_error_zero() {
let brick = ActivationQuantBrick::new(32);
let original = vec![0.0f32; 32];
let (quants, scales) = brick.quantize(&original).expect("quantize failed");
let error = brick
.measure_error(&original, &quants, &scales)
.expect("measure failed");
assert_eq!(error, 0.0, "Error should be 0 for zero input");
}
#[test]
fn t017_rmsnorm_length_mismatch() {
let brick =
RmsNormBrick::new(vec![1.0; 32], 1e-5).with_budget(TokenBudget::from_latency(1000.0));
let result = brick.run(&[1.0f32; 16]);
assert!(result.is_err());
if let Err(BrickError::InvalidInput(msg)) = result {
assert!(msg.contains("16"));
assert!(msg.contains("32"));
}
}
#[test]
fn t018_cuda_graph_replay_not_captured() {
let brick = CudaGraphBrick::new(28, 1536);
let result = brick.replay();
assert!(result.is_err());
if let Err(BrickError::ComputeError(msg)) = result {
assert!(msg.contains("not captured"));
}
}
#[test]
fn t019_cuda_graph_can_run_invalid() {
let brick = CudaGraphBrick::new(0, 1536);
assert!(!brick.can_run(), "Zero layers should not run");
let brick = CudaGraphBrick::new(28, 0);
assert!(!brick.can_run(), "Zero hidden_dim should not run");
}
#[test]
fn t020_transformer_layer_verify() {
let layer = TransformerLayerBrick::from_config(0, 64, 4, 2, 256, 1e-5, 10000.0, 0);
let verification = layer.verify();
assert!(
verification.is_valid,
"Valid layer should pass verification"
);
}
#[test]
fn t021_bottleneck_report_display() {
let report = BottleneckReport {
layer_idx: 5,
brick_name: "attention",
actual_us: 15.0,
budget_us: 10.0,
gap_factor: 1.5,
};
let display = format!("{}", report);
assert!(display.contains("attention"));
assert!(display.contains("layer 5"));
assert!(display.contains("15.0"));
assert!(display.contains("10.0"));
assert!(display.contains("1.50"));
}
#[test]
fn t022_brick_error_std_error() {
let err = BrickError::InvalidInput("test".to_string());
let _: &dyn std::error::Error = &err;
}
#[test]
fn t023_token_result_zero_latency() {
let budget = TokenBudget::from_latency(100.0);
let result: TokenResult<Vec<f32>> = TokenResult::new(vec![], 1, 0.0, &budget);
assert_eq!(result.tokens_per_sec, 0.0);
}
#[test]
fn t024_flash_attention_custom_tile() {
let brick = FlashAttentionBrick::with_tile_size(8, 2, 64, 64);
assert_eq!(brick.tile_size, 64);
assert_eq!(brick.num_tiles(512), 8); }
#[test]
fn t025_flash_attention_zero_tile() {
let mut brick = FlashAttentionBrick::new(8, 2, 64);
brick.tile_size = 0;
assert!(!brick.can_run(), "Zero tile_size should not run");
}
#[test]
fn t026_attention_group_size_edge() {
let brick = AttentionBrick::new(8, 0, 64);
assert_eq!(brick.group_size(), 8); }
#[test]
fn t027_benchmark_empty_samples() {
let brick = RmsNormBrick::new(vec![1.0; 4], 1e-5);
let config = BenchmarkConfig {
warmup: 0,
samples: 1, max_cv: 1.0,
};
let report = benchmark_brick(
&brick,
|| 10.0, &config,
);
assert_eq!(report.mean_us, 10.0);
assert_eq!(report.std_us, 0.0); }
#[test]
fn t028_assertion_budget_not_met() {
let assertion = BrickAssertion::budget_met();
let data = &[1.0f32, 2.0, 3.0];
assert!(assertion.check_f32(data, true).is_ok());
let result = assertion.check_f32(data, false);
assert!(result.is_err());
if let Err(BrickError::AssertionFailed {
name,
expected,
actual,
}) = result
{
assert_eq!(name, "budget_met");
assert!(expected.contains("met"));
assert!(actual.contains("exceeded"));
}
}
#[test]
fn t029_assertion_inf_check() {
let assertion = BrickAssertion::no_inf();
assert!(assertion.check_f32(&[1.0, 2.0, 3.0], true).is_ok());
let result = assertion.check_f32(&[1.0, f32::INFINITY, 3.0], true);
assert!(result.is_err());
let result = assertion.check_f32(&[1.0, f32::NEG_INFINITY, 3.0], true);
assert!(result.is_err());
}