unsigned-float 0.2.1

Unsigned floating-point formats for non-negative numeric domains.
Documentation
use std::env;
use std::fs;
use std::path::{Path, PathBuf};

#[path = "src/dispatch/soft.rs"]
mod soft;

fn main() {
    println!("cargo:rerun-if-changed=build.rs");

    let out_dir = PathBuf::from(env::var_os("OUT_DIR").expect("OUT_DIR is set by Cargo"));

    write_uf8_layout_tables(&out_dir, "uf8_e4m4", 4, 4, 7);
    write_uf8_layout_tables(&out_dir, "uf8_e5m3", 5, 3, 15);
    write_uf8_cross_layout_pow_tables(&out_dir, "uf8_e4m4", "uf8_e5m3", (4, 4, 7), (5, 3, 15));
    write_uf8_cross_layout_pow_tables(&out_dir, "uf8_e5m3", "uf8_e4m4", (5, 3, 15), (4, 4, 7));
}

fn write_uf8_layout_tables(
    out_dir: &Path,
    name: &str,
    exp_bits: u32,
    mantissa_bits: u32,
    bias: i32,
) {
    write_uf8_table(
        out_dir,
        &format!("{name}_add.bin"),
        exp_bits,
        mantissa_bits,
        bias,
        |a, b| a + b,
    );
    write_uf8_table(
        out_dir,
        &format!("{name}_sub.bin"),
        exp_bits,
        mantissa_bits,
        bias,
        |a, b| a - b,
    );
    write_uf8_table(
        out_dir,
        &format!("{name}_mul.bin"),
        exp_bits,
        mantissa_bits,
        bias,
        |a, b| a * b,
    );
    write_uf8_table(
        out_dir,
        &format!("{name}_div.bin"),
        exp_bits,
        mantissa_bits,
        bias,
        |a, b| a / b,
    );
    write_uf8_table(
        out_dir,
        &format!("{name}_pow.bin"),
        exp_bits,
        mantissa_bits,
        bias,
        powuf_f32,
    );
    write_uf8_table(
        out_dir,
        &format!("{name}_pow1m.bin"),
        exp_bits,
        mantissa_bits,
        bias,
        pow1muf_f32,
    );
}

fn write_uf8_table(
    out_dir: &Path,
    file_name: &str,
    exp_bits: u32,
    mantissa_bits: u32,
    bias: i32,
    op: fn(f32, f32) -> f32,
) {
    let mut table = vec![0; 256 * 256];

    for a in u8::MIN..=u8::MAX {
        for b in u8::MIN..=u8::MAX {
            let index = table_index(a, b);
            let a = uf8_to_f32(a, exp_bits, mantissa_bits, bias);
            let b = uf8_to_f32(b, exp_bits, mantissa_bits, bias);
            table[index] = f32_to_uf8(op(a, b), exp_bits, mantissa_bits, bias);
        }
    }

    fs::write(out_dir.join(file_name), table).expect("write UF8 lookup table");
}

fn write_uf8_cross_layout_pow_tables(
    out_dir: &Path,
    base_name: &str,
    exponent_name: &str,
    base_layout: (u32, u32, i32),
    exponent_layout: (u32, u32, i32),
) {
    write_uf8_cross_layout_table(
        out_dir,
        &format!("{base_name}_pow_{exponent_name}.bin"),
        base_layout,
        exponent_layout,
        powuf_f32,
    );
    write_uf8_cross_layout_table(
        out_dir,
        &format!("{base_name}_pow1m_{exponent_name}.bin"),
        base_layout,
        exponent_layout,
        pow1muf_f32,
    );
}

fn write_uf8_cross_layout_table(
    out_dir: &Path,
    file_name: &str,
    base_layout: (u32, u32, i32),
    exponent_layout: (u32, u32, i32),
    op: fn(f32, f32) -> f32,
) {
    let mut table = vec![0; 256 * 256];
    let (base_exp_bits, base_mantissa_bits, base_bias) = base_layout;
    let (exponent_exp_bits, exponent_mantissa_bits, exponent_bias) = exponent_layout;

    for a in u8::MIN..=u8::MAX {
        for b in u8::MIN..=u8::MAX {
            let index = table_index(a, b);
            let a = uf8_to_f32(a, base_exp_bits, base_mantissa_bits, base_bias);
            let b = uf8_to_f32(b, exponent_exp_bits, exponent_mantissa_bits, exponent_bias);
            table[index] = f32_to_uf8(op(a, b), base_exp_bits, base_mantissa_bits, base_bias);
        }
    }

    fs::write(out_dir.join(file_name), table).expect("write UF8 cross-layout lookup table");
}

fn table_index(a: u8, b: u8) -> usize {
    ((a as usize) << 8) | b as usize
}

fn uf8_to_f32(bits: u8, exp_bits: u32, mantissa_bits: u32, bias: i32) -> f32 {
    soft::decode_to_f64(bits as u64, exp_bits, mantissa_bits, bias) as f32
}

fn f32_to_uf8(value: f32, exp_bits: u32, mantissa_bits: u32, bias: i32) -> u8 {
    soft::encode_from_f64(value as f64, exp_bits, mantissa_bits, bias) as u8
}

fn powuf_f32(base: f32, exponent: f32) -> f32 {
    if exponent == 0.0 {
        1.0
    } else if exponent == 1.0 {
        base
    } else if exponent == 0.5 {
        libm::sqrtf(base)
    } else if let Some(integer) = small_integer_exponent_f32(exponent) {
        powi_u32_f32(base, integer)
    } else {
        libm::powf(base, exponent)
    }
}

fn pow1muf_f32(u: f32, exponent: f32) -> f32 {
    if exponent == 0.0 || u == 0.0 {
        1.0
    } else if u == 1.0 {
        0.0
    } else if !(0.0..=1.0).contains(&u) {
        powuf_f32(1.0 - u, exponent)
    } else if exponent == 0.5 {
        libm::sqrtf(1.0 - u)
    } else {
        libm::expm1f(exponent * libm::log1pf(-u)) + 1.0
    }
}

fn small_integer_exponent_f32(exponent: f32) -> Option<u32> {
    if !(2.0..=32.0).contains(&exponent) {
        return None;
    }

    let integer = exponent as u32;
    if integer as f32 == exponent {
        Some(integer)
    } else {
        None
    }
}

fn powi_u32_f32(mut base: f32, mut exponent: u32) -> f32 {
    let mut acc = 1.0;

    while exponent != 0 {
        if exponent & 1 == 1 {
            acc *= base;
        }

        exponent >>= 1;
        if exponent != 0 {
            base *= base;
        }
    }

    acc
}