ferrum_testkit/op_diff/
embedding_lookup.rs1use super::{random_vec, OpUnderTest, Output};
7
8pub struct EmbeddingLookupOp {
9 pub vocab: usize,
10 pub dim: usize,
11 pub tokens: usize,
12}
13
14impl EmbeddingLookupOp {
15 fn build_input(&self, seed: u64) -> (Vec<f32>, Vec<u32>) {
16 let table = random_vec(self.vocab * self.dim, -1.0, 1.0, seed);
17 let ids: Vec<u32> = (0..self.tokens)
19 .map(|i| {
20 let h = (seed.wrapping_add(i as u64).wrapping_mul(2654435761)) as u32;
21 h % self.vocab as u32
22 })
23 .collect();
24 (table, ids)
25 }
26}
27
28impl OpUnderTest for EmbeddingLookupOp {
29 fn name(&self) -> &str {
30 "embedding_lookup"
31 }
32
33 fn run_cpu(&self, seed: u64) -> Output {
34 use ferrum_kernels::backend::cpu::CpuBackend;
35 use ferrum_kernels::backend::Backend;
36
37 let (table, ids) = self.build_input(seed);
38 let mut ctx = CpuBackend::new_context();
39 let table_buf = CpuBackend::from_slice(&table);
40 let mut out = CpuBackend::alloc(self.tokens * self.dim);
41 CpuBackend::embedding_lookup(&mut ctx, &table_buf, &ids, &mut out, self.dim);
42 CpuBackend::sync(&mut ctx);
43 CpuBackend::to_vec(&out, self.tokens * self.dim)
44 }
45
46 #[cfg(all(target_os = "macos", feature = "metal"))]
47 fn run_metal(&self, seed: u64) -> Output {
48 use ferrum_kernels::backend::metal::MetalBackend;
49 use ferrum_kernels::backend::Backend;
50
51 let (table, ids) = self.build_input(seed);
52 let mut ctx = MetalBackend::new_context();
53 let table_buf = MetalBackend::from_slice(&table);
54 let mut out = MetalBackend::alloc(self.tokens * self.dim);
55 MetalBackend::embedding_lookup(&mut ctx, &table_buf, &ids, &mut out, self.dim);
56 MetalBackend::sync(&mut ctx);
57 MetalBackend::to_vec(&out, self.tokens * self.dim)
58 }
59
60 #[cfg(feature = "cuda")]
61 fn run_cuda(&self, seed: u64) -> Output {
62 use ferrum_kernels::backend::cuda::CudaBackend;
63 use ferrum_kernels::backend::Backend;
64
65 let (table, ids) = self.build_input(seed);
66 let mut ctx = CudaBackend::new_context();
67 let table_buf = CudaBackend::from_slice(&table);
68 let mut out = CudaBackend::alloc(self.tokens * self.dim);
69 CudaBackend::embedding_lookup(&mut ctx, &table_buf, &ids, &mut out, self.dim);
70 CudaBackend::sync(&mut ctx);
71 CudaBackend::to_vec(&out, self.tokens * self.dim)
72 }
73}