Skip to main content

provable_contracts/kernels/
embedding.rs

1//! Embedding lookup kernel.
2//!
3//! Matches `embedding-lookup-v1.yaml`.
4//! `output[i] = W[token_ids[i]]` — table lookup with bounds checking.
5//!
6//! Each function provides one of three backends:
7//! - `fn embedding_scalar(...)` -- Pure Rust scalar reference (ground truth)
8//! - `unsafe fn embedding_avx2(...)` -- AVX2 SIMD implementation
9//! - `fn embedding_ptx() -> &'static str` -- PTX assembly source string
10
11// ────────────────────────────────────────────────────────────────────────────
12// Scalar implementation
13// ────────────────────────────────────────────────────────────────────────────
14
15/// Embedding table lookup (scalar reference).
16///
17/// `weight` is `vocab_size x dim` (row-major), `token_ids` is `seq_len` indices,
18/// `output` is `seq_len x dim`.
19///
20/// # Panics
21/// Panics if any token_id >= vocab_size or output dimensions don't match.
22pub fn embedding_scalar(
23    weight: &[f32],
24    token_ids: &[u32],
25    vocab_size: usize,
26    dim: usize,
27    output: &mut [f32],
28) {
29    assert_eq!(weight.len(), vocab_size * dim, "weight dimension mismatch");
30    assert_eq!(
31        output.len(),
32        token_ids.len() * dim,
33        "output dimension mismatch"
34    );
35
36    for (i, &tid) in token_ids.iter().enumerate() {
37        let tid = tid as usize;
38        assert!(
39            tid < vocab_size,
40            "token_id {tid} >= vocab_size {vocab_size}"
41        );
42        let src = &weight[tid * dim..(tid + 1) * dim];
43        let dst = &mut output[i * dim..(i + 1) * dim];
44        dst.copy_from_slice(src);
45    }
46}
47
48// ────────────────────────────────────────────────────────────────────────────
49// AVX2 implementation
50// ────────────────────────────────────────────────────────────────────────────
51
52/// AVX2 embedding lookup -- delegates to scalar (memory-bound, no compute).
53///
54/// # Safety
55/// Requires AVX2 support.
56#[cfg(target_arch = "x86_64")]
57#[target_feature(enable = "avx2")]
58pub unsafe fn embedding_avx2(
59    weight: &[f32],
60    token_ids: &[u32],
61    vocab_size: usize,
62    dim: usize,
63    output: &mut [f32],
64) {
65    embedding_scalar(weight, token_ids, vocab_size, dim, output);
66}
67
68// ────────────────────────────────────────────────────────────────────────────
69// PTX implementation
70// ────────────────────────────────────────────────────────────────────────────
71
72/// PTX assembly for embedding lookup.
73///
74/// One thread per (token, dimension) pair. Each thread copies one f32
75/// from the embedding weight table to the output.
76pub fn embedding_ptx() -> &'static str {
77    r#".version 8.5
78.target sm_90
79.address_size 64
80.visible .entry embedding_kernel(
81    .param .u64 WEIGHT,
82    .param .u64 TOKEN_IDS,
83    .param .u64 OUT,
84    .param .u32 VOCAB_SIZE,
85    .param .u32 DIM
86) {
87    .reg .u32 %tid, %bid, %dim, %vocab_size, %token_id, %offset;
88    .reg .u64 %w_ptr, %t_ptr, %out_ptr, %addr, %off64;
89    .reg .f32 %val;
90    .reg .pred %p_bound;
91
92    mov.u32 %tid, %tid.x;
93    mov.u32 %bid, %ctaid.x;
94
95    ld.param.u32 %dim, [DIM];
96    ld.param.u32 %vocab_size, [VOCAB_SIZE];
97    ld.param.u64 %w_ptr, [WEIGHT];
98    ld.param.u64 %t_ptr, [TOKEN_IDS];
99    ld.param.u64 %out_ptr, [OUT];
100
101    // bid = token index, tid = dimension index
102    setp.ge.u32 %p_bound, %tid, %dim;
103    @%p_bound bra EXIT;
104
105    // Load token_id = TOKEN_IDS[bid]
106    mul.wide.u32 %off64, %bid, 4;
107    add.u64 %addr, %t_ptr, %off64;
108    ld.global.u32 %token_id, [%addr];
109
110    // Load WEIGHT[token_id * dim + tid]
111    mad.lo.u32 %offset, %token_id, %dim, %tid;
112    mul.wide.u32 %off64, %offset, 4;
113    add.u64 %addr, %w_ptr, %off64;
114    ld.global.f32 %val, [%addr];
115
116    // Store OUT[bid * dim + tid]
117    mad.lo.u32 %offset, %bid, %dim, %tid;
118    mul.wide.u32 %off64, %offset, 4;
119    add.u64 %addr, %out_ptr, %off64;
120    st.global.f32 [%addr], %val;
121
122EXIT:
123    ret;
124}
125"#
126}
127
128// ────────────────────────────────────────────────────────────────────────────
129// Tests
130// ────────────────────────────────────────────────────────────────────────────
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use proptest::prelude::*;
136
137    #[test]
138    fn test_embedding_basic() {
139        // 3 tokens, dim=2, vocab=4
140        let weight = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; // 4x2
141        let ids = [0, 2, 1];
142        let mut output = [0.0f32; 6]; // 3x2
143
144        embedding_scalar(&weight, &ids, 4, 2, &mut output);
145
146        assert_eq!(&output[0..2], &[1.0, 2.0]); // token 0
147        assert_eq!(&output[2..4], &[5.0, 6.0]); // token 2
148        assert_eq!(&output[4..6], &[3.0, 4.0]); // token 1
149    }
150
151    #[test]
152    fn test_embedding_single() {
153        let weight = [10.0, 20.0, 30.0];
154        let ids = [2];
155        let mut output = [0.0f32; 1];
156
157        embedding_scalar(&weight, &ids, 3, 1, &mut output);
158        assert_eq!(output[0], 30.0);
159    }
160
161    #[test]
162    #[should_panic(expected = "token_id 5 >= vocab_size 3")]
163    fn test_embedding_oob() {
164        let weight = [0.0f32; 6];
165        let ids = [5];
166        let mut output = [0.0f32; 2];
167        embedding_scalar(&weight, &ids, 3, 2, &mut output);
168    }
169
170    #[test]
171    fn test_embedding_deterministic() {
172        let weight = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
173        let ids = [1, 0, 2];
174        let mut out1 = [0.0f32; 6];
175        let mut out2 = [0.0f32; 6];
176        embedding_scalar(&weight, &ids, 3, 2, &mut out1);
177        embedding_scalar(&weight, &ids, 3, 2, &mut out2);
178        assert_eq!(out1, out2);
179    }
180
181    proptest! {
182        #[test]
183        fn prop_embedding_output_finite(
184            vocab_size in 2usize..8,
185            dim in 1usize..5,
186            seq_len in 1usize..6,
187        ) {
188            let weight: Vec<f32> = (0..vocab_size * dim)
189                .map(|i| (i as f32) * 0.1)
190                .collect();
191            let ids: Vec<u32> = (0..seq_len)
192                .map(|i| (i % vocab_size) as u32)
193                .collect();
194            let mut output = vec![0.0f32; seq_len * dim];
195
196            embedding_scalar(&weight, &ids, vocab_size, dim, &mut output);
197
198            for (idx, &val) in output.iter().enumerate() {
199                prop_assert!(val.is_finite(), "output[{idx}] = {val} not finite");
200            }
201        }
202    }
203
204    #[test]
205    fn test_embedding_ptx_structure() {
206        let ptx = embedding_ptx();
207        assert!(ptx.contains(".entry embedding_kernel"));
208        assert!(ptx.contains("ret;"));
209    }
210
211    #[cfg(target_arch = "x86_64")]
212    #[test]
213    fn test_embedding_avx2_parity() {
214        if !is_x86_feature_detected!("avx2") {
215            return;
216        }
217        let weight = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
218        let ids = [0, 3, 1];
219        let mut scalar_out = [0.0f32; 6];
220        let mut avx2_out = [0.0f32; 6];
221        embedding_scalar(&weight, &ids, 4, 2, &mut scalar_out);
222        unsafe { embedding_avx2(&weight, &ids, 4, 2, &mut avx2_out) };
223        assert_eq!(scalar_out, avx2_out);
224    }
225}