use metal::{
Buffer, CommandBufferRef, ComputePipelineState, MTLSize, NSUInteger,
};
use super::metal::{MetalBackend, MetalError};
use super::mtl_weight_buf::MtlWeightBuf;
use super::variants::GROUP_SIZE;
pub struct MatvecSpec<'a> {
pub w_off: u64,
pub s_off: u64,
pub b_off: u64,
pub input: &'a Buffer,
pub output: &'a Buffer,
pub out_dim: u32,
pub in_dim: u32,
pub bits: u32,
}
pub struct MatvecPipelines {
pub v3_4bit: ComputePipelineState,
pub fast_4bit: ComputePipelineState,
pub v3_8bit: ComputePipelineState,
}
impl MatvecPipelines {
pub fn fetch(metal: &mut MetalBackend) -> Result<Self, MetalError> {
Ok(Self {
v3_4bit: metal.pipeline("dequant_matvec_4bit_v3")?.clone(),
fast_4bit: metal.pipeline("dequant_matvec_4bit_fast")?.clone(),
v3_8bit: metal.pipeline("dequant_matvec_8bit_v3")?.clone(),
})
}
}
pub fn encode_matvec(
cmdbuf: &CommandBufferRef,
pipes: &MatvecPipelines,
wf_buf: &MtlWeightBuf,
spec: &MatvecSpec,
) {
let group_size = GROUP_SIZE as u32;
let (pipeline, use_v3_layout) = if spec.bits == 8 {
(&pipes.v3_8bit, true)
} else if spec.in_dim <= 4096 {
(&pipes.v3_4bit, true)
} else {
(&pipes.fast_4bit, false)
};
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(pipeline);
enc.set_buffer(0, Some(wf_buf.buffer()), spec.w_off as NSUInteger);
enc.set_buffer(1, Some(wf_buf.buffer()), spec.s_off as NSUInteger);
enc.set_buffer(2, Some(wf_buf.buffer()), spec.b_off as NSUInteger);
enc.set_buffer(3, Some(spec.input), 0);
enc.set_buffer(4, Some(spec.output), 0);
enc.set_bytes(5, 4, (&spec.out_dim as *const u32).cast());
enc.set_bytes(6, 4, (&spec.in_dim as *const u32).cast());
enc.set_bytes(7, 4, (&group_size as *const u32).cast());
if use_v3_layout {
let num_tgs = (spec.out_dim + 7) / 8;
enc.dispatch_thread_groups(
MTLSize::new(num_tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
} else {
enc.dispatch_thread_groups(
MTLSize::new(spec.out_dim as NSUInteger, 1, 1),
MTLSize::new(64, 1, 1),
);
}
enc.end_encoding();
}