use metal::{Buffer, CommandBufferRef, MTLSize, NSUInteger};
use moeflux_metal::{QmmCall, Kernels, QuantWeights};
use super::encoder::pipeline_bundle;
use crate::riir::io::mtl_weight_buf::MtlWeightBuf;
use crate::riir::variants::GROUP_SIZE;
pub struct MatvecSpec<'a> {
pub w_off: u64,
pub s_off: u64,
pub b_off: u64,
pub input: &'a Buffer,
pub output: &'a Buffer,
pub out_dim: u32,
pub in_dim: u32,
pub bits: u32,
}
pipeline_bundle! {
pub struct MatvecPipelines {
v3_4bit => "dequant_matvec_4bit_v3",
fast_4bit => "dequant_matvec_4bit_fast",
v3_8bit => "dequant_matvec_8bit_v3",
v3_4bit_n => "dequant_matvec_4bit_v3_n_tokens",
fast_4bit_n => "dequant_matvec_4bit_fast_n_tokens",
v3_8bit_n => "dequant_matvec_8bit_v3_n_tokens",
}
}
pub fn encode_matvec(
cmdbuf: &CommandBufferRef,
pipes: &MatvecPipelines,
wf_buf: &MtlWeightBuf,
spec: &MatvecSpec,
) {
let group_size = GROUP_SIZE as u32;
let (pipeline, use_v3_layout) = if spec.bits == 8 {
(&pipes.v3_8bit, true)
} else if spec.in_dim <= 4096 {
(&pipes.v3_4bit, true)
} else {
(&pipes.fast_4bit, false)
};
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(pipeline);
enc.set_buffer(0, Some(wf_buf.buffer()), spec.w_off as NSUInteger);
enc.set_buffer(1, Some(wf_buf.buffer()), spec.s_off as NSUInteger);
enc.set_buffer(2, Some(wf_buf.buffer()), spec.b_off as NSUInteger);
enc.set_buffer(3, Some(spec.input), 0);
enc.set_buffer(4, Some(spec.output), 0);
enc.set_bytes(5, 4, (&spec.out_dim as *const u32).cast());
enc.set_bytes(6, 4, (&spec.in_dim as *const u32).cast());
enc.set_bytes(7, 4, (&group_size as *const u32).cast());
if use_v3_layout {
let num_tgs = (spec.out_dim + 7) / 8;
enc.dispatch_thread_groups(
MTLSize::new(num_tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
} else {
enc.dispatch_thread_groups(
MTLSize::new(spec.out_dim as NSUInteger, 1, 1),
MTLSize::new(64, 1, 1),
);
}
enc.end_encoding();
}
#[allow(clippy::too_many_arguments)]
pub fn encode_matvec_n_tokens(
cmdbuf: &CommandBufferRef,
pipes: &MatvecPipelines,
w_buf: &Buffer,
w_off: u64,
s_off: u64,
b_off: u64,
input: &Buffer,
input_off: u64,
output: &Buffer,
output_off: u64,
in_dim: u32,
out_dim: u32,
n_tokens: u32,
bits: u32,
) {
assert!(
bits == 4 || bits == 8,
"encode_matvec_n_tokens: only 4-bit / 8-bit supported (got bits={})",
bits
);
if n_tokens == 0 {
return;
}
let group_size = GROUP_SIZE as u32;
let use_v3 = bits == 8 || in_dim <= 4096;
let pipeline = if bits == 8 {
&pipes.v3_8bit_n
} else if use_v3 {
&pipes.v3_4bit_n
} else {
&pipes.fast_4bit_n
};
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(pipeline);
enc.set_buffer(0, Some(w_buf), w_off as NSUInteger);
enc.set_buffer(1, Some(w_buf), s_off as NSUInteger);
enc.set_buffer(2, Some(w_buf), b_off as NSUInteger);
enc.set_buffer(3, Some(input), input_off as NSUInteger);
enc.set_buffer(4, Some(output), output_off as NSUInteger);
enc.set_bytes(5, 4, (&out_dim as *const u32).cast());
enc.set_bytes(6, 4, (&in_dim as *const u32).cast());
enc.set_bytes(7, 4, (&group_size as *const u32).cast());
enc.set_bytes(8, 4, (&n_tokens as *const u32).cast());
if use_v3 {
let num_row_tiles = (out_dim + 7) / 8;
enc.set_bytes(9, 4, (&num_row_tiles as *const u32).cast());
let total_tgs =
(num_row_tiles as u64).saturating_mul(n_tokens as u64);
enc.dispatch_thread_groups(
MTLSize::new(total_tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
} else {
let total_tgs = (out_dim as u64).saturating_mul(n_tokens as u64);
enc.dispatch_thread_groups(
MTLSize::new(total_tgs as NSUInteger, 1, 1),
MTLSize::new(64, 1, 1),
);
}
enc.end_encoding();
}
#[allow(clippy::too_many_arguments)]
pub fn encode_dense_matmul_n_tokens(
cmdbuf: &CommandBufferRef,
kernels: &Kernels,
pipes: &MatvecPipelines,
w_buf: &Buffer,
w_off: u64,
s_off: u64,
b_off: u64,
input: &Buffer,
input_off: u64,
output: &Buffer,
output_off: u64,
in_dim: u32,
out_dim: u32,
n_tokens: u32,
bits: u32,
) {
if n_tokens == 0 {
return;
}
if bits == 4 {
kernels.encode(
cmdbuf,
&QmmCall {
weights: QuantWeights {
buffer: w_buf,
packed_offset: w_off,
scales_offset: s_off,
biases_offset: b_off,
},
input,
input_offset: input_off,
output,
output_offset: output_off,
in_dim,
out_dim,
n_tokens,
},
);
} else {
encode_matvec_n_tokens(
cmdbuf, pipes, w_buf, w_off, s_off, b_off, input, input_off,
output, output_off, in_dim, out_dim, n_tokens, bits,
);
}
}
pipeline_bundle! {
pub struct BfMatvecPipelines {
bf16 => "bf16_matvec",
bf16_n => "bf16_matmul_n_tokens",
}
}
#[allow(clippy::too_many_arguments)]
pub fn encode_bf16_matvec(
cmdbuf: &CommandBufferRef,
pipes: &BfMatvecPipelines,
w_buf: &Buffer,
w_off: u64,
input: &Buffer,
output: &Buffer,
in_dim: u32,
out_dim: u32,
) {
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(&pipes.bf16);
enc.set_buffer(0, Some(w_buf), w_off as NSUInteger);
enc.set_buffer(1, Some(input), 0);
enc.set_buffer(2, Some(output), 0);
enc.set_bytes(3, 4, (&in_dim as *const u32).cast());
enc.set_bytes(4, 4, (&out_dim as *const u32).cast());
enc.dispatch_thread_groups(
MTLSize::new(out_dim as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
#[allow(clippy::too_many_arguments)]
pub fn encode_bf16_matmul_n_tokens(
cmdbuf: &CommandBufferRef,
pipes: &BfMatvecPipelines,
w_buf: &Buffer,
w_off: u64,
input: &Buffer,
output: &Buffer,
in_dim: u32,
out_dim: u32,
n_tokens: u32,
) {
if n_tokens == 0 {
return;
}
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(&pipes.bf16_n);
enc.set_buffer(0, Some(w_buf), w_off as NSUInteger);
enc.set_buffer(1, Some(input), 0);
enc.set_buffer(2, Some(output), 0);
enc.set_bytes(3, 4, (&in_dim as *const u32).cast());
enc.set_bytes(4, 4, (&out_dim as *const u32).cast());
enc.set_bytes(5, 4, (&n_tokens as *const u32).cast());
let total_tgs = (out_dim as u64).saturating_mul(n_tokens as u64);
enc.dispatch_thread_groups(
MTLSize::new(total_tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}