llama-cpp-sys-4 0.2.46

Low Level Bindings to llama.cpp
Documentation
enable f16;

#ifdef TYPE_F32
#define DataType f32
#endif
#ifdef TYPE_F16
#define DataType f16
#endif

#ifdef OP_REGLU
fn op(a: DataType, b: DataType) -> DataType {
    return max(a, 0) * b;
}
#endif

#ifdef OP_GEGLU
const SQRT_2_OVER_PI: DataType =  0.79788456080286535587989211986876;
const GELU_COEF_A: DataType = 0.044715;

fn op(a: DataType, b: DataType) -> DataType {
    let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a);
    return 0.5 * a * (2.0 - 2.0/ (exp(2* val) + 1)) * b;
}
#endif

#ifdef OP_SWIGLU
fn op(a: DataType, b: DataType) -> DataType {
    return a / (1.0 + exp(-a)) * b;
}
#endif
#ifdef OP_SWIGLU_OAI
fn op(a: f32, b: f32) -> f32 {
    let xi = min(a, params.limit);
    let gi = max(min(b, params.limit), -params.limit);
    var out_glu = xi / (1.0 + exp(-xi * params.alpha));
    out_glu = out_glu * (1.0 + gi);
    return out_glu;
}
#endif
#ifdef OP_GEGLU_ERF
const p_erf: DataType = 0.3275911;
const a1_erf: DataType = 0.254829592;
const a2_erf: DataType = -0.284496736;
const a3_erf: DataType = 1.421413741;
const a4_erf: DataType = -1.453152027;
const a5_erf: DataType = 1.061405429;
const SQRT_2_INV: DataType = 0.7071067811865476;

fn op(a: DataType, b: DataType) -> DataType {
    let a_div_sqr2 = a * SQRT_2_INV;
    let sign_x = sign(a_div_sqr2);
    let x = abs(a_div_sqr2);
    let t = 1.0 / (1.0 + p_erf * x);
    let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x));
    let erf_approx = sign_x * y;
    return 0.5 * a * (1.0 + erf_approx) * b;
}
#endif
#ifdef OP_GEGLU_QUICK
const GELU_QUICK_COEF: DataType = -1.702;

fn op(a: DataType, b: DataType) -> DataType {
    return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b;
}
#endif

struct Params {
    offset_src0: u32,
    offset_src1: u32,
    offset_dst: u32,

    // Strides (in elements)
    stride_src01: u32,
    stride_src02: u32,
    stride_src03: u32,

    stride_src11: u32,
    stride_src12: u32,
    stride_src13: u32,

    stride_dst1: u32,
    stride_dst2: u32,
    stride_dst3: u32,

    // shape of dst
    ne: u32,
    ne0: u32,
    ne1: u32,
    ne2: u32,

    swapped: u32,
    alpha: f32,
    limit: f32,
}

@group(0) @binding(0)
var<storage, read_write> src0: array<DataType>;

#ifdef NO_SPLIT
@group(0) @binding(1)
var<storage, read_write> dst: array<DataType>;

@group(0) @binding(2)
var<uniform> params: Params;

fn a_value(base: u32) -> DataType {
    let offset: u32 = select(0, params.ne0, params.swapped != 0);
    return src0[base + offset];
}

fn b_value(base: u32) -> DataType {
    let offset: u32 = select(params.ne0, 0, params.swapped != 0);
    return src0[base + offset];
}

#else
@group(0) @binding(1)
var<storage, read_write> src1: array<DataType>;

@group(0) @binding(2)
var<storage, read_write> dst: array<DataType>;

@group(0) @binding(3)
var<uniform> params: Params;

fn a_value(base: u32) -> DataType {
    return src0[base];
}

fn b_value(base: u32) -> DataType {
    return src1[base];
}

#endif

@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    if (gid.x >= params.ne) {
        return;
    }

    var i = gid.x;
    let i3 = i / (params.ne2 * params.ne1 * params.ne0);
    i = i % (params.ne2 * params.ne1 * params.ne0);
    let i2 = i / (params.ne1 * params.ne0);
    i = i % (params.ne1 * params.ne0);
    let i1 = i / params.ne0;
    let i0 = i % params.ne0;

    let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0;
    let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0;
    let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;

    dst[i_dst] = op(a_value(i_a), b_value(i_b));
}