1use super::ops;
12
13pub fn linear_scalar(
28 x: &[f32],
29 weight: &[f32],
30 bias: &[f32],
31 batch: usize,
32 in_features: usize,
33 out_features: usize,
34 output: &mut [f32],
35) {
36 assert_eq!(x.len(), batch * in_features, "x dimension mismatch");
37 assert_eq!(
38 weight.len(),
39 out_features * in_features,
40 "weight dimension mismatch"
41 );
42 assert_eq!(
43 output.len(),
44 batch * out_features,
45 "output dimension mismatch"
46 );
47 assert!(
48 bias.is_empty() || bias.len() == out_features,
49 "bias must be empty or out_features={out_features}, got {}",
50 bias.len()
51 );
52
53 for b in 0..batch {
55 let x_row = &x[b * in_features..(b + 1) * in_features];
56 for o in 0..out_features {
57 let w_row = &weight[o * in_features..(o + 1) * in_features];
58 let mut val = ops::dot(x_row, w_row);
59 if !bias.is_empty() {
60 val += bias[o];
61 }
62 output[b * out_features + o] = val;
63 }
64 }
65}
66
67#[cfg(target_arch = "x86_64")]
76#[target_feature(enable = "avx2")]
77pub unsafe fn linear_avx2(
78 x: &[f32],
79 weight: &[f32],
80 bias: &[f32],
81 batch: usize,
82 in_features: usize,
83 out_features: usize,
84 output: &mut [f32],
85) {
86 linear_scalar(x, weight, bias, batch, in_features, out_features, output);
87}
88
89pub fn linear_ptx() -> &'static str {
98 r#".version 8.5
99.target sm_90
100.address_size 64
101.visible .entry linear_kernel(
102 .param .u64 X,
103 .param .u64 W,
104 .param .u64 BIAS,
105 .param .u64 OUT,
106 .param .u32 BATCH,
107 .param .u32 IN_FEAT,
108 .param .u32 OUT_FEAT,
109 .param .u32 HAS_BIAS
110) {
111 .reg .u32 %tid, %bid, %batch, %in_feat, %out_feat, %has_bias;
112 .reg .u32 %b_idx, %o_idx, %k, %tmp32;
113 .reg .u64 %x_ptr, %w_ptr, %bias_ptr, %out_ptr, %addr, %off64;
114 .reg .f32 %acc, %x_val, %w_val, %bias_val;
115 .reg .pred %p_k, %p_bias, %p_bound;
116
117 mov.u32 %tid, %tid.x;
118 mov.u32 %bid, %ctaid.x;
119
120 ld.param.u32 %batch, [BATCH];
121 ld.param.u32 %in_feat, [IN_FEAT];
122 ld.param.u32 %out_feat, [OUT_FEAT];
123 ld.param.u32 %has_bias, [HAS_BIAS];
124 ld.param.u64 %x_ptr, [X];
125 ld.param.u64 %w_ptr, [W];
126 ld.param.u64 %bias_ptr, [BIAS];
127 ld.param.u64 %out_ptr, [OUT];
128
129 // bid = batch index, tid = output feature index
130 mov.u32 %b_idx, %bid;
131 mov.u32 %o_idx, %tid;
132
133 setp.ge.u32 %p_bound, %o_idx, %out_feat;
134 @%p_bound bra EXIT;
135
136 // acc = dot(x[b_idx], w[o_idx])
137 mov.f32 %acc, 0f00000000;
138 mov.u32 %k, 0;
139DOT_LOOP:
140 setp.ge.u32 %p_k, %k, %in_feat;
141 @%p_k bra DOT_DONE;
142
143 // x[b_idx * in_feat + k]
144 mad.lo.u32 %tmp32, %b_idx, %in_feat, %k;
145 mul.wide.u32 %off64, %tmp32, 4;
146 add.u64 %addr, %x_ptr, %off64;
147 ld.global.f32 %x_val, [%addr];
148
149 // w[o_idx * in_feat + k]
150 mad.lo.u32 %tmp32, %o_idx, %in_feat, %k;
151 mul.wide.u32 %off64, %tmp32, 4;
152 add.u64 %addr, %w_ptr, %off64;
153 ld.global.f32 %w_val, [%addr];
154
155 fma.rn.f32 %acc, %x_val, %w_val, %acc;
156 add.u32 %k, %k, 1;
157 bra DOT_LOOP;
158DOT_DONE:
159
160 // Add bias if present
161 setp.eq.u32 %p_bias, %has_bias, 0;
162 @%p_bias bra STORE;
163 mul.wide.u32 %off64, %o_idx, 4;
164 add.u64 %addr, %bias_ptr, %off64;
165 ld.global.f32 %bias_val, [%addr];
166 add.f32 %acc, %acc, %bias_val;
167
168STORE:
169 mad.lo.u32 %tmp32, %b_idx, %out_feat, %o_idx;
170 mul.wide.u32 %off64, %tmp32, 4;
171 add.u64 %addr, %out_ptr, %off64;
172 st.global.f32 [%addr], %acc;
173
174EXIT:
175 ret;
176}
177"#
178}
179
180#[cfg(test)]
185mod tests {
186 use super::super::ulp::assert_ulp_eq;
187 use super::*;
188 use proptest::prelude::*;
189
190 #[test]
192 fn test_linear_basic_with_bias() {
193 let x = [1.0, 2.0];
196 let w = [3.0, 4.0, 5.0, 6.0]; let b = [10.0, 20.0];
198 let mut output = [0.0f32; 2];
199
200 linear_scalar(&x, &w, &b, 1, 2, 2, &mut output);
201 assert!((output[0] - 21.0).abs() < 1e-5);
202 assert!((output[1] - 37.0).abs() < 1e-5);
203 }
204
205 #[test]
207 fn test_linear_no_bias() {
208 let x = [1.0, 0.0];
209 let w = [1.0, 0.0, 0.0, 1.0]; let mut output = [0.0f32; 2];
211
212 linear_scalar(&x, &w, &[], 1, 2, 2, &mut output);
213 assert!((output[0] - 1.0).abs() < 1e-5);
214 assert!((output[1] - 0.0).abs() < 1e-5);
215 }
216
217 #[test]
219 fn test_linear_zero_input_returns_bias() {
220 let x = [0.0, 0.0, 0.0];
221 let w = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; let b = [7.0, 8.0];
223 let mut output = [0.0f32; 2];
224
225 linear_scalar(&x, &w, &b, 1, 3, 2, &mut output);
226 assert!((output[0] - 7.0).abs() < 1e-5);
227 assert!((output[1] - 8.0).abs() < 1e-5);
228 }
229
230 #[test]
232 fn test_linear_batch() {
233 let x = [1.0, 2.0, 3.0, 4.0]; let w = [1.0, 1.0]; let mut output = [0.0f32; 2]; linear_scalar(&x, &w, &[], 2, 2, 1, &mut output);
239 assert!((output[0] - 3.0).abs() < 1e-5); assert!((output[1] - 7.0).abs() < 1e-5); }
242
243 #[test]
245 fn test_linear_linearity() {
246 let x1 = [1.0, 2.0, 3.0];
248 let x2: Vec<f32> = x1.iter().map(|v| v * 2.0).collect();
249 let w = [0.5, 0.3, 0.1, 0.2, 0.4, 0.6]; let mut out1 = [0.0f32; 2];
251 let mut out2 = [0.0f32; 2];
252
253 linear_scalar(&x1, &w, &[], 1, 3, 2, &mut out1);
254 linear_scalar(&x2, &w, &[], 1, 3, 2, &mut out2);
255
256 for i in 0..2 {
257 assert!(
258 (out2[i] - 2.0 * out1[i]).abs() < 1e-5,
259 "linearity violated at {i}: f(2x)={} vs 2*f(x)={}",
260 out2[i],
261 2.0 * out1[i]
262 );
263 }
264 }
265
266 proptest! {
267 #[test]
268 fn prop_linear_output_finite(
269 batch in 1usize..3,
270 in_f in 1usize..5,
271 out_f in 1usize..5,
272 ) {
273 let x: Vec<f32> = (0..batch * in_f).map(|i| (i as f32) * 0.1).collect();
274 let w: Vec<f32> = (0..out_f * in_f).map(|i| (i as f32) * 0.1).collect();
275 let b: Vec<f32> = (0..out_f).map(|i| (i as f32) * 0.01).collect();
276 let mut output = vec![0.0f32; batch * out_f];
277
278 linear_scalar(&x, &w, &b, batch, in_f, out_f, &mut output);
279
280 for (idx, &val) in output.iter().enumerate() {
281 prop_assert!(val.is_finite(), "output[{idx}] = {val} not finite");
282 }
283 }
284 }
285
286 #[test]
288 fn test_linear_ptx_structure() {
289 let ptx = linear_ptx();
290 assert!(ptx.contains(".entry linear_kernel"));
291 assert!(ptx.contains("fma.rn.f32"));
292 assert!(ptx.contains("ret;"));
293 }
294
295 #[cfg(target_arch = "x86_64")]
297 #[test]
298 fn test_linear_avx2_parity() {
299 if !is_x86_feature_detected!("avx2") {
300 return;
301 }
302 let x = [1.0, 2.0, 3.0, 4.0]; let w = [0.5, 0.3, 0.1, 0.2, 0.4, 0.6, 0.7, 0.8]; let b = [1.0, 2.0];
305 let mut scalar_out = [0.0f32; 2]; let mut avx2_out = [0.0f32; 2];
307 linear_scalar(&x, &w, &b, 1, 4, 2, &mut scalar_out);
308 unsafe { linear_avx2(&x, &w, &b, 1, 4, 2, &mut avx2_out) };
309 assert_ulp_eq(&scalar_out, &avx2_out, 0);
310 }
311}