ferrum_testkit/op_diff/
transpose_head_to_token.rs1use super::{random_vec, OpUnderTest, Output};
7
8pub struct TransposeHeadToTokenOp {
9 pub tokens: usize,
10 pub heads: usize,
11 pub dim: usize,
12}
13
14impl TransposeHeadToTokenOp {
15 fn elems(&self) -> usize {
16 self.tokens * self.heads * self.dim
17 }
18}
19
20impl OpUnderTest for TransposeHeadToTokenOp {
21 fn name(&self) -> &str {
22 "transpose_head_to_token"
23 }
24
25 fn run_cpu(&self, seed: u64) -> Output {
26 use ferrum_kernels::backend::cpu::CpuBackend;
27 use ferrum_kernels::backend::Backend;
28
29 let src = random_vec(self.elems(), -2.0, 2.0, seed);
30 let mut ctx = CpuBackend::new_context();
31 let src_buf = CpuBackend::from_slice(&src);
32 let mut dst = CpuBackend::alloc(self.elems());
33 CpuBackend::transpose_head_to_token(
34 &mut ctx,
35 &src_buf,
36 &mut dst,
37 self.tokens,
38 self.heads,
39 self.dim,
40 );
41 CpuBackend::sync(&mut ctx);
42 CpuBackend::to_vec(&dst, self.elems())
43 }
44
45 #[cfg(all(target_os = "macos", feature = "metal"))]
46 fn run_metal(&self, seed: u64) -> Output {
47 use ferrum_kernels::backend::metal::MetalBackend;
48 use ferrum_kernels::backend::Backend;
49
50 let src = random_vec(self.elems(), -2.0, 2.0, seed);
51 let mut ctx = MetalBackend::new_context();
52 let src_buf = MetalBackend::from_slice(&src);
53 let mut dst = MetalBackend::alloc(self.elems());
54 MetalBackend::transpose_head_to_token(
55 &mut ctx,
56 &src_buf,
57 &mut dst,
58 self.tokens,
59 self.heads,
60 self.dim,
61 );
62 MetalBackend::sync(&mut ctx);
63 MetalBackend::to_vec(&dst, self.elems())
64 }
65
66 #[cfg(feature = "cuda")]
67 fn run_cuda(&self, seed: u64) -> Output {
68 use ferrum_kernels::backend::cuda::CudaBackend;
69 use ferrum_kernels::backend::Backend;
70
71 let src = random_vec(self.elems(), -2.0, 2.0, seed);
72 let mut ctx = CudaBackend::new_context();
73 let src_buf = CudaBackend::from_slice(&src);
74 let mut dst = CudaBackend::alloc(self.elems());
75 CudaBackend::transpose_head_to_token(
76 &mut ctx,
77 &src_buf,
78 &mut dst,
79 self.tokens,
80 self.heads,
81 self.dim,
82 );
83 CudaBackend::sync(&mut ctx);
84 CudaBackend::to_vec(&dst, self.elems())
85 }
86}