Skip to main content

ferrum_testkit/op_diff/
embedding_lookup.rs

1//! `embedding_lookup` op-diff harness — see `crate::op_diff`.
2//!
3//! Gather `n` rows from a `[vocab, dim]` table by token id. Exact (no
4//! arithmetic), so accelerator NMSE should be ~0.
5
6use super::{random_vec, OpUnderTest, Output};
7
8pub struct EmbeddingLookupOp {
9    pub vocab: usize,
10    pub dim: usize,
11    pub tokens: usize,
12}
13
14impl EmbeddingLookupOp {
15    fn build_input(&self, seed: u64) -> (Vec<f32>, Vec<u32>) {
16        let table = random_vec(self.vocab * self.dim, -1.0, 1.0, seed);
17        // Deterministic ids in [0, vocab) without rand: hash of seed+index.
18        let ids: Vec<u32> = (0..self.tokens)
19            .map(|i| {
20                let h = (seed.wrapping_add(i as u64).wrapping_mul(2654435761)) as u32;
21                h % self.vocab as u32
22            })
23            .collect();
24        (table, ids)
25    }
26}
27
28impl OpUnderTest for EmbeddingLookupOp {
29    fn name(&self) -> &str {
30        "embedding_lookup"
31    }
32
33    fn run_cpu(&self, seed: u64) -> Output {
34        use ferrum_kernels::backend::cpu::CpuBackend;
35        use ferrum_kernels::backend::Backend;
36
37        let (table, ids) = self.build_input(seed);
38        let mut ctx = CpuBackend::new_context();
39        let table_buf = CpuBackend::from_slice(&table);
40        let mut out = CpuBackend::alloc(self.tokens * self.dim);
41        CpuBackend::embedding_lookup(&mut ctx, &table_buf, &ids, &mut out, self.dim);
42        CpuBackend::sync(&mut ctx);
43        CpuBackend::to_vec(&out, self.tokens * self.dim)
44    }
45
46    #[cfg(all(target_os = "macos", feature = "metal"))]
47    fn run_metal(&self, seed: u64) -> Output {
48        use ferrum_kernels::backend::metal::MetalBackend;
49        use ferrum_kernels::backend::Backend;
50
51        let (table, ids) = self.build_input(seed);
52        let mut ctx = MetalBackend::new_context();
53        let table_buf = MetalBackend::from_slice(&table);
54        let mut out = MetalBackend::alloc(self.tokens * self.dim);
55        MetalBackend::embedding_lookup(&mut ctx, &table_buf, &ids, &mut out, self.dim);
56        MetalBackend::sync(&mut ctx);
57        MetalBackend::to_vec(&out, self.tokens * self.dim)
58    }
59
60    #[cfg(feature = "cuda")]
61    fn run_cuda(&self, seed: u64) -> Output {
62        use ferrum_kernels::backend::cuda::CudaBackend;
63        use ferrum_kernels::backend::Backend;
64
65        let (table, ids) = self.build_input(seed);
66        let mut ctx = CudaBackend::new_context();
67        let table_buf = CudaBackend::from_slice(&table);
68        let mut out = CudaBackend::alloc(self.tokens * self.dim);
69        CudaBackend::embedding_lookup(&mut ctx, &table_buf, &ids, &mut out, self.dim);
70        CudaBackend::sync(&mut ctx);
71        CudaBackend::to_vec(&out, self.tokens * self.dim)
72    }
73}