mistralrs-quant 0.8.1

Fast, flexible LLM inference.
Documentation
macro_rules! dequant_kernel {
    ($wq:ty, $scalar:ty, $postfix:tt) => {
        paste! {
            pub(crate) fn [< dequantize_ $postfix >](
                wq_packed: *const $wq,
                scale: *const $scalar,
                zero: *const $scalar,
                out: *const $scalar,
                h: i32,
                w: i32
            );
        }
    };
}

pub mod eight_bit {
    use half::{bf16, f16};
    use paste::paste;

    #[allow(dead_code)]
    extern "C" {
        dequant_kernel!(u8, f32, 8bit_u8_kernel_f32);
        dequant_kernel!(u8, f16, 8bit_u8_kernel_f16);
        dequant_kernel!(u8, bf16, 8bit_u8_kernel_bf16);
    }
}

pub mod four_bit {
    use half::{bf16, f16};
    use paste::paste;

    #[allow(dead_code)]
    extern "C" {
        dequant_kernel!(u8, f32, 4bit_u8_kernel_f32);
        dequant_kernel!(u8, f16, 4bit_u8_kernel_f16);
        dequant_kernel!(u8, bf16, 4bit_u8_kernel_bf16);
    }
}

pub mod three_bit {
    use half::{bf16, f16};
    use paste::paste;

    #[allow(dead_code)]
    extern "C" {
        dequant_kernel!(i32, f32, 3bit_32_kernel_f32);
        dequant_kernel!(i32, f16, 3bit_32_kernel_f16);
        dequant_kernel!(i32, bf16, 3bit_32_kernel_bf16);
    }
}

pub mod two_bit {
    use half::{bf16, f16};
    use paste::paste;

    #[allow(dead_code)]
    extern "C" {
        dequant_kernel!(u8, f32, 2bit_u8_kernel_f32);
        dequant_kernel!(u8, f16, 2bit_u8_kernel_f16);
        dequant_kernel!(u8, bf16, 2bit_u8_kernel_bf16);
    }
}

pub mod one_bit {
    use half::{bf16, f16};
    use paste::paste;

    #[allow(dead_code)]
    extern "C" {
        dequant_kernel!(u8, f32, 1bit_u8_kernel_f32);
        dequant_kernel!(u8, f16, 1bit_u8_kernel_f16);
        dequant_kernel!(u8, bf16, 1bit_u8_kernel_bf16);
    }
}