#[cfg(test)]
mod tests {
use crate::brick::*;
#[test]
fn flash_attention_can_run() {
let brick = FlashAttentionBrick::new(8, 2, 64);
assert!(brick.can_run());
}
#[test]
fn flash_attention_cannot_run_zero_heads() {
let brick = FlashAttentionBrick::new(0, 2, 64);
assert!(!brick.can_run());
}
#[test]
fn flash_attention_cannot_run_zero_dim() {
let brick = FlashAttentionBrick::new(8, 2, 0);
assert!(!brick.can_run());
}
#[test]
fn flash_attention_forward_invalid_query_len() {
let brick = FlashAttentionBrick::new(4, 2, 8);
let query = vec![1.0f32; 16]; let keys = vec![0.5f32; 64];
let values = vec![0.25f32; 64];
let result = brick.forward(&query, &keys, &values, 4);
assert!(result.is_err());
}
#[test]
fn flash_attention_forward_invalid_keys_len() {
let brick = FlashAttentionBrick::new(4, 2, 8);
let query = vec![1.0f32; 32];
let keys = vec![0.5f32; 32]; let values = vec![0.25f32; 64];
let result = brick.forward(&query, &keys, &values, 4);
assert!(result.is_err());
}
#[test]
fn flash_attention_arithmetic_intensity() {
let brick = FlashAttentionBrick::new(8, 2, 64);
let ai = brick.arithmetic_intensity(512);
assert!(ai > 0.0);
}
#[test]
fn flash_attention_group_size() {
let brick = FlashAttentionBrick::new(8, 2, 64);
assert_eq!(brick.group_size(), 4);
}
#[test]
fn flash_attention_with_budget() {
let brick = FlashAttentionBrick::new(8, 2, 64).with_budget(TokenBudget::from_latency(10.0));
assert!((brick.budget().us_per_token - 10.0).abs() < 0.001);
}
#[test]
fn flash_attention_forward_zero_dimension() {
let brick = FlashAttentionBrick::new(0, 0, 0);
let result = brick.forward(&[], &[], &[], 0);
assert!(result.is_err());
}
#[test]
fn activation_quant_quantize_invalid_input_len() {
let brick = ActivationQuantBrick::new(32);
let input = vec![1.0f32; 64]; let result = brick.quantize(&input);
assert!(result.is_err());
}
#[test]
fn activation_quant_dequantize_invalid_len() {
let brick = ActivationQuantBrick::new(32);
let quants = vec![0i8; 64]; let scales = vec![1.0f32];
let result = brick.dequantize(&quants, &scales);
assert!(result.is_err());
}
#[test]
fn activation_quant_zero_dimension() {
let brick = ActivationQuantBrick::new(0);
assert!(!brick.can_run());
let result = brick.quantize(&[]);
assert!(result.is_err());
}
#[test]
fn activation_quant_measure_error_near_zero() {
let brick = ActivationQuantBrick::new(32);
let input = vec![1e-12f32; 32]; let (quants, scales) = brick.quantize(&input).unwrap();
let error = brick.measure_error(&input, &quants, &scales).unwrap();
assert!(error < 0.1);
}
#[test]
fn activation_quant_with_custom_budget() {
let brick = ActivationQuantBrick::new(64).with_budget(TokenBudget::from_latency(2.0));
assert!((brick.budget().us_per_token - 2.0).abs() < 0.001);
}
#[test]
fn activation_quant_per_channel() {
let brick = ActivationQuantBrick::with_per_channel(64);
assert!(brick.per_channel);
}
#[test]
fn cuda_graph_replay_not_captured() {
let brick = CudaGraphBrick::new(28, 1536);
let result = brick.replay();
assert!(result.is_err());
if let Err(BrickError::ComputeError(msg)) = result {
assert!(msg.contains("not captured"));
}
}
#[test]
fn cuda_graph_set_captured() {
let mut brick = CudaGraphBrick::new(28, 1536);
assert!(!brick.captured);
brick.set_captured(true);
assert!(brick.captured);
assert!(brick.can_replay());
}
#[test]
fn cuda_graph_can_run() {
let brick = CudaGraphBrick::new(28, 1536);
assert!(brick.can_run());
}
#[test]
fn cuda_graph_cannot_run_zero_layers() {
let brick = CudaGraphBrick::new(0, 1536);
assert!(!brick.can_run());
}
#[test]
fn cuda_graph_cannot_run_zero_hidden() {
let brick = CudaGraphBrick::new(28, 0);
assert!(!brick.can_run());
}
#[test]
fn cuda_graph_with_budget() {
let brick = CudaGraphBrick::new(28, 1536).with_budget(TokenBudget::from_latency(15.0));
assert!((brick.budget().us_per_token - 15.0).abs() < 0.001);
}
#[test]
fn layer_timing_bottleneck_all_zero() {
let timing = LayerTiming::default();
let (name, us) = timing.bottleneck();
assert!(!name.is_empty());
assert!(us == 0.0);
}
#[test]
fn layer_timing_bottleneck_tie() {
let timing = LayerTiming {
attn_norm_us: 10.0,
qkv_us: 10.0, rope_us: 5.0,
attention_us: 5.0,
o_proj_us: 5.0,
ffn_norm_us: 5.0,
ffn_us: 5.0,
total_us: 45.0,
};
let (name, us) = timing.bottleneck();
assert!(name == "attn_norm" || name == "qkv");
assert!((us - 10.0).abs() < 0.001);
}
#[test]
fn layer_timing_bottleneck_ffn() {
let timing = LayerTiming {
attn_norm_us: 1.0,
qkv_us: 2.0,
rope_us: 1.0,
attention_us: 5.0,
o_proj_us: 2.0,
ffn_norm_us: 1.0,
ffn_us: 15.0, total_us: 27.0,
};
let (name, us) = timing.bottleneck();
assert_eq!(name, "ffn");
assert!((us - 15.0).abs() < 0.001);
}
#[test]
fn transformer_layer_from_config() {
let layer = TransformerLayerBrick::from_config(
0, 896, 14, 2, 4864, 1e-5, 1e6, 2, );
assert_eq!(layer.layer_idx, 0);
assert!(layer.last_timing.is_none());
}
#[test]
fn transformer_layer_verify() {
let layer = TransformerLayerBrick::from_config(0, 896, 14, 2, 4864, 1e-5, 1e6, 2);
let v = layer.verify();
assert!(v.is_valid);
}
#[test]
fn transformer_layer_total_budget() {
let layer = TransformerLayerBrick::from_config(0, 896, 14, 2, 4864, 1e-5, 1e6, 2);
let total = layer.total_budget_us();
assert!(total > 0.0);
}
#[test]
fn benchmark_config_default() {
let config = BenchmarkConfig::default();
assert_eq!(config.warmup, 10);
assert_eq!(config.samples, 100);
assert!((config.max_cv - 0.05).abs() < 0.001);
}
#[test]
fn benchmark_report_display() {
let report = BenchmarkReport {
brick_name: "test_brick".to_string(),
mean_us: 50.0,
std_us: 5.0,
cv: 0.1,
p50_us: 48.0,
p99_us: 65.0,
tokens_per_sec: 20000.0,
budget_us: 100.0,
budget_met: true,
statistically_valid: true,
};
let display = format!("{}", report);
assert!(display.contains("test_brick"));
assert!(display.contains("PASS"));
}
#[test]
fn benchmark_report_display_fail() {
let report = BenchmarkReport {
brick_name: "test_brick".to_string(),
mean_us: 150.0,
std_us: 15.0,
cv: 0.1,
p50_us: 145.0,
p99_us: 180.0,
tokens_per_sec: 6667.0,
budget_us: 100.0,
budget_met: false,
statistically_valid: true,
};
let display = format!("{}", report);
assert!(display.contains("FAIL"));
}
#[test]
fn bottleneck_report_display() {
let report = BottleneckReport {
layer_idx: 5,
brick_name: "ffn",
actual_us: 20.0,
budget_us: 15.0,
gap_factor: 1.33,
};
let display = format!("{}", report);
assert!(display.contains("ffn"));
assert!(display.contains("layer 5"));
assert!(display.contains("20"));
assert!(display.contains("15"));
assert!(display.contains("1.33"));
}
#[test]
fn rope_brick_creation() {
let brick = RopeBrick::new(64, 8, 10000.0, 0);
assert_eq!(brick.head_dim, 64);
assert_eq!(brick.num_heads, 8);
assert!((brick.theta - 10000.0).abs() < 0.1);
assert_eq!(brick.rope_type, 0);
}
#[test]
fn rope_brick_with_budget() {
let brick = RopeBrick::new(64, 8, 10000.0, 0).with_budget(TokenBudget::from_latency(5.0));
assert!((brick.budget().us_per_token - 5.0).abs() < 0.001);
}
#[test]
fn ffn_brick_creation() {
let brick = FfnBrick::new(1024, 4096);
assert_eq!(brick.hidden_dim, 1024);
assert_eq!(brick.intermediate_dim, 4096);
}
#[test]
fn ffn_brick_with_budget() {
let brick = FfnBrick::new(1024, 4096).with_budget(TokenBudget::from_latency(20.0));
assert!((brick.budget().us_per_token - 20.0).abs() < 0.001);
}
#[test]
fn oproj_brick_creation() {
let brick = OProjBrick::new(512, 128);
assert_eq!(brick.in_dim, 512);
assert_eq!(brick.out_dim, 128);
}
#[test]
fn oproj_brick_with_budget() {
let brick = OProjBrick::new(512, 128).with_budget(TokenBudget::from_latency(5.0));
assert!((brick.budget().us_per_token - 5.0).abs() < 0.001);
}
}