#![allow(clippy::box_default)]
use liquid_core::Runtime;
use liquid_core::{Display_filter, Filter, FilterReflection, ParseFilter};
use liquid_core::{Value, ValueView};
use std::{env, ffi, fs, path};
fn var(k: &str) -> String {
env::var(k).unwrap()
}
fn use_masm() -> bool {
env::var("CARGO_CFG_TARGET_ENV") == Ok("msvc".to_string()) && var("HOST").contains("-windows-")
}
fn jump_table() -> Vec<String> {
println!("cargo:rerun-if-changed=src/frame/mmm/fuse.rs");
std::fs::read_to_string("src/frame/mmm/fuse.rs")
.unwrap()
.lines()
.filter(|l| l.contains("// jump_to:"))
.map(|l| l.split("jump_to:").nth(1).unwrap().to_owned())
.collect()
}
#[derive(Clone, Debug)]
struct ConfigForHalf {
extra_flags: Vec<String>,
needs_pragma: bool,
}
impl ConfigForHalf {
fn new(extra_flags: Vec<String>, needs_pragma: bool) -> ConfigForHalf {
ConfigForHalf { extra_flags, needs_pragma }
}
fn all() -> Vec<ConfigForHalf> {
let mut configs = vec![];
for extra_flags in
vec![vec![], vec!["-march=armv8.2-a".to_string()], vec!["-mcpu=cortex-a55".to_string()]]
{
for needs_pragma in [false, true] {
configs.push(ConfigForHalf::new(extra_flags.clone(), needs_pragma))
}
}
configs
}
fn cc(&self) -> cc::Build {
let mut cc = cc::Build::new();
for flag in &self.extra_flags {
cc.flag(flag);
}
cc
}
fn works(&self) -> bool {
let filename = if self.needs_pragma {
"arm64/arm64fp16/dummy_fmla_pragma.S"
} else {
"arm64/arm64fp16/dummy_fmla_no_pragma.S"
};
self.cc().static_flag(true).file(filename).try_compile("dummy").is_ok()
}
pub fn probe() -> Option<ConfigForHalf> {
Self::all().iter().find(|c| c.works()).cloned()
}
}
fn main() {
let target = var("TARGET");
let arch = var("CARGO_CFG_TARGET_ARCH");
let os = var("CARGO_CFG_TARGET_OS");
let out_dir = path::PathBuf::from(var("OUT_DIR"));
let suffix = env!("CARGO_PKG_VERSION").replace('-', "_").replace('.', "_");
make_extern_kernel_decl_macro(&out_dir, &suffix);
match arch.as_ref() {
"x86_64" => {
let files = preprocess_files("x86_64/fma", &[], &suffix, false);
if os == "windows" {
if use_masm() {
let mut lib_exe = cc::windows_registry::find(&target, "lib.exe")
.expect("Could not find lib.exe");
lib_exe
.arg(format!("/out:{}", out_dir.join("x86_64_fma.lib").to_str().unwrap()));
for f in files {
let mut obj = f.clone();
obj.set_extension("o");
let mut ml_exe = cc::windows_registry::find(&target, "ml64.exe")
.expect("Could not find ml64.exe");
if !ml_exe
.arg("/Fo")
.arg(&obj)
.arg("/c")
.arg(&f)
.status()
.unwrap()
.success()
{
for (i, l) in std::fs::read_to_string(&f).unwrap().lines().enumerate() {
println!("{:8} {}", i, l);
}
panic!();
}
lib_exe.arg(obj);
}
assert!(lib_exe.status().unwrap().success());
println!("cargo:rustc-link-search=native={}", out_dir.to_str().unwrap());
println!("cargo:rustc-link-lib=static=x86_64_fma");
} else {
cc::Build::new()
.files(files)
.flag("-mfma")
.static_flag(true)
.compile("x86_64_fma");
let _ = fs::remove_file("fma_mmm_f32_16x6.asm");
let _ = fs::remove_file("fma_mmm_i32_8x8.asm");
let _ = fs::remove_file("fma_sigmoid_f32.asm");
let _ = fs::remove_file("fma_tanh_f32.asm");
}
} else {
cc::Build::new().files(files).flag("-mfma").static_flag(true).compile("x86_64_fma");
}
}
"arm" | "armv7" => {
let files = preprocess_files("arm32/armvfpv2", &[], &suffix, false);
cc::Build::new()
.files(files)
.flag("-marm")
.flag("-mfpu=vfp")
.static_flag(true)
.compile("armvfpv2");
let files = preprocess_files(
"arm32/armv7neon",
&[("core", vec!["cortexa7", "cortexa9", "generic"])],
&suffix,
false,
);
cc::Build::new()
.files(files)
.flag("-marm")
.flag("-mfpu=neon")
.static_flag(true)
.compile("armv7neon");
}
"aarch64" => {
let files = preprocess_files(
"arm64/arm64simd",
&[("core", vec!["a53", "a55", "gen"])],
&suffix,
false,
);
cc::Build::new().files(files).static_flag(true).compile("arm64simd");
if std::env::var("CARGO_FEATURE_NO_FP16").is_err() {
let config =
ConfigForHalf::probe().expect("No configuration found for fp16 support");
let files = preprocess_files(
"arm64/arm64fp16",
&[("core", vec!["a55", "gen"])],
&suffix,
config.needs_pragma,
);
config.cc().files(files).static_flag(true).compile("arm64fp16")
}
}
_ => {}
}
}
type Variant = (&'static str, Vec<&'static str>);
fn preprocess_files(
input: impl AsRef<path::Path>,
variants: &[Variant],
suffix: &str,
needs_pragma: bool,
) -> Vec<path::PathBuf> {
let out_dir = path::PathBuf::from(var("OUT_DIR"));
let mut files = vec![];
let dir_entries = {
let mut dir_entries: Vec<fs::DirEntry> =
input.as_ref().read_dir().unwrap().map(|f| f.unwrap()).collect();
dir_entries.sort_by_key(|a| a.path());
dir_entries
};
for f in dir_entries {
if f.path().extension() == Some(ffi::OsStr::new("tmpl")) {
let tmpl_file = f.path().file_name().unwrap().to_str().unwrap().to_owned();
let concerned_variants: Vec<&Variant> =
variants.iter().filter(|v| tmpl_file.contains(v.0)).collect();
let expanded_variants = concerned_variants.iter().map(|pair| pair.1.len()).product();
for v in 0..expanded_variants {
let mut tmpl_file = tmpl_file.clone();
let mut id = v;
let mut globals = vec![];
for variable in variants {
let key = variable.0;
let value = variable.1[id % variable.1.len()];
globals.push((key, value));
tmpl_file = tmpl_file.replace(key, value);
id /= variable.1.len();
}
let mut file = out_dir.join(tmpl_file);
file.set_extension("S");
preprocess_file(f.path(), &file, &globals, suffix, needs_pragma);
files.push(file);
}
}
}
files
}
fn strip_comments(s: String, msvc: bool) -> String {
if msvc {
s.lines().map(|line| line.replace("//", ";")).collect::<Vec<String>>().join("\n")
} else {
s
}
}
fn preprocess_file(
template: impl AsRef<path::Path>,
output: impl AsRef<path::Path>,
variants: &[(&'static str, &'static str)],
suffix: &str,
needs_pragma: bool,
) {
println!("cargo:rerun-if-changed={}", template.as_ref().to_string_lossy());
let family = var("CARGO_CFG_TARGET_FAMILY");
let os = var("CARGO_CFG_TARGET_OS");
let msvc = use_masm();
println!("cargo:rerun-if-changed={}", template.as_ref().to_string_lossy());
let mut input = fs::read_to_string(&template).unwrap();
input = strip_comments(input, msvc);
let l = if os == "macos" {
"L"
} else if family == "windows" {
""
} else {
"."
}
.to_owned();
let long = if msvc { "dd" } else { ".long" };
let g = if os == "macos" || os == "ios" { "_" } else { "" };
let mut globals = liquid::object!({
"msvc": msvc,
"needs_pragma": needs_pragma,
"family": family,
"os": os,
"L": l,
"G": g,
"suffix": suffix,
"long": long,
"jump_table": jump_table(),
});
for (k, v) in variants {
globals.insert(k.to_string().into(), liquid::model::Value::scalar(*v));
}
let partials = load_partials(template.as_ref().parent().unwrap(), msvc);
if let Err(e) = liquid::ParserBuilder::with_stdlib()
.partials(liquid::partials::LazyCompiler::new(partials))
.filter(F16)
.build()
.and_then(|p| p.parse(&input))
.and_then(|r| r.render_to(&mut fs::File::create(&output).unwrap(), &globals))
{
eprintln!("Processing {}", template.as_ref().to_string_lossy());
eprintln!("{}", e);
panic!()
}
}
fn load_partials(p: &path::Path, msvc: bool) -> liquid::partials::InMemorySource {
let mut mem = liquid::partials::InMemorySource::new();
for f in walkdir::WalkDir::new(p) {
let f = f.unwrap();
if f.path().is_dir() {
continue;
}
let ext = f.path().extension().map(|s| s.to_string_lossy()).unwrap_or_else(|| "".into());
let text = std::fs::read_to_string(f.path()).unwrap();
let text = match ext.as_ref() {
"tmpli" => Some(text.replace("{{", "{").replace("}}", "}")),
"tmpliq" => Some(text),
_ => None,
};
if let Some(text) = text {
let text = strip_comments(text, msvc);
let key =
f.path().strip_prefix(p).unwrap().to_str().unwrap().to_owned().replace('\\', "/");
println!("cargo:rerun-if-changed={}", f.path().to_string_lossy().replace('\\', "/"));
mem.add(key, text);
}
}
mem
}
fn make_extern_kernel_decl_macro(out_dir: &path::Path, suffix: &str) {
let macro_decl = r#"
macro_rules! extern_kernel {
(fn $name: ident($($par_name:ident : $par_type: ty ),*) -> $rv: ty) => {
paste! {
extern "C" { pub fn [<$name _ _suffix>]($(par_name: $par_type),*) -> $rv; }
pub use [<$name _ _suffix>] as $name;
}
}
}"#
.replace("_suffix", suffix);
std::fs::write(out_dir.join("extern_kernel_macro.rs"), macro_decl).unwrap();
}
#[derive(Clone, ParseFilter, FilterReflection)]
#[filter(
name = "float16",
description = "Write a float16 constant with the .float16 directive in gcc, or as short in clang",
parsed(F16Filter)
)]
pub struct F16;
#[derive(Debug, Default, Display_filter)]
#[name = "float16"]
struct F16Filter;
impl Filter for F16Filter {
fn evaluate(
&self,
input: &dyn ValueView,
_runtime: &dyn Runtime,
) -> liquid_core::Result<Value> {
let input: f32 = input.as_scalar().unwrap().to_float().unwrap() as f32;
let value = half::f16::from_f32(input);
let bits = value.to_bits();
Ok(format!(".short {bits}").to_value())
}
}