Skip to main content

trueno/brick/quant_ops/
tests.rs

1#[allow(unused_imports)]
2use super::*;
3
4// ===== BlockQ5K Tests =====
5
6#[test]
7fn test_block_q5k_size() {
8    assert_eq!(BlockQ5K::BLOCK_SIZE, 256);
9}
10
11#[test]
12fn test_block_q5k_dequantize_basic() {
13    let block = BlockQ5K {
14        d: 0.1,
15        dmin: 0.0,
16        scales: [32; 12], // Neutral scales (32 - 32 = 0)
17        qh: [0; 32],      // All high bits 0
18        qs: [0x88; 128],  // 8,8 pattern (mid-range 4-bit)
19    };
20
21    let mut output = [0.0f32; 256];
22    block.dequantize(&mut output);
23
24    // With scale=0, all outputs should be dmin (0.0)
25    for val in &output {
26        assert!(val.abs() < 1.0, "Expected near-zero, got {}", val);
27    }
28}
29
30#[test]
31fn test_block_q5k_dequantize_with_scale() {
32    let block = BlockQ5K {
33        d: 1.0,
34        dmin: 0.5,
35        scales: [33; 12], // Scale of 1 (33 - 32 = 1)
36        qh: [0xFF; 32],   // All high bits set
37        qs: [0xFF; 128],  // All low bits set (15,15)
38    };
39
40    let mut output = [0.0f32; 256];
41    block.dequantize(&mut output);
42
43    // Values should be non-zero with positive scale
44    let non_zero_count = output.iter().filter(|&&v| v.abs() > 1e-6).count();
45    assert!(non_zero_count > 0, "Should have non-zero values");
46}
47
48#[test]
49fn test_block_q5k_dequantize_alternating() {
50    let block = BlockQ5K {
51        d: 0.5,
52        dmin: 0.1,
53        scales: [34; 12], // Scale of 2
54        qh: [0xAA; 32],   // Alternating bits
55        qs: [0x55; 128],  // Alternating nibbles (5,5)
56    };
57
58    let mut output = [0.0f32; 256];
59    block.dequantize(&mut output);
60
61    // All values should be finite
62    for val in &output {
63        assert!(val.is_finite(), "Value should be finite");
64    }
65}
66
67#[test]
68fn test_block_q5k_dequantize_odd_even_bytes() {
69    // Test both even and odd index paths in dequantization
70    let block = BlockQ5K {
71        d: 1.0,
72        dmin: 0.0,
73        scales: [48; 12], // Scale of 16 (48 - 32 = 16)
74        qh: [0; 32],
75        qs: [0x12; 128], // Low nibble = 2, high nibble = 1
76    };
77
78    let mut output = [0.0f32; 256];
79    block.dequantize(&mut output);
80
81    // Check that alternating values differ (even vs odd extraction)
82    // Since qs[i] = 0x12, even indices extract 2, odd indices extract 1
83    // Note: the actual dequant formula is complex, but values should differ
84    assert!(output[0] != output[1] || output[0].abs() < 1e-6);
85}
86
87// ===== BlockQ6K Tests =====
88
89#[test]
90fn test_block_q6k_size() {
91    assert_eq!(BlockQ6K::BLOCK_SIZE, 256);
92}
93
94#[test]
95fn test_block_q6k_dequantize_basic() {
96    let block = BlockQ6K {
97        ql: [0; 128],
98        qh: [0; 64],
99        scales: [0; 16], // Zero scales
100        d: 0.1,
101    };
102
103    let mut output = [0.0f32; 256];
104    block.dequantize(&mut output);
105
106    // With scale=0, all outputs should be d * 0 * (q6 - 32) = 0
107    for val in &output {
108        assert!(val.abs() < 1e-6, "Expected 0, got {}", val);
109    }
110}
111
112#[test]
113fn test_block_q6k_dequantize_with_scale() {
114    let block = BlockQ6K {
115        ql: [0xFF; 128], // Max low bits
116        qh: [0xFF; 64],  // Max high bits
117        scales: [1; 16], // Positive scale
118        d: 0.5,
119    };
120
121    let mut output = [0.0f32; 256];
122    block.dequantize(&mut output);
123
124    // Values should be non-zero
125    let non_zero = output.iter().any(|&v| v.abs() > 1e-6);
126    assert!(non_zero, "Should have non-zero values");
127}
128
129#[test]
130fn test_block_q6k_dequantize_negative_scale() {
131    let block = BlockQ6K {
132        ql: [0x88; 128],
133        qh: [0x55; 64],
134        scales: [-1; 16], // Negative scale
135        d: 1.0,
136    };
137
138    let mut output = [0.0f32; 256];
139    block.dequantize(&mut output);
140
141    // All values should be finite
142    for val in &output {
143        assert!(val.is_finite());
144    }
145}
146
147#[test]
148fn test_block_q6k_dequantize_all_subblocks() {
149    // Test that all 16 sub-blocks are processed
150    let block = BlockQ6K {
151        ql: [0x12; 128],
152        qh: [0x03; 64], // Different pattern per position
153        scales: [1, 2, 3, 4, 5, 6, 7, 8, -1, -2, -3, -4, -5, -6, -7, -8],
154        d: 0.1,
155    };
156
157    let mut output = [0.0f32; 256];
158    block.dequantize(&mut output);
159
160    // Check values at different sub-block boundaries
161    assert!(output[0].is_finite());
162    assert!(output[15].is_finite());
163    assert!(output[16].is_finite());
164    assert!(output[127].is_finite());
165    assert!(output[255].is_finite());
166}
167
168#[test]
169fn test_block_q6k_qh_extraction() {
170    // Test the 2-bit high value extraction logic
171    // qh_shift cycles through 0, 2, 4, 6 for i % 4 = 0, 1, 2, 3
172    let block = BlockQ6K {
173        ql: [0; 128],
174        qh: [0b11_10_01_00; 64], // Pattern: 0,1,2,3 across 4 positions
175        scales: [1; 16],
176        d: 1.0,
177    };
178
179    let mut output = [0.0f32; 256];
180    block.dequantize(&mut output);
181
182    // Different qh values should produce different outputs
183    // Position 0: qh_val = 0, Position 1: qh_val = 1, etc.
184    // This tests the (i % 4) * 2 shift logic
185    assert!(output[0].is_finite());
186    assert!(output[1].is_finite());
187    assert!(output[2].is_finite());
188    assert!(output[3].is_finite());
189}
190
191// ===== DotQ5KOp Tests =====
192
193#[test]
194fn test_dot_q5k_new() {
195    let op = DotQ5KOp::new(512);
196    assert_eq!(op.n_blocks, 2);
197}
198
199#[test]
200fn test_dot_q5k_name() {
201    let op = DotQ5KOp::new(256);
202    assert_eq!(op.name(), "dot_q5k");
203}
204
205#[test]
206fn test_dot_q5k_empty() {
207    let op = DotQ5KOp::new(256);
208    let result = op.execute((vec![], vec![]), Backend::Scalar).unwrap();
209    assert!((result - 0.0).abs() < 1e-6);
210}
211
212#[test]
213fn test_dot_q5k_empty_activations() {
214    let op = DotQ5KOp::new(256);
215    let block = BlockQ5K { d: 1.0, dmin: 0.0, scales: [32; 12], qh: [0; 32], qs: [0; 128] };
216    let result = op.execute((vec![block], vec![]), Backend::Scalar).unwrap();
217    assert!((result - 0.0).abs() < 1e-6);
218}
219
220#[test]
221fn test_dot_q5k_tokens() {
222    let op = DotQ5KOp::new(512); // 2 blocks
223    let input = (vec![], vec![]);
224    assert_eq!(op.tokens(&input), 512);
225}
226
227#[test]
228fn test_dot_q5k_scalar_execution() {
229    let op = DotQ5KOp::new(256);
230    let block = BlockQ5K {
231        d: 1.0,
232        dmin: 0.0,
233        scales: [33; 12], // Scale = 1
234        qh: [0; 32],
235        qs: [0x88; 128], // Mid-range values
236    };
237    let x = vec![1.0f32; 256];
238    let result = op.execute((vec![block], x), Backend::Scalar).unwrap();
239    assert!(result.is_finite());
240}
241
242#[test]
243fn test_dot_q5k_multiple_blocks() {
244    let op = DotQ5KOp::new(512);
245    let block = BlockQ5K { d: 0.5, dmin: 0.1, scales: [34; 12], qh: [0; 32], qs: [0x44; 128] };
246    let x = vec![0.5f32; 512];
247    let result = op.execute((vec![block.clone(), block], x), Backend::Scalar).unwrap();
248    assert!(result.is_finite());
249}
250
251#[test]
252fn test_dot_q5k_auto_backend() {
253    let op = DotQ5KOp::new(256);
254    let block = BlockQ5K { d: 1.0, dmin: 0.0, scales: [32; 12], qh: [0; 32], qs: [0; 128] };
255    let x = vec![1.0f32; 256];
256    // Auto backend should work (may use AVX2 if available)
257    let result = op.execute((vec![block], x), Backend::Auto).unwrap();
258    assert!(result.is_finite());
259}
260
261#[test]
262fn test_dot_q5k_avx2_backend() {
263    let op = DotQ5KOp::new(256);
264    let block = BlockQ5K { d: 1.0, dmin: 0.0, scales: [33; 12], qh: [0; 32], qs: [0x11; 128] };
265    let x = vec![2.0f32; 256];
266    // Request AVX2, will fall back to scalar if not available
267    let result = op.execute((vec![block], x), Backend::Avx2).unwrap();
268    assert!(result.is_finite());
269}
270
271// ===== DotQ6KOp Tests =====
272
273#[test]
274fn test_dot_q6k_new() {
275    let op = DotQ6KOp::new(768);
276    assert_eq!(op.n_blocks, 3);
277}
278
279#[test]
280fn test_dot_q6k_name() {
281    let op = DotQ6KOp::new(256);
282    assert_eq!(op.name(), "dot_q6k");
283}
284
285#[test]
286fn test_dot_q6k_empty() {
287    let op = DotQ6KOp::new(256);
288    let result = op.execute((vec![], vec![]), Backend::Scalar).unwrap();
289    assert!((result - 0.0).abs() < 1e-6);
290}
291
292#[test]
293fn test_dot_q6k_empty_activations() {
294    let op = DotQ6KOp::new(256);
295    let block = BlockQ6K { ql: [0; 128], qh: [0; 64], scales: [0; 16], d: 1.0 };
296    let result = op.execute((vec![block], vec![]), Backend::Scalar).unwrap();
297    assert!((result - 0.0).abs() < 1e-6);
298}
299
300#[test]
301fn test_dot_q6k_tokens() {
302    let op = DotQ6KOp::new(768); // 3 blocks
303    let input = (vec![], vec![]);
304    assert_eq!(op.tokens(&input), 768);
305}
306
307#[test]
308fn test_dot_q6k_scalar_execution() {
309    let op = DotQ6KOp::new(256);
310    let block = BlockQ6K { ql: [0x55; 128], qh: [0x55; 64], scales: [1; 16], d: 0.5 };
311    let x = vec![1.0f32; 256];
312    let result = op.execute((vec![block], x), Backend::Scalar).unwrap();
313    assert!(result.is_finite());
314}
315
316#[test]
317fn test_dot_q6k_multiple_blocks() {
318    let op = DotQ6KOp::new(512);
319    let block = BlockQ6K { ql: [0x33; 128], qh: [0x33; 64], scales: [2; 16], d: 0.25 };
320    let x = vec![0.5f32; 512];
321    let result = op.execute((vec![block.clone(), block], x), Backend::Scalar).unwrap();
322    assert!(result.is_finite());
323}
324
325#[test]
326fn test_dot_q6k_auto_backend() {
327    let op = DotQ6KOp::new(256);
328    let block = BlockQ6K { ql: [0; 128], qh: [0; 64], scales: [1; 16], d: 1.0 };
329    let x = vec![1.0f32; 256];
330    let result = op.execute((vec![block], x), Backend::Auto).unwrap();
331    assert!(result.is_finite());
332}
333
334#[test]
335fn test_dot_q6k_avx2_backend() {
336    let op = DotQ6KOp::new(256);
337    let block = BlockQ6K { ql: [0xAA; 128], qh: [0xAA; 64], scales: [3; 16], d: 0.1 };
338    let x = vec![2.0f32; 256];
339    let result = op.execute((vec![block], x), Backend::Avx2).unwrap();
340    assert!(result.is_finite());
341}
342
343// ===== Backend Equivalence Tests =====
344
345#[test]
346fn test_q5k_backend_equivalence() {
347    let op = DotQ5KOp::new(256);
348    let block = BlockQ5K { d: 0.5, dmin: 0.1, scales: [35; 12], qh: [0x55; 32], qs: [0x77; 128] };
349    let x = vec![1.5f32; 256];
350
351    let scalar = op.execute((vec![block.clone()], x.clone()), Backend::Scalar).unwrap();
352    let auto = op.execute((vec![block], x), Backend::Auto).unwrap();
353
354    // Allow small FP differences due to SIMD operation ordering
355    let rel_diff = (scalar - auto).abs() / scalar.abs().max(1e-6);
356    assert!(rel_diff < 1e-4, "scalar={scalar}, auto={auto}, rel_diff={rel_diff}");
357}
358
359#[test]
360fn test_q6k_backend_equivalence() {
361    let op = DotQ6KOp::new(256);
362    let block = BlockQ6K { ql: [0x66; 128], qh: [0x22; 64], scales: [4; 16], d: 0.2 };
363    let x = vec![1.5f32; 256];
364
365    let scalar = op.execute((vec![block.clone()], x.clone()), Backend::Scalar).unwrap();
366    let auto = op.execute((vec![block], x), Backend::Auto).unwrap();
367
368    // Allow small FP differences due to SIMD operation ordering
369    let rel_diff = (scalar - auto).abs() / scalar.abs().max(1e-6);
370    assert!(rel_diff < 1e-4, "scalar={scalar}, auto={auto}, rel_diff={rel_diff}");
371}
372
373// ===== Clone/Debug Trait Tests =====
374
375#[test]
376fn test_block_q5k_clone_debug() {
377    let block = BlockQ5K { d: 1.0, dmin: 0.5, scales: [32; 12], qh: [0; 32], qs: [0; 128] };
378    let cloned = block.clone();
379    assert_eq!(format!("{:?}", block), format!("{:?}", cloned));
380}
381
382#[test]
383fn test_block_q6k_clone_debug() {
384    let block = BlockQ6K { ql: [0; 128], qh: [0; 64], scales: [0; 16], d: 1.0 };
385    let cloned = block.clone();
386    assert_eq!(format!("{:?}", block), format!("{:?}", cloned));
387}
388
389#[test]
390fn test_dot_q5k_op_clone_debug() {
391    let op = DotQ5KOp::new(256);
392    let cloned = op.clone();
393    assert_eq!(format!("{:?}", op), format!("{:?}", cloned));
394}
395
396#[test]
397fn test_dot_q6k_op_clone_debug() {
398    let op = DotQ6KOp::new(256);
399    let cloned = op.clone();
400    assert_eq!(format!("{:?}", op), format!("{:?}", cloned));
401}