#[cfg(test)]
mod tests {
use crate::brick::*;
#[test]
fn token_budget_default_values() {
let budget = TokenBudget::default();
assert!((budget.us_per_token - 100.0).abs() < 0.001);
assert!((budget.tokens_per_sec - 10_000.0).abs() < 1.0);
assert_eq!(budget.batch_size, 1);
}
#[test]
fn token_budget_with_batch_size() {
let budget = TokenBudget::from_latency(50.0).with_batch_size(4);
assert_eq!(budget.batch_size, 4);
assert!((budget.us_per_token - 50.0).abs() < 0.001);
}
#[test]
fn token_budget_extreme_values() {
let fast = TokenBudget::from_latency(1.0);
assert!((fast.tokens_per_sec - 1_000_000.0).abs() < 1.0);
let slow = TokenBudget::from_throughput(100.0);
assert!((slow.us_per_token - 10_000.0).abs() < 0.1);
}
#[test]
fn token_budget_boundary_is_met() {
let budget = TokenBudget::from_latency(100.0);
assert!(budget.is_met(100.0));
assert!(budget.is_met(99.999));
assert!(!budget.is_met(100.001));
}
#[test]
fn token_result_default() {
let result: TokenResult<Vec<f32>> = TokenResult::default();
assert!(result.output.is_empty());
assert_eq!(result.tokens_processed, 0);
assert_eq!(result.us_per_token, 0.0);
assert_eq!(result.tokens_per_sec, 0.0);
assert!(result.budget_met);
}
#[test]
fn token_result_zero_tokens_handled() {
let budget = TokenBudget::from_latency(100.0);
let result: TokenResult<Vec<f32>> = TokenResult::new(vec![], 0, 100.0, &budget);
assert!((result.us_per_token - 100.0).abs() < 0.001);
}
#[test]
fn token_result_zero_elapsed_time() {
let budget = TokenBudget::from_latency(100.0);
let result: TokenResult<Vec<f32>> = TokenResult::new(vec![], 10, 0.0, &budget);
assert!(result.us_per_token == 0.0);
assert!(result.tokens_per_sec == 0.0); assert!(result.budget_met); }
#[test]
fn brick_error_display_assertion_failed() {
let err = BrickError::AssertionFailed {
name: "test_assertion".to_string(),
expected: "42".to_string(),
actual: "0".to_string(),
};
let msg = format!("{}", err);
assert!(msg.contains("test_assertion"));
assert!(msg.contains("42"));
assert!(msg.contains("0"));
}
#[test]
fn brick_error_display_budget_exceeded() {
let err = BrickError::BudgetExceeded {
limit_us: 10.0,
actual_us: 20.0,
};
let msg = format!("{}", err);
assert!(msg.contains("10"));
assert!(msg.contains("20"));
}
#[test]
fn brick_error_display_compute_error() {
let err = BrickError::ComputeError("test error message".to_string());
let msg = format!("{}", err);
assert!(msg.contains("test error message"));
}
#[test]
fn brick_error_display_invalid_input() {
let err = BrickError::InvalidInput("bad input".to_string());
let msg = format!("{}", err);
assert!(msg.contains("bad input"));
}
#[test]
fn brick_error_is_std_error() {
let err = BrickError::InvalidInput("test".to_string());
let _: &dyn std::error::Error = &err;
}
#[test]
fn assertion_equiv_scalar_description() {
let assertion = BrickAssertion::equiv_scalar(0.001);
assert_eq!(assertion.name, "equiv_scalar");
assert!(assertion.description.contains("0.001"));
}
#[test]
fn assertion_bounds_check_pass() {
let assertion = BrickAssertion::bounds(-10.0, 10.0);
let data = [-5.0f32, 0.0, 5.0];
assert!(assertion.check_f32(&data, true).is_ok());
}
#[test]
fn assertion_bounds_check_fail_low() {
let assertion = BrickAssertion::bounds(-10.0, 10.0);
let data = [-15.0f32, 0.0, 5.0];
let result = assertion.check_f32(&data, true);
assert!(result.is_err());
if let Err(BrickError::AssertionFailed { actual, .. }) = result {
assert!(actual.contains("-15"));
assert!(actual.contains("index 0"));
}
}
#[test]
fn assertion_bounds_check_fail_high() {
let assertion = BrickAssertion::bounds(-10.0, 10.0);
let data = [0.0f32, 0.0, 15.0];
let result = assertion.check_f32(&data, true);
assert!(result.is_err());
if let Err(BrickError::AssertionFailed { actual, .. }) = result {
assert!(actual.contains("15"));
assert!(actual.contains("index 2"));
}
}
#[test]
fn assertion_no_inf_check_pass() {
let assertion = BrickAssertion::no_inf();
let data = [1e30f32, -1e30, 0.0];
assert!(assertion.check_f32(&data, true).is_ok());
}
#[test]
fn assertion_no_inf_check_fail_positive() {
let assertion = BrickAssertion::no_inf();
let data = [1.0f32, f32::INFINITY, 3.0];
let result = assertion.check_f32(&data, true);
assert!(result.is_err());
if let Err(BrickError::AssertionFailed { actual, .. }) = result {
assert!(actual.contains("index 1"));
}
}
#[test]
fn assertion_no_inf_check_fail_negative() {
let assertion = BrickAssertion::no_inf();
let data = [f32::NEG_INFINITY, 2.0, 3.0];
let result = assertion.check_f32(&data, true);
assert!(result.is_err());
}
#[test]
fn assertion_budget_met_pass() {
let assertion = BrickAssertion::budget_met();
assert!(assertion.check_f32(&[1.0, 2.0, 3.0], true).is_ok());
}
#[test]
fn assertion_budget_met_fail() {
let assertion = BrickAssertion::budget_met();
let result = assertion.check_f32(&[1.0, 2.0, 3.0], false);
assert!(result.is_err());
if let Err(BrickError::AssertionFailed {
expected, actual, ..
}) = result
{
assert!(expected.contains("budget met"));
assert!(actual.contains("budget exceeded"));
}
}
#[test]
fn verification_pass() {
let v = BrickVerification::pass();
assert!(v.is_valid);
assert!(v.results.is_empty());
}
#[test]
fn verification_fail() {
let v = BrickVerification::fail("test_brick", "test failure reason");
assert!(!v.is_valid);
assert_eq!(v.results.len(), 1);
assert_eq!(v.results[0].0, "test_brick");
assert!(!v.results[0].1); assert_eq!(v.results[0].2, "test failure reason");
}
#[test]
fn verification_add_passed() {
let mut v = BrickVerification::pass();
v.add("check1", true, "ok");
assert!(v.is_valid); assert_eq!(v.results.len(), 1);
}
#[test]
fn verification_add_failed() {
let mut v = BrickVerification::pass();
v.add("check1", false, "failed");
assert!(!v.is_valid); assert_eq!(v.results.len(), 1);
}
#[test]
fn verification_multiple_adds() {
let mut v = BrickVerification::pass();
v.add("check1", true, "ok");
v.add("check2", true, "ok");
v.add("check3", false, "failed");
v.add("check4", true, "ok");
assert!(!v.is_valid); assert_eq!(v.results.len(), 4);
}
#[test]
fn rmsnorm_input_length_mismatch() {
let brick = RmsNormBrick::new(vec![1.0; 4], 1e-5);
let input = vec![1.0; 8]; let result = brick.run(&input);
assert!(result.is_err());
if let Err(BrickError::InvalidInput(msg)) = result {
assert!(msg.contains("8"));
assert!(msg.contains("4"));
}
}
#[test]
fn rmsnorm_verify_passes() {
let brick = RmsNormBrick::new(vec![1.0; 4], 1e-5);
let v = brick.verify();
assert!(v.is_valid);
}
#[test]
fn rmsnorm_can_run() {
let brick = RmsNormBrick::new(vec![1.0; 4], 1e-5);
assert!(brick.can_run());
}
#[test]
fn rmsnorm_with_custom_budget() {
let brick =
RmsNormBrick::new(vec![1.0; 4], 1e-5).with_budget(TokenBudget::from_latency(50.0));
assert!((brick.budget().us_per_token - 50.0).abs() < 0.001);
}
#[test]
fn qkv_total_out_dim() {
let brick = QkvBrick::new(128, 64, 32, 32);
assert_eq!(brick.total_out_dim(), 128); }
#[test]
fn qkv_with_bias() {
let brick = QkvBrick::new(128, 64, 32, 32).with_bias();
assert!(brick.has_bias);
}
#[test]
fn qkv_without_bias() {
let brick = QkvBrick::new(128, 64, 32, 32);
assert!(!brick.has_bias);
}
#[test]
fn attention_gqa_group_size() {
let brick = AttentionBrick::new(8, 2, 64);
assert_eq!(brick.group_size(), 4);
}
#[test]
fn attention_gqa_group_size_mha() {
let brick = AttentionBrick::new(8, 8, 64);
assert_eq!(brick.group_size(), 1);
}
#[test]
fn attention_gqa_group_size_mqa() {
let brick = AttentionBrick::new(8, 1, 64);
assert_eq!(brick.group_size(), 8);
}
#[test]
fn attention_gqa_group_size_zero_kv() {
let brick = AttentionBrick::new(8, 0, 64);
assert_eq!(brick.group_size(), 8); }
}