llama-cpp-sys-4 0.3.1

Low Level Bindings to llama.cpp
Documentation
#ifdef U32_DEQUANT_HELPERS
#define SRC0_TYPE u32

fn byte_of(v: u32, b: u32) -> u32 {
    return (v >> (b * 8u)) & 0xFFu;
}

fn sbyte_of(v: u32, b: u32) -> i32 {
    let raw = i32((v >> (b * 8u)) & 0xFFu);
    return select(raw, raw - 256, raw >= 128);
}
#endif

#define SRC0_TYPE SRC0_INNER_TYPE
#define SRC1_TYPE SRC1_INNER_TYPE

#ifdef LEGACY_QUANTS
#define BLOCK_SIZE 32
#define THREADS_PER_BLOCK 4
#elif K_QUANTS
#define BLOCK_SIZE 256
#define THREADS_PER_BLOCK 16
#endif

#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
#define Q8_BLOCK_SIZE 32

#ifdef MUL_ACC_Q4_0
#define BLOCK_SIZE_BYTES 18
#define B_DS_TYPE vec2<f32>
fn repack_a(block_byte_base: u32, inner_id: u32) -> vec2<u32> {
    let qs_packed = load_u32_at_src0(block_byte_base + 2u + 4u * inner_id);

    return vec2<u32>(
        qs_packed & 0x0F0F0F0Fu,
        (qs_packed >> 4u) & 0x0F0F0F0Fu
    );
}
fn repack_b_qs(block:u32, inner_id: u32) -> vec2<u32> {
    return vec2<u32>(
            src1q[block].qs[inner_id],
            src1q[block].qs[inner_id + 4u],
        );
}
fn repack_b_dm(block: u32) -> B_DS_TYPE {
    return B_DS_TYPE(
        f32(src1q[block].d),
        f32(src1q[block].s)
    );
}
fn get_dm(block_byte_base: u32) -> f32 {
    return f32(load_f16_at_src0(block_byte_base));
}
fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 {
    return f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK;
}
#endif

#ifdef MUL_ACC_Q4_1
#define BLOCK_SIZE_BYTES 20
#define B_DS_TYPE vec2<f32>
fn repack_a(block_byte_base: u32, inner_id: u32) -> vec2<u32> {
    let qs_packed = load_u32_at_src0(block_byte_base + 4u + 4u * inner_id);

    return vec2<u32>(
        qs_packed & 0x0F0F0F0Fu,
        (qs_packed >> 4u) & 0x0F0F0F0Fu
    );
}
fn repack_b_qs(block:u32, inner_id: u32) -> vec2<u32> {
    return vec2<u32>(
            src1q[block].qs[inner_id],
            src1q[block].qs[inner_id + 4u],
        );
}
fn repack_b_dm(block: u32) -> B_DS_TYPE {
    return B_DS_TYPE(
        f32(src1q[block].d),
        f32(src1q[block].s)
    );
}
fn get_dm(block_byte_base: u32) -> vec2<f32> {
    return vec2<f32>(
        f32(load_f16_at_src0(block_byte_base)),
        f32(load_f16_at_src0(block_byte_base + 2u))
    );
}
fn mul_q8_1(row_sum: i32, dma: vec2<f32>, b_ds: B_DS_TYPE) -> f32 {
    return f32(row_sum) * (dma.x * b_ds.x) + dma.y * b_ds.y / THREADS_PER_BLOCK;
}
#endif

#ifdef MUL_ACC_Q8_0
#define BLOCK_SIZE_BYTES 34
#define B_DS_TYPE f32
fn repack_a(block_byte_base: u32, inner_id: u32) -> vec2<u32> {
    return vec2<u32>(
        load_u32_at_src0(block_byte_base + 2u + 4u * (inner_id * 2u)),
        load_u32_at_src0(block_byte_base + 2u + 4u * (inner_id * 2u + 1))
    );
}
fn repack_b_qs(block:u32, inner_id: u32) -> vec2<u32> {
    return vec2<u32>(
            src1q[block].qs[inner_id * 2u],
            src1q[block].qs[inner_id * 2u + 1],
        );
}
fn repack_b_dm(block: u32) -> B_DS_TYPE {
    return B_DS_TYPE(src1q[block].d);
}
fn get_dm(block_byte_base: u32) -> f32 {
    return f32(load_f16_at_src0(block_byte_base));
}
fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 {
    return f32(row_sum) * (da * b_ds);
}
#endif

#ifdef LEGACY_QUANTS
fn mmvq_dot_product(a_byte_base: u32, b_inner_id: u32, b_repacked: vec2<u32>, b_ds: B_DS_TYPE) -> f32 {
    var row_sum = 0;
    let a_repacked = repack_a(a_byte_base, b_inner_id);

    row_sum += dot4I8Packed(a_repacked[0], b_repacked[0]);
    row_sum += dot4I8Packed(a_repacked[1], b_repacked[1]);

    return mul_q8_1(row_sum, get_dm(a_byte_base), b_ds);
}

fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
    var acc: array<f32, OUTPUTS_PER_WG>;

    let num_blocks = params.k / BLOCK_SIZE;

    for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
        let b_inner_id = thread_id % THREADS_PER_BLOCK;
        let b_block_idx = src1q_idx_base + block;

        let b_repacked = repack_b_qs(b_block_idx, b_inner_id);
        let b_ds = repack_b_dm(b_block_idx);

        for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
            let output_row = row_base + row;
            if (output_row < params.m) {
                let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
                acc[row] += mmvq_dot_product(block_byte_base, b_inner_id, b_repacked, b_ds);
            }
        }
    }

    return acc;
}
#endif

#ifdef MUL_ACC_Q2_K
#define BLOCK_SIZE_BYTES 84
#define B_DS_TYPE f32
fn repack_a(block_byte_base: u32, tid: u32) -> vec4<u32> {
    let ih2 = tid / 8u;
    let phase = tid % 2u;
    let iq4_idx = 2u * ih2 + phase;
    let qs_byte_base = block_byte_base + 16u + 16u * iq4_idx;
    let qs_shift = tid & 6u;
    return vec4<u32>(
        (load_u32_at_src0_aligned(qs_byte_base) >> qs_shift) & 0x03030303u,
        (load_u32_at_src0_aligned(qs_byte_base + 4u) >> qs_shift) & 0x03030303u,
        (load_u32_at_src0_aligned(qs_byte_base + 8u) >> qs_shift) & 0x03030303u,
        (load_u32_at_src0_aligned(qs_byte_base + 12u) >> qs_shift) & 0x03030303u,
    );
}
fn repack_b_qs(q8_block_idx: u32, tid: u32) -> vec4<u32> {
    let phase = tid % 2u;
    return vec4<u32>(
        src1q[q8_block_idx].qs[4u * phase],
        src1q[q8_block_idx].qs[4u * phase + 1u],
        src1q[q8_block_idx].qs[4u * phase + 2u],
        src1q[q8_block_idx].qs[4u * phase + 3u],
    );
}
fn repack_b_dm(q8_block_idx: u32) -> B_DS_TYPE {
    return B_DS_TYPE(src1q[q8_block_idx].d);
}
fn get_dm(block_byte_base: u32) -> vec2<f32> {
    return vec2<f32>(
        f32(load_f16_at_src0(block_byte_base + 80u)),
        f32(load_f16_at_src0(block_byte_base + 82u)),
    );
}
fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> {
    let scale_byte = block_byte_base + tid;
    let scale = byte_of(load_u32_at_src0_aligned(scale_byte), scale_byte & 3u);
    return vec2<f32>(f32(scale & 0xFu), f32(scale >> 4u));
}
fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 {
    let a_repacked = repack_a(a_byte_base, tid);
    let dm = get_dm(a_byte_base);
    let scale_min = get_scale_min(a_byte_base, tid);

    let scale_q = i32(scale_min.x);
    let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u;

    let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1])
                   + dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q;
    let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4)
                  + dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4);

    return b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m));
}
#endif

#ifdef MUL_ACC_Q4_K
#define BLOCK_SIZE_BYTES 144
#define B_DS_TYPE vec2<f32>
fn repack_a(block_byte_base: u32, tid: u32) -> vec4<u32> {
    let iq4 = tid / 4u;
    let phase = tid % 2u;
    let nibble = (tid >> 1u) % 2u;
    let q_qs_byte_base = block_byte_base + 16u + 32u * iq4 + 16u * phase;
    let qs_shift = 4u * nibble;
    return vec4<u32>(
        (load_u32_at_src0_aligned(q_qs_byte_base) >> qs_shift) & 0x0F0F0F0Fu,
        (load_u32_at_src0_aligned(q_qs_byte_base + 4u) >> qs_shift) & 0x0F0F0F0Fu,
        (load_u32_at_src0_aligned(q_qs_byte_base + 8u) >> qs_shift) & 0x0F0F0F0Fu,
        (load_u32_at_src0_aligned(q_qs_byte_base + 12u) >> qs_shift) & 0x0F0F0F0Fu,
    );
}
fn repack_b_qs(q8_block_idx: u32, tid: u32) -> vec4<u32> {
    let phase = tid % 2u;
    return vec4<u32>(
        src1q[q8_block_idx].qs[4u * phase],
        src1q[q8_block_idx].qs[4u * phase + 1u],
        src1q[q8_block_idx].qs[4u * phase + 2u],
        src1q[q8_block_idx].qs[4u * phase + 3u],
    );
}
fn repack_b_dm(q8_block_idx: u32) -> B_DS_TYPE {
    return B_DS_TYPE(
        f32(src1q[q8_block_idx].d),
        f32(src1q[q8_block_idx].s),
    );
}
fn get_dm(block_byte_base: u32) -> vec2<f32> {
    return vec2<f32>(
        f32(load_f16_at_src0(block_byte_base + 0u)),
        f32(load_f16_at_src0(block_byte_base + 2u)),
    );
}
fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> {
    let sc_m_idx = tid / 2u;
    let scales_byte_base = block_byte_base + 4u;
    let scales0_3  = load_u32_at_src0_aligned(scales_byte_base);
    let scales4_7  = load_u32_at_src0_aligned(scales_byte_base + 4u);
    let scales8_11 = load_u32_at_src0_aligned(scales_byte_base + 8u);

    let byte_idx = sc_m_idx & 3u;
    let is_high = sc_m_idx >= 4u;

    let sc_low  = byte_of(scales0_3, byte_idx) & 0x3Fu;
    let sc_high = (byte_of(scales8_11, byte_idx) & 0x0Fu) | ((byte_of(scales0_3, byte_idx) & 0xC0u) >> 2u);
    let scale = f32(select(sc_low, sc_high, is_high));

    let mn_low  = byte_of(scales4_7, byte_idx) & 0x3Fu;
    let mn_high = (byte_of(scales8_11, byte_idx) >> 4u) | ((byte_of(scales4_7, byte_idx) & 0xC0u) >> 2u);
    let min_val = f32(select(mn_low, mn_high, is_high));

    return vec2<f32>(scale, min_val);
}
fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 {
    let a_repacked = repack_a(a_byte_base, tid);
    let dm = get_dm(a_byte_base);
    let scale_min = get_scale_min(a_byte_base, tid);

    let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1])
                + dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]);

    // Each thread covers half of the Q8_1 block, so add only b_ds.y/2.
    return b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD));
}
#endif

#ifdef K_QUANTS
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
    var acc: array<f32, OUTPUTS_PER_WG>;

    let tid = thread_id % THREADS_PER_BLOCK;

    for (var block = thread_id / THREADS_PER_BLOCK; block < params.k / BLOCK_SIZE; block += WG_SIZE / THREADS_PER_BLOCK) {
        let src1q_idx = src1q_idx_base + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE;
        let b_repacked = repack_b_qs(src1q_idx, tid);
        let b_ds = repack_b_dm(src1q_idx);

        for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
            let output_row = row_base + row;
            if (output_row < params.m) {
                let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
                acc[row] += mmvq_dot_product(block_byte_base, tid, b_repacked, b_ds);
            }
        }
    }

    return acc;
}
#endif