Skip to main content

ferrum_testkit/op_diff/
transpose_head_to_token.rs

1//! `transpose_head_to_token` op-diff harness — see `crate::op_diff`.
2//!
3//! Reorders `[tokens, heads, dim]` head-major data to token-major layout.
4//! Pure data movement (exact), so accelerator NMSE should be ~0.
5
6use super::{random_vec, OpUnderTest, Output};
7
8pub struct TransposeHeadToTokenOp {
9    pub tokens: usize,
10    pub heads: usize,
11    pub dim: usize,
12}
13
14impl TransposeHeadToTokenOp {
15    fn elems(&self) -> usize {
16        self.tokens * self.heads * self.dim
17    }
18}
19
20impl OpUnderTest for TransposeHeadToTokenOp {
21    fn name(&self) -> &str {
22        "transpose_head_to_token"
23    }
24
25    fn run_cpu(&self, seed: u64) -> Output {
26        use ferrum_kernels::backend::cpu::CpuBackend;
27        use ferrum_kernels::backend::Backend;
28
29        let src = random_vec(self.elems(), -2.0, 2.0, seed);
30        let mut ctx = CpuBackend::new_context();
31        let src_buf = CpuBackend::from_slice(&src);
32        let mut dst = CpuBackend::alloc(self.elems());
33        CpuBackend::transpose_head_to_token(
34            &mut ctx,
35            &src_buf,
36            &mut dst,
37            self.tokens,
38            self.heads,
39            self.dim,
40        );
41        CpuBackend::sync(&mut ctx);
42        CpuBackend::to_vec(&dst, self.elems())
43    }
44
45    #[cfg(all(target_os = "macos", feature = "metal"))]
46    fn run_metal(&self, seed: u64) -> Output {
47        use ferrum_kernels::backend::metal::MetalBackend;
48        use ferrum_kernels::backend::Backend;
49
50        let src = random_vec(self.elems(), -2.0, 2.0, seed);
51        let mut ctx = MetalBackend::new_context();
52        let src_buf = MetalBackend::from_slice(&src);
53        let mut dst = MetalBackend::alloc(self.elems());
54        MetalBackend::transpose_head_to_token(
55            &mut ctx,
56            &src_buf,
57            &mut dst,
58            self.tokens,
59            self.heads,
60            self.dim,
61        );
62        MetalBackend::sync(&mut ctx);
63        MetalBackend::to_vec(&dst, self.elems())
64    }
65
66    #[cfg(feature = "cuda")]
67    fn run_cuda(&self, seed: u64) -> Output {
68        use ferrum_kernels::backend::cuda::CudaBackend;
69        use ferrum_kernels::backend::Backend;
70
71        let src = random_vec(self.elems(), -2.0, 2.0, seed);
72        let mut ctx = CudaBackend::new_context();
73        let src_buf = CudaBackend::from_slice(&src);
74        let mut dst = CudaBackend::alloc(self.elems());
75        CudaBackend::transpose_head_to_token(
76            &mut ctx,
77            &src_buf,
78            &mut dst,
79            self.tokens,
80            self.heads,
81            self.dim,
82        );
83        CudaBackend::sync(&mut ctx);
84        CudaBackend::to_vec(&dst, self.elems())
85    }
86}