use crate::gguf::test_helpers::create_test_model_with_config;
use crate::gguf::GGUFConfig;
fn small_config() -> GGUFConfig {
GGUFConfig {
architecture: "llama".to_string(),
constraints: crate::gguf::ArchConstraints::from_architecture("llama"),
hidden_dim: 64,
intermediate_dim: 128,
num_heads: 4,
num_kv_heads: 4,
num_layers: 1,
vocab_size: 100,
context_length: 256,
rope_theta: 10000.0,
eps: 1e-5,
rope_type: 0,
explicit_head_dim: None,
bos_token_id: None,
eos_token_id: None,
}
}
#[test]
fn test_flash_attention_tiled_no_cache() {
let model = create_test_model_with_config(&small_config());
let hidden_dim = model.config.hidden_dim;
let q = vec![1.0f32; hidden_dim];
let k_cache: Vec<f32> = vec![]; let v_cache: Vec<f32> = vec![];
let current_k = vec![0.5f32; hidden_dim];
let current_v = vec![0.3f32; hidden_dim];
let output = model.flash_attention_tiled(&q, &k_cache, &v_cache, ¤t_k, ¤t_v, 64);
assert_eq!(output.len(), hidden_dim);
assert!(output.iter().all(|x| x.is_finite()));
for (o, &expected) in output.iter().zip(current_v.iter()) {
let diff = (o - expected).abs();
assert!(diff < 1e-3, "output={o} expected={expected} diff={diff}");
}
}
#[test]
fn test_flash_attention_tiled_with_cache() {
let model = create_test_model_with_config(&small_config());
let hidden_dim = model.config.hidden_dim;
let q = vec![1.0f32; hidden_dim];
let cache_len = 3;
let k_cache = vec![0.1f32; cache_len * hidden_dim];
let v_cache = vec![0.2f32; cache_len * hidden_dim];
let current_k = vec![1.0f32; hidden_dim];
let current_v = vec![0.5f32; hidden_dim];
let output = model.flash_attention_tiled(&q, &k_cache, &v_cache, ¤t_k, ¤t_v, 64);
assert_eq!(output.len(), hidden_dim);
assert!(output.iter().all(|x| x.is_finite()));
for val in &output {
assert!(*val > 0.0 && *val < 1.0, "val={val} out of expected range");
}
}
#[test]
fn test_flash_attention_tiled_multi_tile() {
let model = create_test_model_with_config(&small_config());
let hidden_dim = model.config.hidden_dim;
let cache_len = 5;
let k_cache = vec![0.1f32; cache_len * hidden_dim];
let v_cache = vec![0.2f32; cache_len * hidden_dim];
let q = vec![1.0f32; hidden_dim];
let current_k = vec![0.5f32; hidden_dim];
let current_v = vec![0.3f32; hidden_dim];
let output = model.flash_attention_tiled(&q, &k_cache, &v_cache, ¤t_k, ¤t_v, 2);
assert_eq!(output.len(), hidden_dim);
assert!(output.iter().all(|x| x.is_finite()));
}
#[test]
fn test_flash_attention_tiled_parity_with_standard() {
let model = create_test_model_with_config(&small_config());
let hidden_dim = model.config.hidden_dim;
let q = vec![1.0f32; hidden_dim];
let cache_len = 2;
let k_cache = vec![0.3f32; cache_len * hidden_dim];
let v_cache = vec![0.4f32; cache_len * hidden_dim];
let current_k = vec![0.7f32; hidden_dim];
let current_v = vec![0.6f32; hidden_dim];
let flash_output =
model.flash_attention_tiled(&q, &k_cache, &v_cache, ¤t_k, ¤t_v, 64);
let standard_output =
model.attention_with_cache(&q, &k_cache, &v_cache, ¤t_k, ¤t_v);
assert_eq!(flash_output.len(), standard_output.len());
for (f, s) in flash_output.iter().zip(standard_output.iter()) {
let diff = (f - s).abs();
assert!(
diff < 1e-3,
"flash={f} standard={s} diff={diff} — parity failure"
);
}
}
#[test]
fn test_flash_attention_tiled_large_cache() {
let model = create_test_model_with_config(&small_config());
let hidden_dim = model.config.hidden_dim;
let cache_len = 100;
let k_cache: Vec<f32> = (0..cache_len * hidden_dim)
.map(|i| ((i % 7) as f32) * 0.1)
.collect();
let v_cache: Vec<f32> = (0..cache_len * hidden_dim)
.map(|i| ((i % 11) as f32) * 0.05)
.collect();
let q = vec![0.5f32; hidden_dim];
let current_k = vec![0.3f32; hidden_dim];
let current_v = vec![0.2f32; hidden_dim];
let output = model.flash_attention_tiled(&q, &k_cache, &v_cache, ¤t_k, ¤t_v, 32);
assert_eq!(output.len(), hidden_dim);
assert!(output.iter().all(|x| x.is_finite()));
}
#[test]
fn test_flash_attention_tiled_gqa() {
let gqa_config = GGUFConfig {
architecture: "llama".to_string(),
constraints: crate::gguf::ArchConstraints::from_architecture("llama"),
hidden_dim: 64,
intermediate_dim: 128,
num_heads: 8,
num_kv_heads: 2,
num_layers: 1,
vocab_size: 100,
context_length: 256,
rope_theta: 10000.0,
eps: 1e-5,
rope_type: 0,
explicit_head_dim: None,
bos_token_id: None,
eos_token_id: None,
};
let model = create_test_model_with_config(&gqa_config);
let hidden_dim = model.config.hidden_dim;
let q = vec![1.0f32; hidden_dim];
let k_cache = vec![0.2f32; 3 * hidden_dim];
let v_cache = vec![0.3f32; 3 * hidden_dim];
let current_k = vec![0.5f32; hidden_dim];
let current_v = vec![0.4f32; hidden_dim];
let output = model.flash_attention_tiled(&q, &k_cache, &v_cache, ¤t_k, ¤t_v, 64);
assert_eq!(output.len(), hidden_dim);
assert!(output.iter().all(|x| x.is_finite()));
}