pulp 0.21.1

Safe generic simd
Documentation
use std::path::Path;
use std::{env, fs};

fn load_f32x4(out: usize, offset: usize, start: i32, end: i32) -> String {
	let mut f = String::new();

	let mut start = Ord::clamp(start, 0, 4);
	let mut end = Ord::clamp(end, 0, 4);
	let mut offset = offset;
	let a = out;

	let shift = (start as usize) * 4;
	offset += shift;
	end -= start;
	start = 0;

	match (start, end) {
		(0, 4) => f += &format!("vmovups xmm{a}, [rax + {offset}]\n"),
		(0, 3) => {
			f += &format!("vmovsd xmm{a}, [rax + {offset}]\n");
			f += &format!("vpinsrd xmm{a}, xmm{a}, [rax + {offset} + 8], 0x2\n");
		},
		(0, 2) => f += &format!("vmovsd xmm{a}, [rax + {offset}]\n"),
		(0, 1) => f += &format!("vmovss xmm{a}, [rax + {offset}]\n"),

		_ => panic!(),
	}

	if shift > 0 {
		f += &format!("vpslldq xmm{a}, xmm{a}, {shift}\n");
	}

	f
}

fn store_f32x4(inp: usize, offset: usize, start: i32, end: i32) -> String {
	let mut f = String::new();

	let mut start = Ord::clamp(start, 0, 4);
	let mut end = Ord::clamp(end, 0, 4);
	let mut offset = offset;
	let a = inp;

	if start > 0 {
		let shift = (start as usize) * 4;
		offset += shift;
		end -= start;
		start = 0;

		f += &format!("vpsrldq xmm{a}, xmm{a}, {shift}\n");
	}

	match (start, end) {
		(0, 4) => f += &format!("vmovups [rax + {offset}], xmm{a}\n"),
		(0, 3) => {
			f += &format!("vmovsd [rax + {offset}], xmm{a}\n");
			f += &format!("vpsrldq xmm{a}, xmm{a}, 8\n");
			f += &format!("vmovss [rax + {offset} + 8], xmm{a}\n");
		},
		(0, 2) => f += &format!("vmovsd [rax + {offset}], xmm{a}\n"),
		(0, 1) => f += &format!("vmovss [rax + {offset}], xmm{a}\n"),

		_ => panic!(),
	}
	f
}

fn load_f32x16(start: i32, end: i32) -> String {
	let mut start = start;
	let mut init = false;

	let mut f = String::new();
	if start == 0 && end == 16 {
		f += "vmovups zmm0, [rax]\n";
		init = true;
		start = 16;
	}

	if start == 0 && end >= 8 {
		f += "vmovups ymm0, [rax]\n";
		init = true;
		start = 8;
	}
	if start < 4 && end > start {
		f += &load_f32x4(0, 0, start, end);
		init = true;
		start = 4;
	}
	if !init {
		f += "vxorps xmm0, xmm0, xmm0\n";
	}

	if start == 4 && end >= 8 {
		f += "vinsertf128 ymm0, ymm0, [rax + 16], 0x1\n";
		start = 8;
	}
	if start < 8 && end > start {
		f += &load_f32x4(1, 16, start - 4, end - 4);
		f += "vinsertf128 ymm0, ymm0, xmm1, 0x1\n";
		start = 8;
	}

	if start == 8 && end == 16 {
		f += "vinsertf64x4 zmm0, zmm0, [rax + 32], 0x1\n";
		start = 16;
	}
	if start == 8 && end == 12 {
		f += "vinsertf64x2 zmm0, zmm0, [rax + 32], 0x2\n";
		start = 12;
	}
	if start < 12 && end > start {
		f += &load_f32x4(1, 32, start - 8, end - 8);
		f += "vinsertf64x2 zmm0, zmm0, xmm1, 0x2\n";
		start = 12;
	}

	if start == 12 && end == 16 {
		f += "vinsertf64x2 zmm0, zmm0, [rax + 48], 0x3\n";
		start = 16;
	}
	if start < 16 && end > start {
		f += &load_f32x4(1, 48, start - 12, end - 12);
		f += "vinsertf64x2 zmm0, zmm0, xmm1, 0x3\n";
		start = 12;
	}

	_ = start;

	f
}

fn store_f32x16(start: i32, end: i32) -> String {
	let mut end = end;

	let mut f = String::new();
	if start == 0 && end == 16 {
		f += "vmovups [rax], zmm0\n";
		end = 0;
	}
	if start <= 8 && end == 16 {
		f += "vextractf64x4 ymm1, zmm0, 0x1\n";
		f += "vmovups [rax + 32], ymm1\n";
		end = 8;
	}
	if end > 12 && end > start {
		f += "vextractf64x2 xmm1, zmm0, 0x3\n";
		f += &store_f32x4(1, 48, start - 12, end - 12);
		end = 12;
	}
	if end > 8 && end > start {
		f += "vextractf64x2 xmm1, zmm0, 0x2\n";
		f += &store_f32x4(1, 32, start - 8, end - 8);
		end = 8;
	}

	if start == 0 && end == 8 {
		f += "vmovups [rax], ymm0\n";
		end = 0;
	}
	if end > 4 && end > start {
		f += "vextractf128 xmm1, ymm0, 0x1\n";
		f += &store_f32x4(1, 16, start - 4, end - 4);
		end = 4;
	}
	if end > start {
		f += &store_f32x4(0, 0, start, end);
		end = 0;
	}
	_ = end;

	f
}

fn main() {
	let out_dir = env::var_os("OUT_DIR").unwrap();
	let dest_path = Path::new(&out_dir).join("x86_64_asm.rs");

	let mut f = String::new();

	f += "core::arch::global_asm! {\"\n";

	let max = 16;
	let mut wrote_empty = false;

	let mut names = vec![];
	let ver_major = env::var("CARGO_PKG_VERSION_MAJOR").unwrap();
	let ver_minor = env::var("CARGO_PKG_VERSION_MINOR").unwrap();
	let ver_patch = env::var("CARGO_PKG_VERSION_PATCH").unwrap();
	let ver = format!("v{ver_major}_{ver_minor}_{ver_patch}");

	for end in 1..=max {
		for start in 0..max {
			let mut mask = 0u16;
			for i in 0..16 {
				if i >= start && i < end {
					mask |= 1 << (15 - i);
				}
			}

			if mask != 0 || !wrote_empty {
				let ld = format!("libpulp_{ver}_ld_b32s_{mask:0>16b}");
				f += &format!(".global {ld}\n");
				f += &format!("{ld}:\n");
				f += &load_f32x16(start, end);
				f += "jmp rcx\n";
				names.push(ld);

				let st = format!("libpulp_{ver}_st_b32s_{mask:0>16b}");
				f += &format!(".global {st}\n");
				f += &format!("{st}:\n");
				f += &store_f32x16(start, end);
				f += "jmp rcx\n";
				names.push(st);
			}

			if mask == 0 {
				wrote_empty = true;
			}
		}
	}
	f += "\"}\n";

	f += "unsafe extern \"C\" {\n";
	for name in &names {
		f += &format!("#[link_name = \"\\x01{name}\"] fn {name}();\n");
	}
	f += "}\n";

	f += &format!("static LD_ST: [unsafe extern \"C\" fn(); 2 * (({max} + 1) * {max})] = [\n");

	for end in 0..=max {
		for start in 0..max {
			let mut mask = 0u16;
			for i in 0..16 {
				if i >= start && i < end {
					mask |= 1 << (15 - i);
				}
			}

			let ld = format!("libpulp_{ver}_ld_b32s_{mask:0>16b}");
			let st = format!("libpulp_{ver}_st_b32s_{mask:0>16b}");
			f += &format!("{ld}, {st},\n");
		}
	}
	f += "];\n";

	fs::write(&dest_path, &f).unwrap();

	println!("cargo::rerun-if-changed=build.rs");
}