llama-cpp-sys-4 0.2.45

Low Level Bindings to llama.cpp
Documentation
enable f16;

#ifdef TYPE_F32
#define DataType f32
#endif
#ifdef TYPE_F16
#define DataType f16
#endif

struct Params {
    offset_src0: u32,
    offset_src1: u32,
    offset_src2: u32,
    offset_dst: u32,

    // Strides (in elements)
    stride_src01: u32,
    stride_src02: u32,
    stride_src03: u32,

    stride_dst1: u32,
    stride_dst2: u32,
    stride_dst3: u32,

    n_threads: u32,
    ne0: u32,
    ne1: u32,
    ne2: u32,

    n_dims: u32,
    mode: u32,
    theta_scale: f32,
    attn_factor: f32,
    freq_scale: f32,
    ext_factor: f32,
    corr_dim0: f32,
    corr_dim1: f32,
    sections0: u32,
    sections1: u32,
    sections2: u32,
    sections3: u32
};

@group(0) @binding(0)
var<storage, read_write> src0: array<DataType>;
@group(0) @binding(1)
var<storage, read_write> src1: array<i32>;

#ifdef INPLACE

#ifdef FF_FUNC

@group(0) @binding(2)
var<storage, read_write> src2: array<f32>;

@group(0) @binding(3)
var<uniform> params: Params;

#else

@group(0) @binding(2)
var<uniform> params: Params;

#endif

#else

#ifdef FF_FUNC
@group(0) @binding(2)
var<storage, read_write> src2: array<f32>;

@group(0) @binding(3)
var<storage, read_write> dst: array<DataType>;

@group(0) @binding(4)
var<uniform> params: Params;

#else
@group(0) @binding(2)
var<storage, read_write> dst: array<DataType>;

@group(0) @binding(3)
var<uniform> params: Params;
#endif
#endif

#ifdef FF_FUNC
fn freq_factor(i: u32) -> f32 {
    return src2[params.offset_src2 + i/2];
}

#else
fn freq_factor(i: u32) -> f32 {
    return 1.0f;
}
#endif
#ifdef INPLACE
fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
    src0[i_dst0] = DataType(out0);
    src0[i_dst1] = DataType(out1);
}
#else
fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
    dst[i_dst0] = DataType(out0);
    dst[i_dst1] = DataType(out1);
}
#endif

fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 {
    let y = (f32(i / 2) - low) / max(0.001f, high - low);
    return 1.0f - min(1.0f, max(0.0f, y));
}

// returns vector of (cos_theta, sin_theta)
// TODO: check performance of instantiating once on the CPU and passed as buffer, since it's repeated per-row
fn rope_yarn(theta_extrap: f32, i: u32) -> vec2<f32> {
    var mscale = params.attn_factor;
    var theta  = params.freq_scale * theta_extrap;
    if (params.ext_factor != 0.0f) {
        let ramp_mix = rope_yarn_ramp(params.corr_dim0, params.corr_dim1, i) * params.ext_factor;
        theta = theta * (1 - ramp_mix) + theta_extrap * ramp_mix;
        mscale *= 1.0f + 0.1f * log(1.0f / params.freq_scale);
    }
    return vec2<f32>(cos(theta) * mscale, sin(theta) * mscale);
}

fn pair_base(i0: u32, div_2: bool) -> u32 {
    if (div_2) {
        return i0 / 2;
    } else {
        return i0;
    }
}

fn pair_offset(is_neox: bool, is_mrope: bool, is_vision: bool) -> u32 {
    if (is_vision) {
        return params.n_dims;
    } else if (is_neox || is_mrope) {
        return params.n_dims / 2;
    } else {
        return 1;
    }
}

@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
    // two elements per n_threads
    if (gid.x >= params.n_threads) {
        return;
    }

    let is_neox = bool(params.mode & 2);
    let is_mrope = bool(params.mode & 8);
    let is_imrope = params.mode == 40;
    let is_vision = params.mode == 24;

    var i = gid.x * 2; // start index for this thread
    let i3 = i / (params.ne2 * params.ne1 * params.ne0);
    i = i % (params.ne2 * params.ne1 * params.ne0);
    let i2 = i / (params.ne1 * params.ne0);
    i = i % (params.ne1 * params.ne0);
    let i1 = i / params.ne0;
    let i0 = i % params.ne0;

    let i_src_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01;
    let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;

    if (i0 >= params.n_dims && !is_vision) {
        let i_src = i_src_row + i0;
        let i_dst = i_dst_row + i0;
        rotate(i_dst, i_dst + 1, f32(src0[i_src]), f32(src0[i_src + 1]));
        return;
    }

    var theta_base_mult: u32 = 0;
    var theta_scale_pwr: u32 = i0 / 2;
    if (is_mrope) {
        let sect_dims = params.sections0 + params.sections1 + params.sections2 + params.sections3;
        let sec_w = params.sections1 + params.sections0;
        let sec_e = params.sections2 + sec_w;
        let sector = (i0 / 2) % sect_dims;
        if (is_imrope) {
          if (sector % 3 == 1 && sector < 3 * params.sections1) {
              theta_base_mult = 1;
          } else if (sector % 3 == 2 && sector < 3 * params.sections2) {
              theta_base_mult = 2;
          } else if (sector % 3 == 0 && sector < 3 * params.sections0) {
              theta_base_mult = 0;
          } else {
              theta_base_mult = 3;
          }
        } else {
          if (sector >= params.sections0 && sector < sec_w) {
              theta_base_mult = 1;
              if (is_vision) {
                  theta_scale_pwr = sector - params.sections0;
              }
          } else if (sector >= sec_w && sector < sec_e) {
              theta_base_mult = 2;
              if (is_vision) {
                  theta_scale_pwr = sector - sec_w;
              }
          } else if (sector >= sec_e) {
              if (is_vision) {
                  theta_scale_pwr = sector - sec_e;
                  theta_scale_pwr = (i0 / 2) % sec_e;
              }
              theta_base_mult = 3;
          } else if (is_vision) {
              theta_scale_pwr = sector;
          }
        }
    }
    let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr));
    let thetas = rope_yarn(theta_base/freq_factor(i0), i0);

    let i_src = i_src_row + pair_base(i0, is_neox || is_mrope || is_vision);
    let i_dst = i_dst_row + pair_base(i0, is_neox || is_mrope || is_vision);

    let x0 = f32(src0[i_src]);
    let x1 = f32(src0[i_src + pair_offset(is_neox, is_mrope, is_vision)]);
    rotate(i_dst, i_dst + pair_offset(is_neox, is_mrope, is_vision), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x);

}