1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
//! Marlin INT4xFP16 fused GEMM kernel (IST Austria).
//!
//! Near-ideal 3.9x speedup over FP16 cuBLAS for INT4 quantized weights.
//! Weights must be in Marlin packed format (different from GPTQ).
//!
//! Constraints: K % 128 == 0, N % 256 == 0, SM >= 8.0 (Ampere+).
use cudarc::driver::{CudaSlice, CudaStream, DevicePtr};
use std::sync::Arc;
// FFI declaration for the Marlin CUDA kernel.
// Only linked when the "marlin" feature is enabled (requires nvcc + SM >= 8.0).
#[cfg(feature = "marlin")]
extern "C" {
fn marlin_cuda(
A: *const std::ffi::c_void,
B: *const std::ffi::c_void,
C: *mut std::ffi::c_void,
s: *const std::ffi::c_void,
prob_m: i32,
prob_n: i32,
prob_k: i32,
workspace: *mut std::ffi::c_void,
groupsize: i32,
dev: i32,
stream: cudarc::driver::sys::CUstream,
thread_k: i32,
thread_n: i32,
sms: i32,
max_par: i32,
) -> i32;
}
/// Check if Marlin kernel is available at compile time.
pub fn is_available() -> bool {
cfg!(feature = "marlin")
}
/// Marlin-format quantized weight for one linear layer.
pub struct MarlinWeight {
/// Repacked INT4 weights in Marlin tile format: varies by K, N
pub qweight: CudaSlice<i32>,
/// Per-group FP16 scales (permuted for Marlin access pattern)
pub scales: CudaSlice<half::f16>,
/// Workspace for Marlin kernel: [N/128 * max_par] int32, zeroed
pub workspace: CudaSlice<i32>,
pub k: usize,
pub n: usize,
pub group_size: i32,
}
/// Run Marlin INT4xFP16 fused GEMM.
///
/// Computes: C[m, n] = A[m, k] @ dequant(B[k, n])
/// where B is in Marlin packed INT4 format.
///
/// Only available when compiled with `--features marlin`.
#[cfg(feature = "marlin")]
pub fn marlin_gemm(
stream: &Arc<CudaStream>,
input: &CudaSlice<half::f16>,
weight: &MarlinWeight,
output: &mut CudaSlice<half::f16>,
m: i32,
) -> candle_core::Result<()> {
let n = weight.n as i32;
let k = weight.k as i32;
let raw_stream = stream.cu_stream();
// Zero workspace on the runner's stream — Marlin uses it as mutex locks.
// All operations (memset + kernel) on same stream → naturally ordered.
{
let (ws_ptr, _guard) = weight.workspace.device_ptr(stream);
unsafe {
cudarc::driver::sys::cuMemsetD32Async(ws_ptr, 0, weight.workspace.len(), raw_stream);
}
}
// Get raw device pointers
let (a_ptr, _a_guard) = input.device_ptr(stream);
let (b_ptr, _b_guard) = weight.qweight.device_ptr(stream);
let (c_ptr, _c_guard) = output.device_ptr(stream);
let (s_ptr, _s_guard) = weight.scales.device_ptr(stream);
let (ws_ptr, _ws_guard) = weight.workspace.device_ptr(stream);
let ret = unsafe {
marlin_cuda(
a_ptr as *const _,
b_ptr as *const _,
c_ptr as *mut _,
s_ptr as *const _,
m,
n,
k,
ws_ptr as *mut _,
weight.group_size,
0, // dev
raw_stream,
-1, // auto thread_k
-1, // auto thread_n
-1, // auto sms
16, // max_par
)
};
if ret != 0 {
return Err(candle_core::Error::Msg(format!(
"marlin_cuda failed: ret={ret} (m={m}, n={n}, k={k}, gs={})",
weight.group_size
)));
}
// No per-call sync needed — all operations (memset + kernel) are on the
// runner's stream. decode_step syncs once at the end before returning logits.
Ok(())
}
/// Stub when Marlin feature is not enabled.
#[cfg(not(feature = "marlin"))]
pub fn marlin_gemm(
_stream: &Arc<CudaStream>,
_input: &CudaSlice<half::f16>,
_weight: &MarlinWeight,
_output: &mut CudaSlice<half::f16>,
_m: i32,
) -> candle_core::Result<()> {
Err(candle_core::Error::Msg(
"Marlin kernel not available (compile with --features marlin)".into(),
))
}
// ===================== Weight Repacking (GPTQ → Marlin) =====================
/// Repack GPTQ INT4 weights to Marlin format on CPU.
///
/// GPTQ format: qweight [K/8, N] int32 (in_features packed, out_features columns)
/// Marlin format: [N/16, K*16/8] int32, tiled and permuted for tensor core access
///
/// Key: Marlin operates on [N, K] layout (out_features first, like PyTorch Linear.weight).
/// GPTQ stores [K, N]. Must transpose before tiling.
///
/// Reference: IST-DASLab/marlin __init__.py Layer.pack()
pub fn repack_gptq_to_marlin(
qweight_gptq: &[i32], // [K/8, N]
k: usize,
n: usize,
) -> Vec<i32> {
// Step 1: Unpack GPTQ [K/8, N] → individual INT4 values [K, N]
let packed_rows = k / 8;
let mut kn = vec![0u8; k * n]; // [K, N] layout
for pr in 0..packed_rows {
for col in 0..n {
let packed = qweight_gptq[pr * n + col];
for i in 0..8 {
kn[(pr * 8 + i) * n + col] = ((packed >> (i * 4)) & 0xF) as u8;
}
}
}
// Step 2: Transpose [K, N] to get w = linear.weight.data.t() = [K, N]
// (GPTQ stores [K, N] already, so kn IS [K, N] — no transpose needed!)
// Marlin's pack() does: w = linear.weight.data.t() which gives [K, N].
// Our kn is already [K, N].
// Step 3: Tile [K, N] → [K/16, 16, N/16, 16] → permute(0,2,1,3) → [K/16, N*16]
let tile = 16;
let kt = k / tile;
let nt = n / tile;
let mut tiled = vec![0u8; k * n]; // [K/16, N*16]
for tk in 0..kt {
for tn in 0..nt {
for ik in 0..tile {
for in_ in 0..tile {
let src = (tk * tile + ik) * n + (tn * tile + in_);
let dst = tk * (n * tile) + tn * (tile * tile) + ik * tile + in_;
tiled[dst] = kn[src];
}
}
}
}
// Step 4: Apply _perm in blocks of 1024
let perm = build_marlin_perm();
let total = k * n;
let mut permuted = vec![0u8; total];
let num_blocks = total / 1024;
for blk in 0..num_blocks {
let base = blk * 1024;
for (dst, &src) in perm.iter().enumerate() {
permuted[base + dst] = tiled[base + src];
}
}
// Step 4: Pack 8 INT4 values → int32, taking every 8th element
// result shape: [N/16, K*16/8] = [N/16, K*2]
let packed_len = total / 8;
let mut result = vec![0i32; packed_len];
for i in 0..packed_len {
let mut word = 0u32;
for j in 0..8 {
word |= (permuted[i * 8 + j] as u32) << (j * 4);
}
result[i] = word as i32;
}
result
}
/// Permute scales from GPTQ layout to Marlin access pattern.
///
/// GPTQ: [num_groups, N] row-major (groups along K, columns are out_features)
/// Marlin: [num_groups, N] but reshuffled to match the kernel's tile access.
///
/// Reference: IST-DASLab/marlin __init__.py _scale_perm / _scale_perm_single
pub fn repack_scales_to_marlin(
scales_gptq: &[half::f16], // [num_groups, N]
k: usize,
n: usize,
group_size: usize,
) -> Vec<half::f16> {
let num_groups = k / group_size;
// Build permutation table matching Marlin's scale access pattern
let scale_perm: Vec<usize> = if num_groups > 1 {
// Grouped quantization (group_size=128, group_blocks=8)
// _scale_perm = [i + 8*j for i in range(8) for j in range(8)]
(0..8)
.flat_map(|i| (0..8).map(move |j| i + 8 * j))
.collect()
} else {
// Per-channel (group_size=-1, group_blocks=-1)
// _scale_perm_single = [2*i+j for i in range(4) for j in [0,1,8,9,16,17,24,25]]
(0..4)
.flat_map(|i| [0, 1, 8, 9, 16, 17, 24, 25].map(move |j| 2 * i + j))
.collect()
};
// Flatten scales, apply permutation in blocks
let total = num_groups * n;
let perm_len = scale_perm.len();
let mut result = vec![half::f16::ZERO; total];
// Reshape scales as flat array, permute in blocks of perm_len
for blk in 0..(total / perm_len) {
let base = blk * perm_len;
for (dst, &src) in scale_perm.iter().enumerate() {
result[base + dst] = scales_gptq[base + src];
}
}
// Remainder (if total not divisible by perm_len)
let rem_start = (total / perm_len) * perm_len;
for i in rem_start..total {
result[i] = scales_gptq[i];
}
result
}
/// Build the 1024-element Marlin weight permutation array.
///
/// This encodes the m16n8k16 tensor core mma fragment layout.
/// Each 1024-element block of the tiled weight [N/16, K*16] is
/// permuted to match how the Marlin kernel loads data into
/// tensor core fragments via shared memory.
///
/// Reference: IST-DASLab/marlin __init__.py _perm construction
fn build_marlin_perm() -> Vec<usize> {
let mut perm = Vec::with_capacity(1024);
for i in 0..32 {
let col = i / 4;
let mut perm1 = Vec::with_capacity(8);
for _block in 0..2 {
for &row_off in &[0, 1, 8, 9] {
let row = 2 * (i % 4) + row_off / 8 * 8 + row_off % 8;
// Actually, the original Python is:
// for row in [2*(i%4), 2*(i%4)+1, 2*(i%4+4), 2*(i%4+4)+1]:
// perm1.append(16*row + col + 8*block)
let _ = row; // ignore, use direct construction below
}
}
// Direct from Python: for block in [0,1]: for row in [...]: perm1.append(...)
perm1.clear();
for block in 0..2 {
for &row in &[
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
] {
perm1.push(16 * row + col + 8 * block);
}
}
for j in 0..4 {
for &p in &perm1 {
perm.push(p + 256 * j);
}
}
}
assert_eq!(perm.len(), 1024);
// KEY: apply interleave [0,2,4,6,1,3,5,7] within each group of 8
let interleave = [0usize, 2, 4, 6, 1, 3, 5, 7];
let mut perm_interleaved = vec![0usize; 1024];
for g in 0..128 {
for i in 0..8 {
perm_interleaved[g * 8 + i] = perm[g * 8 + interleave[i]];
}
}
perm_interleaved
}