use cudarc::driver::{CudaSlice, CudaStream};
use rlx_ir::quant::QuantScheme;
use std::sync::Arc;
pub fn gguf_scheme_id(scheme: QuantScheme) -> u32 {
match scheme {
QuantScheme::GgufQ4K => 0,
QuantScheme::GgufQ5K => 1,
QuantScheme::GgufQ6K => 2,
QuantScheme::GgufQ8K => 3,
QuantScheme::GgufQ2K => 4,
QuantScheme::GgufQ3K => 5,
other => panic!("rlx-cuda gguf_host: unsupported scheme {other:?}"),
}
}
pub fn scheme_from_id(scheme_id: u32) -> QuantScheme {
match scheme_id {
0 => QuantScheme::GgufQ4K,
1 => QuantScheme::GgufQ5K,
2 => QuantScheme::GgufQ6K,
3 => QuantScheme::GgufQ8K,
4 => QuantScheme::GgufQ2K,
5 => QuantScheme::GgufQ3K,
_ => panic!("rlx-cuda gguf_host: bad scheme_id {scheme_id}"),
}
}
fn dtoh_bytes(
stream: &Arc<CudaStream>,
buffer: &CudaSlice<f32>,
byte_off: usize,
len: usize,
) -> Vec<u8> {
let start_f32 = byte_off / 4;
let end_byte = byte_off + len;
let end_f32 = end_byte.div_ceil(4);
let mut words = vec![0f32; end_f32 - start_f32];
stream
.memcpy_dtoh(&buffer.slice(start_f32..end_f32), &mut words)
.expect("rlx-cuda: gguf dtoh failed");
let mut raw = vec![0u8; words.len() * 4];
for (i, w) in words.iter().enumerate() {
raw[i * 4..i * 4 + 4].copy_from_slice(&w.to_le_bytes());
}
raw[byte_off % 4..byte_off % 4 + len].to_vec()
}
fn htod_bytes(stream: &Arc<CudaStream>, buffer: &mut CudaSlice<f32>, byte_off: usize, data: &[u8]) {
let start_f32 = byte_off / 4;
let end_byte = byte_off + data.len();
let end_f32 = end_byte.div_ceil(4);
let mut words = vec![0f32; end_f32 - start_f32];
stream
.memcpy_dtoh(&buffer.slice(start_f32..end_f32), &mut words)
.expect("rlx-cuda: gguf htod staging dtoh failed");
let mut raw = vec![0u8; words.len() * 4];
for (i, w) in words.iter().enumerate() {
raw[i * 4..i * 4 + 4].copy_from_slice(&w.to_le_bytes());
}
raw[byte_off % 4..byte_off % 4 + data.len()].copy_from_slice(data);
for (i, chunk) in raw.chunks_exact(4).enumerate() {
words[i] = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
}
stream
.memcpy_htod(&words, &mut buffer.slice_mut(start_f32..end_f32))
.expect("rlx-cuda: gguf htod failed");
}
pub fn run_dequant_matmul_gguf(
stream: &Arc<CudaStream>,
buffer: &mut CudaSlice<f32>,
m: usize,
k: usize,
n: usize,
scheme_id: u32,
x_byte_off: usize,
w_byte_off: usize,
out_byte_off: usize,
) {
let scheme = scheme_from_id(scheme_id);
let block_bytes = scheme.gguf_block_bytes() as usize;
let block_elems = scheme.gguf_block_size() as usize;
let total_bytes = (k * n) / block_elems * block_bytes;
stream
.synchronize()
.expect("rlx-cuda: gguf pre-sync failed");
let x_f32_off = x_byte_off / 4;
let mut x_host = vec![0f32; m * k];
stream
.memcpy_dtoh(&buffer.slice(x_f32_off..x_f32_off + m * k), &mut x_host)
.expect("rlx-cuda: gguf x dtoh failed");
let w_host = dtoh_bytes(stream, buffer, w_byte_off, total_bytes);
let mut out_host = vec![0f32; m * n];
rlx_cpu::gguf_matmul::gguf_matmul_bt(&x_host, &w_host, &mut out_host, m, k, n, scheme);
let out_f32_off = out_byte_off / 4;
stream
.memcpy_htod(
&out_host,
&mut buffer.slice_mut(out_f32_off..out_f32_off + m * n),
)
.expect("rlx-cuda: gguf out htod failed");
}
pub fn run_dequant_grouped_matmul_gguf(
stream: &Arc<CudaStream>,
buffer: &mut CudaSlice<f32>,
m: usize,
k: usize,
n: usize,
num_experts: usize,
scheme_id: u32,
x_byte_off: usize,
w_byte_off: usize,
idx_byte_off: usize,
out_byte_off: usize,
) {
let scheme = scheme_from_id(scheme_id);
let block_bytes = scheme.gguf_block_bytes() as usize;
let block_elems = scheme.gguf_block_size() as usize;
let slab_bytes = (k * n) / block_elems * block_bytes;
let total_bytes = num_experts * slab_bytes;
stream
.synchronize()
.expect("rlx-cuda: grouped gguf pre-sync failed");
let x_f32_off = x_byte_off / 4;
let mut x_host = vec![0f32; m * k];
stream
.memcpy_dtoh(&buffer.slice(x_f32_off..x_f32_off + m * k), &mut x_host)
.expect("rlx-cuda: grouped gguf x dtoh failed");
let w_host = dtoh_bytes(stream, buffer, w_byte_off, total_bytes);
let idx_f32_off = idx_byte_off / 4;
let mut idx_host = vec![0f32; m];
stream
.memcpy_dtoh(&buffer.slice(idx_f32_off..idx_f32_off + m), &mut idx_host)
.expect("rlx-cuda: grouped gguf idx dtoh failed");
let mut out_host = vec![0f32; m * n];
rlx_cpu::gguf_matmul::gguf_grouped_matmul_bt(
&x_host,
&w_host,
&idx_host,
&mut out_host,
m,
k,
n,
num_experts,
scheme,
);
let out_f32_off = out_byte_off / 4;
stream
.memcpy_htod(
&out_host,
&mut buffer.slice_mut(out_f32_off..out_f32_off + m * n),
)
.expect("rlx-cuda: grouped gguf out htod failed");
}
pub fn upload_param_bytes(
stream: &Arc<CudaStream>,
buffer: &mut CudaSlice<f32>,
byte_off: usize,
data: &[u8],
) {
htod_bytes(stream, buffer, byte_off, data);
}