dot_product/
dot_product.rs

1use aligned_vec::avec;
2use core::iter;
3use diol::prelude::*;
4use pulp::{Arch, Simd};
5
6// dot product
7// sum x * y
8pub fn dot_product_scalar(x: &[f32], y: &[f32]) -> f32 {
9	let mut acc = 0.0;
10	for (x, y) in iter::zip(x, y) {
11		acc += x * y;
12	}
13	acc
14}
15pub fn bench_dot_scalar(bencher: Bencher, PlotArg(n): PlotArg) {
16	let x = &*vec![1.0_f32; n];
17	let y = &*vec![1.0_f32; n];
18
19	bencher.bench(|| dot_product_scalar(x, y))
20}
21
22#[allow(dead_code)]
23mod inline_examples {
24	// inline hint: compiler pls inline this
25	#[inline]
26	fn foo0(x: i32, y: i32) -> i32 {
27		x + y
28	}
29
30	// strong inline hint: compiler i really need you to inline this
31	// needed for simd
32	#[inline(always)]
33	fn foo1(x: i32, y: i32) -> i32 {
34		x + y
35	}
36
37	// no_inline hint: compiler pls inline this
38	#[inline(never)]
39	fn foo2(x: i32, y: i32) -> i32 {
40		x + y
41	}
42
43	fn bar(x: i32, y: i32) -> i32 {
44		foo0(x, y) + foo1(x, 16)
45	}
46}
47
48#[cfg(target_arch = "x86_64")]
49mod x86 {
50	use super::*;
51
52	use pulp::x86::V3;
53	use pulp::{cast, f32x8};
54
55	// x86/x86_64
56	// - V3 simd uses 256bit registers (f32x8)   => 16 registers
57	// - V4 simd uses 512bit registers (f32x16)  => 32 registers (for consumer cpus, only available
58	//   on 11th gen intel, and amd zen4 (or zen5?))
59	//
60	// arm/aarch64
61	// typically register size of 128bit (f32x4) => 32 registers
62
63	pub fn dot_product_simd_v3(simd: V3, x: &[f32], y: &[f32]) -> f32 {
64		// essentially emulating a closure because closures don't always respect
65		// #[inline(always)], which we need for the compiler to inline the simd intrinsics
66		struct Impl<'a> {
67			simd: V3,
68			x: &'a [f32],
69			y: &'a [f32],
70		}
71
72		impl pulp::NullaryFnOnce for Impl<'_> {
73			type Output = f32;
74
75			// PERF: must be #[inline(always)]
76			#[inline(always)]
77			fn call(self) -> Self::Output {
78				let Self { simd, x, y } = self;
79
80				// x0 x1 x2.. x100
81				// [[x0 x1 x2 x3 x4 x5 x6 x7] [x8 x9 x10.. x15] [..x95]] | [x96 x97 x98 x99 x100]
82				let (x8, x1) = pulp::as_arrays::<8, _>(x);
83				let (y8, y1) = pulp::as_arrays::<8, _>(y);
84
85				let mut acc = 0.0;
86				for (x, y) in iter::zip(x8, y8) {
87					let x: f32x8 = cast(*x);
88					let y: f32x8 = cast(*y);
89
90					acc += simd.reduce_sum_f32s(simd.mul_f32x8(x, y));
91				}
92				for (x, y) in iter::zip(x1, y1) {
93					acc += x * y;
94				}
95
96				acc
97			}
98		}
99
100		simd.vectorize(Impl { simd, x, y })
101	}
102
103	pub fn dot_product_simd_extract_reduce_v3(simd: V3, x: &[f32], y: &[f32]) -> f32 {
104		struct Impl<'a> {
105			simd: V3,
106			x: &'a [f32],
107			y: &'a [f32],
108		}
109
110		impl pulp::NullaryFnOnce for Impl<'_> {
111			type Output = f32;
112
113			#[inline(always)]
114			fn call(self) -> Self::Output {
115				let Self { simd, x, y } = self;
116
117				// x0 x1 x2.. x100
118				// [[x0 x1 x2 x3 x4 x5 x6 x7] [x8 x9 x10.. x15] [..x95]] | [x96 x97 x98 x99 x100]
119				let (x8, x1) = pulp::as_arrays::<8, _>(x);
120				let (y8, y1) = pulp::as_arrays::<8, _>(y);
121
122				// sum (x * y)
123				// sum (reduce_sum(X * Y))
124				// reduce_sum(sum (X * Y))
125
126				// [0.0; 8]
127				let mut acc = simd.splat_f32x8(0.0);
128				for (x, y) in iter::zip(x8, y8) {
129					let x: f32x8 = cast(*x);
130					let y: f32x8 = cast(*y);
131
132					acc = simd.add_f32x8(acc, simd.mul_f32x8(x, y));
133				}
134				// reduce_sum_f32s
135				// f32x8 -> f32x4 + f32x4
136				// f32x4 -> f32x2 + f32x2
137				// f32x2 -> f32 + f32
138				let mut acc = simd.reduce_sum_f32s(acc);
139
140				for (x, y) in iter::zip(x1, y1) {
141					acc += x * y;
142				}
143
144				acc
145			}
146		}
147
148		simd.vectorize(Impl { simd, x, y })
149	}
150
151	// ilp: instruction-level-parallelism
152	// out of order execution
153	pub fn dot_product_simd_extract_reduce_ilp_v3(simd: V3, x: &[f32], y: &[f32]) -> f32 {
154		struct Impl<'a> {
155			simd: V3,
156			x: &'a [f32],
157			y: &'a [f32],
158		}
159
160		impl pulp::NullaryFnOnce for Impl<'_> {
161			type Output = f32;
162
163			#[inline(always)]
164			fn call(self) -> Self::Output {
165				let Self { simd, x, y } = self;
166
167				// x0 x1 x2.. x100
168				// [[x0 x1 x2 x3 x4 x5 x6 x7] [x8 x9 x10.. x15] [..x95]] | [x96 x97 x98 x99 x100]
169				let (x8, x1) = pulp::as_arrays::<8, _>(x);
170				let (y8, y1) = pulp::as_arrays::<8, _>(y);
171
172				// sum (x * y)
173				// sum (reduce_sum(X * Y))
174				// reduce_sum(sum (X * Y))
175
176				let mut acc0 = simd.splat_f32x8(0.0);
177				let mut acc1 = simd.splat_f32x8(0.0);
178				let mut acc2 = simd.splat_f32x8(0.0);
179				let mut acc3 = simd.splat_f32x8(0.0);
180
181				// 12 registers are being used
182				// 4 for accumulators + 2×4 inside the loop for x[0|1] and y[0|1]
183				let (x8_4, x8_1) = pulp::as_arrays::<4, _>(x8);
184				let (y8_4, y8_1) = pulp::as_arrays::<4, _>(y8);
185
186				for ([x0, x1, x2, x3], [y0, y1, y2, y3]) in iter::zip(x8_4, y8_4) {
187					let x0: f32x8 = cast(*x0);
188					let y0: f32x8 = cast(*y0);
189					let x1: f32x8 = cast(*x1);
190					let y1: f32x8 = cast(*y1);
191					let x2: f32x8 = cast(*x2);
192					let y2: f32x8 = cast(*y2);
193					let x3: f32x8 = cast(*x3);
194					let y3: f32x8 = cast(*y3);
195
196					acc0 = simd.add_f32x8(acc0, simd.mul_f32x8(x0, y0));
197					acc1 = simd.add_f32x8(acc1, simd.mul_f32x8(x1, y1));
198					acc2 = simd.add_f32x8(acc2, simd.mul_f32x8(x2, y2));
199					acc3 = simd.add_f32x8(acc3, simd.mul_f32x8(x3, y3));
200				}
201
202				for (x0, y0) in iter::zip(x8_1, y8_1) {
203					let x0: f32x8 = cast(*x0);
204					let y0: f32x8 = cast(*y0);
205					acc0 = simd.add_f32x8(acc0, simd.mul_f32x8(x0, y0));
206				}
207
208				// reduce_sum_f32s
209				// f32x8 -> f32x4 + f32x4
210				// f32x4 -> f32x2 + f32x2
211				// f32x2 -> f32 + f32
212				acc0 = simd.add_f32x8(acc0, acc1);
213				acc2 = simd.add_f32x8(acc2, acc3);
214
215				acc0 = simd.add_f32x8(acc0, acc2);
216
217				let mut acc = simd.reduce_sum_f32s(acc0);
218
219				for (x, y) in iter::zip(x1, y1) {
220					acc += x * y;
221				}
222
223				acc
224			}
225		}
226
227		simd.vectorize(Impl { simd, x, y })
228	}
229
230	// fma: fused multiply add
231	pub fn dot_product_simd_extract_reduce_ilp_fma_v3(simd: V3, x: &[f32], y: &[f32]) -> f32 {
232		struct Impl<'a> {
233			simd: V3,
234			x: &'a [f32],
235			y: &'a [f32],
236		}
237
238		impl pulp::NullaryFnOnce for Impl<'_> {
239			type Output = f32;
240
241			#[inline(always)]
242			fn call(self) -> Self::Output {
243				let Self { simd, x, y } = self;
244
245				// x0 x1 x2.. x100
246				// [[x0 x1 x2 x3 x4 x5 x6 x7] [x8 x9 x10.. x15] [..x95]] | [x96 x97 x98 x99 x100]
247				let (x8, x1) = pulp::as_arrays::<8, _>(x);
248				let (y8, y1) = pulp::as_arrays::<8, _>(y);
249
250				// sum (x * y)
251				// sum (reduce_sum(X * Y))
252				// reduce_sum(sum (X * Y))
253
254				let mut acc0 = simd.splat_f32x8(0.0);
255				let mut acc1 = simd.splat_f32x8(0.0);
256				let mut acc2 = simd.splat_f32x8(0.0);
257				let mut acc3 = simd.splat_f32x8(0.0);
258
259				// 12 registers are being used
260				// 4 for accumulators + 2×4 inside the loop for x[0|1] and y[0|1]
261				let (x8_4, x8_1) = pulp::as_arrays::<4, _>(x8);
262				let (y8_4, y8_1) = pulp::as_arrays::<4, _>(y8);
263
264				for ([x0, x1, x2, x3], [y0, y1, y2, y3]) in iter::zip(x8_4, y8_4) {
265					let x0 = cast(*x0);
266					let y0 = cast(*y0);
267					let x1 = cast(*x1);
268					let y1 = cast(*y1);
269					let x2 = cast(*x2);
270					let y2 = cast(*y2);
271					let x3 = cast(*x3);
272					let y3 = cast(*y3);
273
274					acc0 = simd.mul_add_f32x8(x0, y0, acc0);
275					acc1 = simd.mul_add_f32x8(x1, y1, acc1);
276					acc2 = simd.mul_add_f32x8(x2, y2, acc2);
277					acc3 = simd.mul_add_f32x8(x3, y3, acc3);
278				}
279
280				for (x0, y0) in iter::zip(x8_1, y8_1) {
281					let x0 = cast(*x0);
282					let y0 = cast(*y0);
283					acc0 = simd.mul_add_f32x8(x0, y0, acc0);
284				}
285
286				// reduce_sum_f32s
287				// f32x8 -> f32x4 + f32x4
288				// f32x4 -> f32x2 + f32x2
289				// f32x2 -> f32 + f32
290				acc0 = simd.add_f32x8(acc0, acc1);
291				acc2 = simd.add_f32x8(acc2, acc3);
292
293				acc0 = simd.add_f32x8(acc0, acc2);
294
295				let mut acc = simd.reduce_sum_f32s(acc0);
296
297				for (x, y) in iter::zip(x1, y1) {
298					acc += x * y;
299				}
300
301				acc
302			}
303		}
304
305		simd.vectorize(Impl { simd, x, y })
306	}
307
308	pub fn bench_dot_simd(bencher: Bencher, PlotArg(n): PlotArg) {
309		let x = &*vec![1.0_f32; n];
310		let y = &*vec![1.0_f32; n];
311
312		if let Some(simd) = V3::try_new() {
313			bencher.bench(|| dot_product_simd_v3(simd, x, y))
314		} else {
315			bencher.skip();
316		}
317	}
318	pub fn bench_dot_simd_extract_reduce(bencher: Bencher, PlotArg(n): PlotArg) {
319		let x = &*vec![1.0_f32; n];
320		let y = &*vec![1.0_f32; n];
321
322		if let Some(simd) = V3::try_new() {
323			bencher.bench(|| dot_product_simd_extract_reduce_v3(simd, x, y))
324		} else {
325			bencher.skip();
326		}
327	}
328
329	pub fn bench_dot_simd_extract_reduce_ilp(bencher: Bencher, PlotArg(n): PlotArg) {
330		let x = &*vec![1.0_f32; n];
331		let y = &*vec![1.0_f32; n];
332
333		if let Some(simd) = V3::try_new() {
334			bencher.bench(|| dot_product_simd_extract_reduce_ilp_v3(simd, x, y))
335		} else {
336			bencher.skip();
337		}
338	}
339
340	pub fn bench_dot_simd_extract_reduce_ilp_fma(bencher: Bencher, PlotArg(n): PlotArg) {
341		let x = &*vec![1.0_f32; n];
342		let y = &*vec![1.0_f32; n];
343
344		dbg!(x.as_ptr().addr() % 32);
345		dbg!(y.as_ptr().addr() % 32);
346
347		if let Some(simd) = V3::try_new() {
348			bencher.bench(|| dot_product_simd_extract_reduce_ilp_fma_v3(simd, x, y))
349		} else {
350			bencher.skip();
351		}
352	}
353
354	pub fn bench_dot_simd_extract_reduce_ilp_fma_misaligned(bencher: Bencher, PlotArg(n): PlotArg) {
355		let x = &avec![1.0_f32; n + 1][1..];
356		let y = &avec![1.0_f32; n + 1][1..];
357
358		if let Some(simd) = V3::try_new() {
359			bencher.bench(|| dot_product_simd_extract_reduce_ilp_fma_v3(simd, x, y))
360		} else {
361			bencher.skip();
362		}
363	}
364
365	pub fn bench_dot_simd_extract_reduce_ilp_fma_aligned(bencher: Bencher, PlotArg(n): PlotArg) {
366		// aligned memory is more efficient for simd loads and stores
367		// always use this for benchmarks
368		//
369		// we didn't use it in the previous benchmarks just to showcase the difference
370		let x = &*avec![1.0_f32; n];
371		let y = &*avec![1.0_f32; n];
372
373		if let Some(simd) = V3::try_new() {
374			bencher.bench(|| dot_product_simd_extract_reduce_ilp_fma_v3(simd, x, y))
375		} else {
376			bencher.skip();
377		}
378	}
379}
380
381// fma: fused multiply add
382pub fn dot_product_simd_extract_reduce_ilp_fma_generic<S: Simd>(
383	simd: S,
384	x: &[f32],
385	y: &[f32],
386) -> f32 {
387	struct Impl<'a, S> {
388		simd: S,
389		x: &'a [f32],
390		y: &'a [f32],
391	}
392
393	impl<S: Simd> pulp::NullaryFnOnce for Impl<'_, S> {
394		type Output = f32;
395
396		#[inline(always)]
397		fn call(self) -> Self::Output {
398			let Self { simd, x, y } = self;
399
400			// x0 x1 x2.. x100
401			// [[x0 x1 x2 x3 x4 x5 x6 x7] [x8 x9 x10.. x15] [..x95]] | [x96 x97 x98 x99 x100]
402			let (xs, x1) = S::as_simd_f32s(x);
403			let (ys, y1) = S::as_simd_f32s(y);
404
405			// sum (x * y)
406			// sum (reduce_sum(X * Y))
407			// reduce_sum(sum (X * Y))
408
409			let mut acc0 = simd.splat_f32s(0.0);
410			let mut acc1 = simd.splat_f32s(0.0);
411			let mut acc2 = simd.splat_f32s(0.0);
412			let mut acc3 = simd.splat_f32s(0.0);
413
414			// 12 registers are being used
415			// 4 for accumulators + 2×4 inside the loop for x[0|1] and y[0|1]
416			let (xs_4, xs_1) = pulp::as_arrays::<4, _>(xs);
417			let (ys_4, ys_1) = pulp::as_arrays::<4, _>(ys);
418
419			for ([x0, x1, x2, x3], [y0, y1, y2, y3]) in iter::zip(xs_4, ys_4) {
420				acc0 = simd.mul_add_f32s(*x0, *y0, acc0);
421				acc1 = simd.mul_add_f32s(*x1, *y1, acc1);
422				acc2 = simd.mul_add_f32s(*x2, *y2, acc2);
423				acc3 = simd.mul_add_f32s(*x3, *y3, acc3);
424			}
425
426			for (x0, y0) in iter::zip(xs_1, ys_1) {
427				acc0 = simd.mul_add_f32s(*x0, *y0, acc0);
428			}
429
430			// reduce_sum_f32s
431			// f32x8 -> f32x4 + f32x4
432			// f32x4 -> f32x2 + f32x2
433			// f32x2 -> f32 + f32
434			acc0 = simd.add_f32s(acc0, acc1);
435			acc2 = simd.add_f32s(acc2, acc3);
436
437			acc0 = simd.add_f32s(acc0, acc2);
438
439			let mut acc = simd.reduce_sum_f32s(acc0);
440
441			for (x, y) in iter::zip(x1, y1) {
442				acc += x * y;
443			}
444
445			acc
446		}
447	}
448
449	simd.vectorize(Impl { simd, x, y })
450}
451
452// fma: fused multiply add
453pub fn dot_product_simd_extract_reduce_ilp_fma_epilogue_generic<S: Simd>(
454	simd: S,
455	x: &[f32],
456	y: &[f32],
457) -> f32 {
458	struct Impl<'a, S> {
459		simd: S,
460		x: &'a [f32],
461		y: &'a [f32],
462	}
463
464	impl<S: Simd> pulp::NullaryFnOnce for Impl<'_, S> {
465		type Output = f32;
466
467		#[inline(always)]
468		fn call(self) -> Self::Output {
469			let Self { simd, x, y } = self;
470
471			// x0 x1 x2.. x100
472			// [[x0 x1 x2 x3 x4 x5 x6 x7] [x8 x9 x10.. x15] [..x95]] | [x96 x97 x98 x99 x100]
473			let (xs, x1) = S::as_simd_f32s(x);
474			let (ys, y1) = S::as_simd_f32s(y);
475
476			// sum (x * y)
477			// sum (reduce_sum(X * Y))
478			// reduce_sum(sum (X * Y))
479
480			let mut acc0 = simd.splat_f32s(0.0);
481			let mut acc1 = simd.splat_f32s(0.0);
482			let mut acc2 = simd.splat_f32s(0.0);
483			let mut acc3 = simd.splat_f32s(0.0);
484
485			// 12 registers are being used
486			// 4 for accumulators + 2×4 inside the loop for x[0|1] and y[0|1]
487			let (xs_4, xs_1) = pulp::as_arrays::<4, _>(xs);
488			let (ys_4, ys_1) = pulp::as_arrays::<4, _>(ys);
489
490			for ([x0, x1, x2, x3], [y0, y1, y2, y3]) in iter::zip(xs_4, ys_4) {
491				acc0 = simd.mul_add_f32s(*x0, *y0, acc0);
492				acc1 = simd.mul_add_f32s(*x1, *y1, acc1);
493				acc2 = simd.mul_add_f32s(*x2, *y2, acc2);
494				acc3 = simd.mul_add_f32s(*x3, *y3, acc3);
495			}
496
497			for (x0, y0) in iter::zip(xs_1, ys_1) {
498				acc0 = simd.mul_add_f32s(*x0, *y0, acc0);
499			}
500
501			// reduce_sum_f32s
502			// f32x8 -> f32x4 + f32x4
503			// f32x4 -> f32x2 + f32x2
504			// f32x2 -> f32 + f32
505			acc0 = simd.add_f32s(acc0, acc1);
506			acc2 = simd.add_f32s(acc2, acc3);
507
508			acc0 = simd.add_f32s(acc0, acc2);
509
510			if !x1.is_empty() {
511				acc0 =
512					simd.mul_add_f32s(simd.partial_load_f32s(x1), simd.partial_load_f32s(y1), acc0);
513			}
514
515			simd.reduce_sum_f32s(acc0)
516		}
517	}
518
519	simd.vectorize(Impl { simd, x, y })
520}
521
522pub fn bench_dot_simd_extract_reduce_ilp_fma_aligned_runtime_dispatch(
523	bencher: Bencher,
524	PlotArg(n): PlotArg,
525) {
526	let x = &*avec![1.0_f32; n];
527	let y = &*avec![1.0_f32; n];
528
529	let arch = Arch::new();
530
531	struct Impl<'a> {
532		x: &'a [f32],
533		y: &'a [f32],
534	}
535
536	impl pulp::WithSimd for Impl<'_> {
537		type Output = f32;
538
539		#[inline(always)]
540		fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
541			let Self { x, y } = self;
542			dot_product_simd_extract_reduce_ilp_fma_generic(simd, x, y)
543		}
544	}
545
546	bencher.bench(|| arch.dispatch(Impl { x, y }));
547}
548
549pub fn bench_dot_simd_extract_reduce_ilp_fma_epilogue_aligned_runtime_dispatch(
550	bencher: Bencher,
551	PlotArg(n): PlotArg,
552) {
553	let x = &*avec![1.0_f32; n];
554	let y = &*avec![1.0_f32; n];
555
556	let arch = Arch::new();
557
558	struct Impl<'a> {
559		x: &'a [f32],
560		y: &'a [f32],
561	}
562
563	impl pulp::WithSimd for Impl<'_> {
564		type Output = f32;
565
566		#[inline(always)]
567		fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
568			let Self { x, y } = self;
569			dot_product_simd_extract_reduce_ilp_fma_epilogue_generic(simd, x, y)
570		}
571	}
572
573	bencher.bench(|| arch.dispatch(Impl { x, y }));
574}
575fn main() -> std::io::Result<()> {
576	let mut bench = Bench::new(BenchConfig::from_args()?);
577
578	let mut params = vec![];
579	for i in 1..=16 {
580		params.push(i);
581	}
582	for i in 2..=16 {
583		params.push(16 * i);
584	}
585	for i in 2..=16 {
586		params.push(256 * i);
587	}
588
589	#[cfg(target_arch = "x86_64")]
590	bench.register_many(
591		list![
592			bench_dot_scalar,
593			x86::bench_dot_simd,
594			x86::bench_dot_simd_extract_reduce,
595			x86::bench_dot_simd_extract_reduce_ilp,
596			x86::bench_dot_simd_extract_reduce_ilp_fma,
597			x86::bench_dot_simd_extract_reduce_ilp_fma_misaligned,
598			x86::bench_dot_simd_extract_reduce_ilp_fma_aligned,
599			bench_dot_simd_extract_reduce_ilp_fma_aligned_runtime_dispatch,
600			bench_dot_simd_extract_reduce_ilp_fma_epilogue_aligned_runtime_dispatch,
601		],
602		params.iter().copied().map(PlotArg),
603	);
604
605	#[cfg(not(target_arch = "x86_64"))]
606	bench.register_many(
607		list![
608			bench_dot_scalar,
609			bench_dot_simd_extract_reduce_ilp_fma_aligned_runtime_dispatch,
610			bench_dot_simd_extract_reduce_ilp_fma_epilogue_aligned_runtime_dispatch,
611		],
612		params.iter().copied().map(PlotArg),
613	);
614
615	bench.run()?;
616	Ok(())
617}