provable_contracts/kernels/
embedding.rs1pub fn embedding_scalar(
23 weight: &[f32],
24 token_ids: &[u32],
25 vocab_size: usize,
26 dim: usize,
27 output: &mut [f32],
28) {
29 assert_eq!(weight.len(), vocab_size * dim, "weight dimension mismatch");
30 assert_eq!(
31 output.len(),
32 token_ids.len() * dim,
33 "output dimension mismatch"
34 );
35
36 for (i, &tid) in token_ids.iter().enumerate() {
37 let tid = tid as usize;
38 assert!(
39 tid < vocab_size,
40 "token_id {tid} >= vocab_size {vocab_size}"
41 );
42 let src = &weight[tid * dim..(tid + 1) * dim];
43 let dst = &mut output[i * dim..(i + 1) * dim];
44 dst.copy_from_slice(src);
45 }
46}
47
48#[cfg(target_arch = "x86_64")]
57#[target_feature(enable = "avx2")]
58pub unsafe fn embedding_avx2(
59 weight: &[f32],
60 token_ids: &[u32],
61 vocab_size: usize,
62 dim: usize,
63 output: &mut [f32],
64) {
65 embedding_scalar(weight, token_ids, vocab_size, dim, output);
66}
67
68pub fn embedding_ptx() -> &'static str {
77 r#".version 8.5
78.target sm_90
79.address_size 64
80.visible .entry embedding_kernel(
81 .param .u64 WEIGHT,
82 .param .u64 TOKEN_IDS,
83 .param .u64 OUT,
84 .param .u32 VOCAB_SIZE,
85 .param .u32 DIM
86) {
87 .reg .u32 %tid, %bid, %dim, %vocab_size, %token_id, %offset;
88 .reg .u64 %w_ptr, %t_ptr, %out_ptr, %addr, %off64;
89 .reg .f32 %val;
90 .reg .pred %p_bound;
91
92 mov.u32 %tid, %tid.x;
93 mov.u32 %bid, %ctaid.x;
94
95 ld.param.u32 %dim, [DIM];
96 ld.param.u32 %vocab_size, [VOCAB_SIZE];
97 ld.param.u64 %w_ptr, [WEIGHT];
98 ld.param.u64 %t_ptr, [TOKEN_IDS];
99 ld.param.u64 %out_ptr, [OUT];
100
101 // bid = token index, tid = dimension index
102 setp.ge.u32 %p_bound, %tid, %dim;
103 @%p_bound bra EXIT;
104
105 // Load token_id = TOKEN_IDS[bid]
106 mul.wide.u32 %off64, %bid, 4;
107 add.u64 %addr, %t_ptr, %off64;
108 ld.global.u32 %token_id, [%addr];
109
110 // Load WEIGHT[token_id * dim + tid]
111 mad.lo.u32 %offset, %token_id, %dim, %tid;
112 mul.wide.u32 %off64, %offset, 4;
113 add.u64 %addr, %w_ptr, %off64;
114 ld.global.f32 %val, [%addr];
115
116 // Store OUT[bid * dim + tid]
117 mad.lo.u32 %offset, %bid, %dim, %tid;
118 mul.wide.u32 %off64, %offset, 4;
119 add.u64 %addr, %out_ptr, %off64;
120 st.global.f32 [%addr], %val;
121
122EXIT:
123 ret;
124}
125"#
126}
127
128#[cfg(test)]
133mod tests {
134 use super::*;
135 use proptest::prelude::*;
136
137 #[test]
138 fn test_embedding_basic() {
139 let weight = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; let ids = [0, 2, 1];
142 let mut output = [0.0f32; 6]; embedding_scalar(&weight, &ids, 4, 2, &mut output);
145
146 assert_eq!(&output[0..2], &[1.0, 2.0]); assert_eq!(&output[2..4], &[5.0, 6.0]); assert_eq!(&output[4..6], &[3.0, 4.0]); }
150
151 #[test]
152 fn test_embedding_single() {
153 let weight = [10.0, 20.0, 30.0];
154 let ids = [2];
155 let mut output = [0.0f32; 1];
156
157 embedding_scalar(&weight, &ids, 3, 1, &mut output);
158 assert_eq!(output[0], 30.0);
159 }
160
161 #[test]
162 #[should_panic(expected = "token_id 5 >= vocab_size 3")]
163 fn test_embedding_oob() {
164 let weight = [0.0f32; 6];
165 let ids = [5];
166 let mut output = [0.0f32; 2];
167 embedding_scalar(&weight, &ids, 3, 2, &mut output);
168 }
169
170 #[test]
171 fn test_embedding_deterministic() {
172 let weight = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
173 let ids = [1, 0, 2];
174 let mut out1 = [0.0f32; 6];
175 let mut out2 = [0.0f32; 6];
176 embedding_scalar(&weight, &ids, 3, 2, &mut out1);
177 embedding_scalar(&weight, &ids, 3, 2, &mut out2);
178 assert_eq!(out1, out2);
179 }
180
181 proptest! {
182 #[test]
183 fn prop_embedding_output_finite(
184 vocab_size in 2usize..8,
185 dim in 1usize..5,
186 seq_len in 1usize..6,
187 ) {
188 let weight: Vec<f32> = (0..vocab_size * dim)
189 .map(|i| (i as f32) * 0.1)
190 .collect();
191 let ids: Vec<u32> = (0..seq_len)
192 .map(|i| (i % vocab_size) as u32)
193 .collect();
194 let mut output = vec![0.0f32; seq_len * dim];
195
196 embedding_scalar(&weight, &ids, vocab_size, dim, &mut output);
197
198 for (idx, &val) in output.iter().enumerate() {
199 prop_assert!(val.is_finite(), "output[{idx}] = {val} not finite");
200 }
201 }
202 }
203
204 #[test]
205 fn test_embedding_ptx_structure() {
206 let ptx = embedding_ptx();
207 assert!(ptx.contains(".entry embedding_kernel"));
208 assert!(ptx.contains("ret;"));
209 }
210
211 #[cfg(target_arch = "x86_64")]
212 #[test]
213 fn test_embedding_avx2_parity() {
214 if !is_x86_feature_detected!("avx2") {
215 return;
216 }
217 let weight = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
218 let ids = [0, 3, 1];
219 let mut scalar_out = [0.0f32; 6];
220 let mut avx2_out = [0.0f32; 6];
221 embedding_scalar(&weight, &ids, 4, 2, &mut scalar_out);
222 unsafe { embedding_avx2(&weight, &ids, 4, 2, &mut avx2_out) };
223 assert_eq!(scalar_out, avx2_out);
224 }
225}