1use std::{
2 borrow::Cow,
3 io::Cursor,
4 sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use hanzo_ml::{quantized::GgmlDType, DType, Device, DeviceLocation, Result, Shape, Tensor, D};
9use hanzo_nn::Linear;
10
11use crate::{
12 cublaslt::{maybe_init_cublas_lt_wrapper, CUBLASLT_CONTROLLER},
13 generate_isq, generate_isq_imatrix,
14 hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer, ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
15 utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
16 AfqBits, AfqGroupSize, AfqLayer, FP8Linear, GgufMatMul, ImatrixLayerStats, IsqType,
17 QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
18};
19
20#[derive(Debug)]
21pub struct UnquantLinear {
22 w: Tensor,
23 b: Option<Tensor>,
24 stats: Option<ImatrixLayerStats>,
25}
26
27impl QuantMethod for UnquantLinear {
28 fn new(method: QuantMethodConfig) -> hanzo_ml::Result<Self>
29 where
30 Self: Sized,
31 {
32 match method {
33 QuantMethodConfig::Gguf { .. }
34 | QuantMethodConfig::GptqAwq { .. }
35 | QuantMethodConfig::Hqq { .. }
36 | QuantMethodConfig::Dummy
37 | QuantMethodConfig::FP8 { .. }
38 | QuantMethodConfig::Bnb { .. }
39 | QuantMethodConfig::BlockwiseFP8 { .. }
40 | QuantMethodConfig::PerTensorFP8 { .. }
41 | QuantMethodConfig::Afq { .. }
42 | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
43 QuantMethodConfig::Unquantized(l) => Ok(Self {
44 w: l.weight().clone(),
45 b: l.bias().cloned(),
46 stats: None,
47 }),
48 }
49 }
50
51 fn dequantize_w(&self) -> Result<Tensor> {
52 Ok(self.w.clone())
53 }
54
55 fn forward_raw(&self, a: &Tensor) -> Result<Tensor> {
56 maybe_init_cublas_lt_wrapper(a.device().clone());
58
59 #[cfg(feature = "cuda")]
61 if crate::gemv::should_use_gemv(a, &self.w) {
62 return crate::gemv::gemv(a, &self.w, self.b.as_ref());
63 }
64
65 let w = match *a.dims() {
66 [b1, b2, _, _] => self.w.broadcast_left((b1, b2))?,
67 [bsize, _, _] => self.w.broadcast_left(bsize)?,
68 _ => self.w.clone(),
69 };
70
71 if let Some(stats) = &self.stats {
72 stats.process(a)?;
73 }
74
75 if let Some(b) = self.b.as_ref() {
76 let mut tgt_shape = a.dims().to_vec();
77 tgt_shape[a.dims().len() - 1] = w.dim(D::Minus2)?;
78 let b = b.broadcast_as(Shape::from_dims(&tgt_shape))?;
79
80 match a.device().location() {
81 DeviceLocation::Cuda { .. } => {
82 if let (Device::Cuda(_), Some(cublaslt)) =
84 (a.device(), CUBLASLT_CONTROLLER.get_for_device(a.device()))
85 {
86 cublaslt
87 .batch_matmul(
88 a,
89 &w,
90 Some(&b.t()?.contiguous()?),
91 None,
92 Some(1.0),
93 None,
94 None,
95 )?
96 .t()
97 } else {
98 let matmul_result = a.matmul(&w.t()?)?;
99 matmul_result.broadcast_add(&b)
100 }
101 }
102 DeviceLocation::Metal { .. } => {
103 let matmul_result = a.matmul(&w.t()?)?;
104 matmul_result.broadcast_add(&b)
105 }
106 #[cfg(feature = "rocm")]
107 DeviceLocation::Rocm { .. } => {
108 let matmul_result = a.matmul(&w.t()?)?;
109 matmul_result.broadcast_add(&b)
110 }
111 #[cfg(feature = "vulkan")]
112 DeviceLocation::Vulkan { .. } => {
113 let matmul_result = a.matmul(&w.t()?)?;
114 matmul_result.broadcast_add(&b)
115 }
116 DeviceLocation::Cpu => {
117 #[cfg(feature = "accelerate")]
118 {
119 let original_dtype = a.dtype();
120 let a_f32 = a.to_dtype(DType::F32)?;
121 let w_f32 = w.t()?.to_dtype(DType::F32)?;
122 let b_f32 = b.to_dtype(DType::F32)?;
123 let matmul_result = a_f32.matmul(&w_f32)?;
124 matmul_result
125 .broadcast_add(&b_f32)?
126 .to_dtype(original_dtype)
127 }
128 #[cfg(not(feature = "accelerate"))]
129 {
130 let matmul_result = a.matmul(&w.t()?)?;
131 matmul_result.broadcast_add(&b)
132 }
133 }
134 }
135 } else {
136 match a.device().location() {
137 DeviceLocation::Cuda { .. } => {
138 if let (Device::Cuda(_), Some(cublaslt)) =
139 (a.device(), CUBLASLT_CONTROLLER.get_for_device(a.device()))
140 {
141 if a.rank() >= 3 && w.rank() >= 3 {
143 cublaslt
144 .batch_matmul(a, &w, None, None, None, None, None)?
145 .t()
146 } else {
147 a.matmul(&w.t()?)
148 }
149 } else {
150 a.matmul(&w.t()?)
151 }
152 }
153 DeviceLocation::Metal { .. } => a.matmul(&w.t()?),
154 #[cfg(feature = "rocm")]
155 DeviceLocation::Rocm { .. } => a.matmul(&w.t()?),
156 #[cfg(feature = "vulkan")]
157 DeviceLocation::Vulkan { .. } => a.matmul(&w.t()?),
158 DeviceLocation::Cpu => {
159 #[cfg(feature = "accelerate")]
160 {
161 let original_dtype = a.dtype();
162 a.to_dtype(DType::F32)?
163 .matmul(&w.t()?.to_dtype(DType::F32)?)?
164 .to_dtype(original_dtype)
165 }
166 #[cfg(not(feature = "accelerate"))]
167 {
168 a.matmul(&w.t()?)
169 }
170 }
171 }
172 }
173 }
174
175 fn gather_forward_raw(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
176 let w = &self.w;
185 let (_num_experts, out_features, _in_features) = w.dims3()?;
186
187 match a.dims() {
188 &[b_size, seq_len, 1, 1, hidden_dim] => {
190 let (_b, _s, num_experts_per_tok) = indices.dims3()?;
191 let flat_indices = indices.reshape((b_size * seq_len * num_experts_per_tok,))?;
193
194 let selected_w = w.index_select(&flat_indices, 0)?;
196
197 let a_flat = a.reshape((b_size * seq_len, hidden_dim))?;
199
200 let a_expanded = a_flat
203 .unsqueeze(1)?
204 .broadcast_as((b_size * seq_len, num_experts_per_tok, hidden_dim))?
205 .reshape((b_size * seq_len * num_experts_per_tok, hidden_dim))?;
206
207 let result = a_expanded
209 .unsqueeze(1)?
210 .matmul(&selected_w.transpose(1, 2)?)?
211 .squeeze(1)?;
212
213 result.reshape((b_size, seq_len, num_experts_per_tok, out_features))
215 }
216 &[num_tokens, 1, hidden_dim] => {
218 let (_, num_experts_per_tok) = indices.dims2()?;
219
220 let flat_indices = indices.reshape((num_tokens * num_experts_per_tok,))?;
222
223 let selected_w = w.index_select(&flat_indices, 0)?;
225
226 let a_expanded = a
228 .broadcast_as((num_tokens, num_experts_per_tok, hidden_dim))?
229 .reshape((num_tokens * num_experts_per_tok, hidden_dim))?;
230
231 let result = a_expanded
233 .unsqueeze(1)?
234 .matmul(&selected_w.transpose(1, 2)?)?
235 .squeeze(1)?;
236
237 result.reshape((num_tokens, num_experts_per_tok, out_features))
239 }
240 &[num_tokens, num_experts_per_tok, hidden_dim] => {
241 let (indices_num_tokens, indices_num_experts_per_tok) = indices.dims2()?;
242 if num_tokens != indices_num_tokens
243 || num_experts_per_tok != indices_num_experts_per_tok
244 {
245 hanzo_ml::bail!(
246 "UnquantLinear::gather_forward: input shape {:?} does not match indices shape {:?}",
247 a.dims(),
248 indices.dims()
249 );
250 }
251
252 let flat_indices = indices.reshape((num_tokens * num_experts_per_tok,))?;
253 let selected_w = w.index_select(&flat_indices, 0)?;
254 let a_flat = a.reshape((num_tokens * num_experts_per_tok, hidden_dim))?;
255
256 let result = a_flat
257 .unsqueeze(1)?
258 .matmul(&selected_w.transpose(1, 2)?)?
259 .squeeze(1)?;
260
261 result.reshape((num_tokens, num_experts_per_tok, out_features))
262 }
263 dims => {
264 hanzo_ml::bail!(
265 "UnquantLinear::gather_forward: unsupported input shape {:?}",
266 dims
267 );
268 }
269 }
270 }
271
272 fn quantized_act_type(&self) -> Option<DType> {
273 None
274 }
275
276 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
277 Ok(Arc::new(Self {
278 w: (&self.w + delta)?,
279 b: self.b.clone(),
280 stats: self.stats.clone(),
281 }))
282 }
283
284 fn dtype_and_device(&self) -> (DType, hanzo_ml::Device) {
285 (self.w.dtype(), self.w.device().clone())
286 }
287
288 fn apply_isq(
289 self: Arc<Self>,
290 dtype: Option<IsqType>,
291 device: Device,
292 n_quantized: &AtomicUsize,
293 imatrix_weight: Option<Vec<f32>>,
294 guard: QuantizeOntoGuard,
295 ) -> Result<Arc<dyn QuantMethod>> {
296 match dtype {
297 Some(IsqType::HQQ4 | IsqType::HQQ8) => {
299 let _acquired_quantize_guard = guard.acquire(&device);
300 if imatrix_weight.is_some() {
301 hanzo_ml::bail!("HQQ does not support imatrix.");
303 }
304
305 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
306 let bits = match dtype.unwrap() {
307 IsqType::HQQ8 => HqqBits::Eight,
308 IsqType::HQQ4 => HqqBits::Four,
309 _ => unreachable!(),
313 };
314 let cfg = HqqConfig {
315 bits,
316 group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
317 axis: HqqAxis::Zero,
318 optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
319 round_zeros: false,
320 channel_wise: true,
321 };
322 let res = HqqLayer::quantize(&self.w.to_device(&device)?, &device, cfg)?;
323 if let Some(bias) = &self.b {
324 let bias = bias
325 .to_device(&device)?
326 .to_dtype(res.dtype_and_device().0)?;
327 Ok(Arc::new(res.with_bias(bias)))
328 } else {
329 Ok(Arc::new(res))
330 }
331 }
332 Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
333 let _acquired_quantize_guard = guard.acquire(&device);
334 if imatrix_weight.is_some() {
335 hanzo_ml::bail!("AFQ does not support imatrix.");
337 }
338
339 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
340 let bits = match dtype.unwrap() {
341 IsqType::AFQ8 => AfqBits::Eight,
342 IsqType::AFQ6 => AfqBits::Six,
343 IsqType::AFQ4 => AfqBits::Four,
344 IsqType::AFQ3 => AfqBits::Three,
345 IsqType::AFQ2 => AfqBits::Two,
346 _ => unreachable!(),
347 };
348
349 Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
350 weight: self.w.to_device(&device)?,
351 bias: self.b.as_ref().map(|b| b.to_device(&device).unwrap()),
352 bits,
353 group_size: AfqGroupSize::default(),
354 })?))
355 }
356 Some(
357 IsqType::Q2K
358 | IsqType::Q3K
359 | IsqType::Q4K
360 | IsqType::Q4_0
361 | IsqType::Q4_1
362 | IsqType::Q5K
363 | IsqType::Q5_0
364 | IsqType::Q5_1
365 | IsqType::Q6K
366 | IsqType::Q8K
367 | IsqType::Q8_0
368 | IsqType::Q8_1,
369 ) => {
370 let dtype: GgmlDType = dtype.unwrap().try_into()?;
371 let res = if let Some(imatrix_weight) = imatrix_weight {
372 generate_isq_imatrix!(self.w, imatrix_weight, device, dtype, n_quantized, guard)
373 } else {
374 generate_isq!(self.w, device, dtype, n_quantized, guard)
375 };
376 Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
377 q_weight: res,
378 b: self
379 .b
380 .as_ref()
381 .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
382 })?))
383 }
384 Some(IsqType::F8E4M3) => {
385 let _acquired_quantize_guard = guard.acquire(&device);
386 if imatrix_weight.is_some() {
387 hanzo_ml::bail!("F8E4M3 does not support imatrix.");
389 }
390
391 let w = self.w.to_device(&device)?;
392 let b = if let Some(b) = &self.b {
393 Some(b.to_device(&device)?)
394 } else {
395 None
396 };
397 Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
398 lin: Linear::new(w, b),
399 dtype: DType::F8E4M3,
400 })?))
401 }
402 Some(IsqType::MXFP4) => {
403 let _acquired_quantize_guard = guard.acquire(&device);
404 if imatrix_weight.is_some() {
405 hanzo_ml::bail!("MXFP4 does not support imatrix.");
406 }
407
408 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
409 let w = self.w.to_device(&device)?;
410 let b = self.b.as_ref().map(|b| b.to_device(&device)).transpose()?;
411 crate::MXFP4Layer::quantize(&w, b, &device)
412 }
413 Some(IsqType::F8Q8) => {
414 let _acquired_quantize_guard = guard.acquire(&device);
415 if imatrix_weight.is_some() {
416 hanzo_ml::bail!("F8Q8 does not support imatrix.");
417 }
418
419 let w = self.w.to_device(&device)?;
420 let b = if let Some(b) = &self.b {
421 Some(b.to_device(&device)?)
422 } else {
423 None
424 };
425 Ok(Arc::new(crate::F8Q8Linear::from_weight(&w, b)?))
426 }
427 None => {
428 let _acquired_quantize_guard = guard.acquire(&device);
429 let w = self.w.to_device(&device)?;
432 let b = if let Some(b) = &self.b {
433 Some(b.to_device(&device)?)
434 } else {
435 None
436 };
437 Ok(Arc::new(UnquantLinear::new(
438 QuantMethodConfig::Unquantized(Linear::new(w, b)),
439 )?))
440 }
441 }
442 }
443
444 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
445 Some((self.w.clone(), self.b.clone()))
446 }
447
448 fn begin_track_stats(&mut self) -> Result<()> {
449 self.stats = Some(ImatrixLayerStats::new(&self.w, self.w.device())?);
450 Ok(())
451 }
452
453 fn end_track_stats(&self) -> Result<Tensor> {
454 if let Some(stats) = &self.stats {
455 let imatrix = stats.compute_imatrix()?;
456 stats.clear()?;
457 Ok(imatrix)
458 } else {
459 hanzo_ml::bail!("`{}` does not support tracking stats.", self.name())
460 }
461 }
462}
463
464impl QuantizedSerde for UnquantLinear {
479 fn isq_serde_supported(&self) -> bool {
480 true
481 }
482 fn name(&self) -> &'static str {
483 "unquant-linear"
484 }
485 fn serialize(&self) -> Result<Cow<'_, [u8]>> {
486 self.serialize_with_bias(self.b.clone())
487 }
488 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
489 let mut buffer = Vec::new();
490
491 buffer.extend(&UQFF_VERSION.to_le_bytes());
494
495 buffer.push(QuantizedSerdeType::Unquant as u8);
497
498 buffer.push(bias.is_some() as u8);
500
501 serialize_tensor(&mut buffer, &self.w)?;
503
504 if let Some(bias) = &bias {
505 serialize_tensor(&mut buffer, bias)?;
507 }
508
509 Ok(Cow::from(buffer))
510 }
511
512 fn deserialize(
513 data: Cow<[u8]>,
514 device: &Device,
515 _comm: &Arc<crate::Comm>,
516 guard: QuantizeOntoGuard,
517 ) -> Result<Arc<dyn QuantMethod>>
518 where
519 Self: Sized,
520 {
521 let mut buffer = Cursor::new(data);
522
523 let version = buffer.read_u32::<LittleEndian>()?;
524 if let Err(e) = version_is_compatible(version) {
525 return Err(hanzo_ml::Error::wrap(e));
526 }
527
528 let isq_type = buffer.read_u8()? as usize;
529 if isq_type != QuantizedSerdeType::Unquant as usize {
530 hanzo_ml::bail!(
531 "ISQ type ({isq_type}) doesn't match expected type {}",
532 QuantizedSerdeType::Unquant as usize
533 );
534 }
535
536 let has_bias = buffer.read_u8()? != 0;
537
538 let _acquired_load_guard = guard.acquire(device);
539 let w = deserialize_tensor(&mut buffer, device)?;
540
541 let b = if has_bias {
542 Some(deserialize_tensor(&mut buffer, device)?)
543 } else {
544 None
545 };
546
547 Ok(Arc::new(Self { w, b, stats: None }))
548 }
549 fn deserialize_ext_bias(
550 data: Cow<[u8]>,
551 device: &Device,
552 guard: QuantizeOntoGuard,
553 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
554 where
555 Self: Sized,
556 {
557 let mut buffer = Cursor::new(data);
558
559 let version = buffer.read_u32::<LittleEndian>()?;
560 if let Err(e) = version_is_compatible(version) {
561 return Err(hanzo_ml::Error::wrap(e));
562 }
563
564 let isq_type = buffer.read_u8()? as usize;
565 if isq_type != QuantizedSerdeType::Unquant as usize {
566 hanzo_ml::bail!(
567 "ISQ type ({isq_type}) doesn't match expected type {}",
568 QuantizedSerdeType::Unquant as usize
569 );
570 }
571
572 let has_bias = buffer.read_u8()? != 0;
573
574 let _acquired_load_guard = guard.acquire(device);
575 let w = deserialize_tensor(&mut buffer, device)?;
576
577 let b = if has_bias {
578 Some(deserialize_tensor(&mut buffer, device)?)
579 } else {
580 None
581 };
582
583 Ok((
584 Arc::new(Self {
585 w,
586 b: None,
587 stats: None,
588 }),
589 b,
590 ))
591 }
592}
593
594#[cfg(test)]
595mod tests {
596 use super::*;
597
598 fn test_layer(device: &Device) -> Result<UnquantLinear> {
599 let weight = Tensor::from_vec(
600 vec![1f32, 0., 0., 0., 1., 0., 0., 0., 1., 1., 1., 1.],
601 (2, 2, 3),
602 device,
603 )?;
604 <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(Linear::new(
605 weight, None,
606 )))
607 }
608
609 #[test]
610 fn gather_forward_expands_single_route_input() -> Result<()> {
611 let device = Device::Cpu;
612 let layer = test_layer(&device)?;
613 let input = Tensor::from_vec(vec![1f32, 2., 3., 4., 5., 6.], (2, 1, 3), &device)?;
614 let indices = Tensor::from_vec(vec![0u32, 1, 1, 0], (2, 2), &device)?;
615
616 let output = layer.gather_forward(&input, &indices)?;
617
618 assert_eq!(output.dims(), &[2, 2, 2]);
619 assert_eq!(
620 output.flatten_all()?.to_vec1::<f32>()?,
621 &[1., 2., 3., 6., 6., 15., 4., 5.]
622 );
623 Ok(())
624 }
625
626 #[test]
627 fn gather_forward_accepts_per_route_input() -> Result<()> {
628 let device = Device::Cpu;
629 let layer = test_layer(&device)?;
630 let input = Tensor::from_vec(vec![1f32, 2., 3., 4., 5., 6.], (1, 2, 3), &device)?;
631 let indices = Tensor::from_vec(vec![0u32, 1], (1, 2), &device)?;
632
633 let output = layer.gather_forward(&input, &indices)?;
634
635 assert_eq!(output.dims(), &[1, 2, 2]);
636 assert_eq!(output.flatten_all()?.to_vec1::<f32>()?, &[1., 2., 6., 15.]);
637 Ok(())
638 }
639}