1use std::{
2 borrow::Cow,
3 fmt::Debug,
4 num::NonZeroUsize,
5 sync::{atomic::AtomicUsize, Arc, Condvar, Mutex, MutexGuard},
6};
7
8use blockwise_fp8::blockwise_fp8_linear_b;
9#[cfg(feature = "metal")]
10use hanzo_ml::D;
11use hanzo_ml::{
12 quantized::{GgmlDType, QMatMul, QTensor},
13 DType, Device, Result, Tensor,
14};
15use pertensor_fp8::pertensor_fp8_linear_b;
16
17#[cfg(feature = "metal")]
18pub mod metal_kernels;
19
20mod afq;
21mod bitsandbytes;
22mod blockwise_fp8;
23pub mod cublaslt;
24pub mod distributed;
25mod dummy;
26pub mod f8q8;
27mod fp8;
28pub mod gemv;
29mod gguf;
30mod gptq;
31mod hqq;
32mod imatrix;
33mod lora;
34mod mxfp4;
35mod pending_layer;
36mod pertensor_fp8;
37pub mod rotary;
38pub mod safetensors;
39mod scalar_fp8;
40mod unquantized;
41mod utils;
42mod vector_fp8;
43
44use gptq::gptq_linear;
45use lora::merge_lora_weights;
46use regex::Regex;
47pub use safetensors::{Shard, ShardedSafeTensors, ShardedVarBuilder};
48
49pub use afq::{AfqBits, AfqGroupSize, AfqInner, AfqLayer};
50pub use bitsandbytes::{BnbLinear, BnbQuantParams, BnbQuantType};
51pub use blockwise_fp8::{
52 blockwise_fp8_moe, fp8_blockwise_dequantize, fp8_blockwise_quantize, BlockwiseFP8Linear,
53};
54pub use distributed::{
55 layers::{
56 compute_kv_shard, compute_n_kv_groups, ColumnParallelLayer, FusedExperts, PackedExperts,
57 ReplicatedLayer, RowParallelLayer,
58 },
59 socket::{Client, Server},
60 BarrierLike, Comm, Id, RingConfig, SumAllReduce,
61};
62pub use dummy::{DummyLayer, DummyLayerInfo};
63pub use f8q8::F8Q8Linear;
64pub use fp8::FP8Linear;
65#[cfg(feature = "cuda")]
66pub use gemv::gemv;
67pub use gemv::{should_use_gemv, GEMV_CONTROLLER};
68#[cfg(feature = "cuda")]
69pub use gguf::cuda::{
70 grouped_moe_gemm_prequantized, indexed_moe_fused_decode, moe_dispatch_build,
71 moe_weighted_reduce_flat, quantize_input_q8_1, ACT_GELU_PYTORCH_TANH, ACT_SILU,
72};
73#[cfg(feature = "cuda")]
74pub use gguf::fast_mmq::{
75 grouped as grouped_moe_mmq, grouped_from_glu_pair as grouped_moe_mmq_from_glu_pair,
76 grouped_pair as grouped_moe_mmq_pair, supports as supports_mmq,
77};
78pub use gguf::GgufMatMul;
79pub use gptq::GptqLayer;
80pub use hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer};
81pub use imatrix::{CollectedImatrixData, ImatrixLayerStats};
82pub use lora::{
83 clear_applied_loras, get_applied_loras, linear_no_bias_static_lora, push_applied_lora,
84 LoraAdapter, LoraConfig, StaticLoraConfig, MULTI_LORA_DELIMITER,
85};
86pub use mxfp4::MXFP4Layer;
87pub use pending_layer::PendingIsqLayer;
88pub use pertensor_fp8::PerTensorFP8Linear;
89pub use unquantized::UnquantLinear;
90pub use utils::flash_attn_sinks_metal;
91pub use utils::flash_attn_sinks_varlen_metal;
92#[cfg(feature = "cuda")]
93pub use utils::gptoss_swiglu_fused;
94#[cfg(feature = "cuda")]
95pub use utils::gptoss_swiglu_interleaved;
96pub use utils::isq::apply_immediate_isq;
97pub use utils::softmax_with_sinks;
98pub use utils::{fused_glu, GluActivationType};
99pub use utils::{log, BitWiseOp, CumSumOp, LeftshiftOp, NonZeroOp, SortOp, UQFF_QUANT_TYPE_OFFSET};
100pub use vector_fp8::{fp8_vector_dequantize, fp8_vector_quantize};
101
102use hanzo_nn::{Conv1d, Conv2d, Linear, Module};
103use serde::{Deserialize, Deserializer, Serialize};
104
105pub struct IsqBackpressure {
111 count: Mutex<usize>,
112 cvar: Condvar,
113 max: usize,
114}
115
116impl IsqBackpressure {
117 pub fn new(max: usize) -> Self {
118 Self {
119 count: Mutex::new(0),
120 cvar: Condvar::new(),
121 max,
122 }
123 }
124
125 pub fn acquire(&self) {
127 let mut count = self.count.lock().expect("ISQ backpressure lock poisoned");
128 while *count >= self.max {
129 count = self
130 .cvar
131 .wait(count)
132 .expect("ISQ backpressure lock poisoned");
133 }
134 *count += 1;
135 }
136
137 pub fn release(&self) {
139 let mut count = self.count.lock().expect("ISQ backpressure lock poisoned");
140 *count = count.saturating_sub(1);
141 self.cvar.notify_one();
142 }
143}
144
145impl Debug for IsqBackpressure {
146 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147 let count = self.count.lock().map(|c| *c).unwrap_or(0);
148 f.debug_struct("IsqBackpressure")
149 .field("outstanding", &count)
150 .field("max", &self.max)
151 .finish()
152 }
153}
154
155#[derive(Clone, Debug)]
156pub struct ImmediateIsqParams {
157 pub guard: QuantizeOntoGuard,
158 pub ty: Option<IsqType>,
159 pub predicates: Vec<Regex>,
160 pub overrides: Vec<ImmediateIsqOverride>,
161 pub pool: Option<Arc<rayon::ThreadPool>>,
165 pub backpressure: Arc<IsqBackpressure>,
167}
168
169#[derive(Clone, Debug)]
170pub struct ImmediateIsqOverride {
171 pub predicate: Regex,
172 pub ty: Option<IsqType>,
173 pub device: Option<Device>,
174}
175
176#[derive(Clone, Debug)]
177pub struct ImmediateIsqMatch {
178 pub ty: IsqType,
179 pub device: Option<Device>,
180}
181
182thread_local! {
183 static ENGINE_IMMEDIATE_ISQ: std::cell::RefCell<Option<ImmediateIsqParams>> = const { std::cell::RefCell::new(None) } ;
184}
185
186pub fn set_immediate_isq(isq: Option<IsqType>, predicates: Vec<Regex>) {
187 let (pool, _) = create_isq_thread_pool(isq);
188 set_immediate_isq_with_pool(isq, predicates, Vec::new(), pool);
189}
190
191pub fn set_immediate_isq_with_pool(
192 isq: Option<IsqType>,
193 predicates: Vec<Regex>,
194 overrides: Vec<ImmediateIsqOverride>,
195 pool: rayon::ThreadPool,
196) {
197 let max_outstanding = pool.current_num_threads() + 1;
200 ENGINE_IMMEDIATE_ISQ.with(|cell| {
201 *cell.borrow_mut() = Some(ImmediateIsqParams {
202 guard: QuantizeOntoGuard::new(),
203 ty: isq,
204 predicates,
205 overrides,
206 backpressure: Arc::new(IsqBackpressure::new(max_outstanding)),
207 pool: Some(Arc::new(pool)),
208 });
209 });
210}
211
212pub fn create_isq_thread_pool(ty: Option<IsqType>) -> (rayon::ThreadPool, usize) {
219 let num_threads = if std::env::var("HANZO_ISQ_SINGLETHREAD").is_ok() {
220 1
221 } else if let Some(ty) = ty {
222 ty.get_max_isq_cpu_threads()
223 .map(usize::from)
224 .unwrap_or_else(rayon::current_num_threads)
225 } else {
226 rayon::current_num_threads()
227 };
228
229 let pool = rayon::ThreadPoolBuilder::new()
230 .num_threads(num_threads)
231 .build()
232 .expect("Failed to create ISQ thread pool");
233 (pool, num_threads)
234}
235
236pub fn get_immediate_isq() -> Option<ImmediateIsqParams> {
237 ENGINE_IMMEDIATE_ISQ.with(|cell| cell.borrow().clone())
238}
239
240pub fn clear_immediate_isq() {
241 ENGINE_IMMEDIATE_ISQ.with(|cell| {
242 *cell.borrow_mut() = None;
243 });
244}
245
246pub fn should_apply_immediate_isq(vb: &ShardedVarBuilder) -> bool {
247 immediate_isq_match(vb).is_some()
248}
249
250pub fn immediate_isq_match(vb: &ShardedVarBuilder) -> Option<ImmediateIsqMatch> {
251 let immediate_isq = get_immediate_isq()?;
252 let prefix = format!("{}.weight", vb.prefix());
254 resolve_immediate_isq(&immediate_isq, &prefix)
255}
256
257fn resolve_immediate_isq(params: &ImmediateIsqParams, prefix: &str) -> Option<ImmediateIsqMatch> {
258 if let Some(override_hit) = params
259 .overrides
260 .iter()
261 .find(|override_pred| override_pred.predicate.is_match(prefix))
262 {
263 if let Some(ty) = override_hit.ty.or(params.ty) {
264 return Some(ImmediateIsqMatch {
265 ty,
266 device: override_hit.device.clone(),
267 });
268 }
269 return None;
270 }
271
272 if let Some(ty) = params.ty {
273 if params
274 .predicates
275 .iter()
276 .any(|predicate| predicate.is_match(prefix))
277 {
278 return Some(ImmediateIsqMatch { ty, device: None });
279 }
280 }
281
282 None
283}
284
285#[derive(Debug, Clone, Serialize)]
286#[serde(tag = "quant_method", rename_all = "lowercase")]
287pub enum QuantizedConfig {
288 GptqAwq {
289 bits: usize,
290 group_size: usize,
291 checkpoint_format: Option<String>,
292 is_awq: bool,
293 },
294 Fp8 {
295 weight_block_size: Option<Vec<usize>>,
296 },
297 Bitsandbytes {
298 bnb_4bit_quant_type: Option<String>,
299 },
300 Afq {
301 bits: usize,
302 group_size: usize,
303 },
304 MXFP4 {},
305}
306
307#[derive(Deserialize)]
309struct RawConfig {
310 quant_method: Option<String>,
311 bits: Option<usize>,
312 group_size: Option<usize>,
313 checkpoint_format: Option<String>,
314 weight_block_size: Option<Vec<usize>>,
315 bnb_4bit_quant_type: Option<String>,
316}
317
318impl<'de> Deserialize<'de> for QuantizedConfig {
320 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
321 where
322 D: Deserializer<'de>,
323 {
324 let raw = RawConfig::deserialize(deserializer)?;
325
326 match &raw.quant_method {
327 Some(m) if m == "gptq" || m == "awq" => {
328 let bits = raw
329 .bits
330 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
331 let group_size = raw
332 .group_size
333 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
334 Ok(QuantizedConfig::GptqAwq {
335 bits,
336 group_size,
337 checkpoint_format: raw.checkpoint_format,
338 is_awq: m == "awq",
339 })
340 }
341 Some(m) if m == "fp8" => {
342 Ok(QuantizedConfig::Fp8 {
344 weight_block_size: raw.weight_block_size,
345 })
346 }
347 Some(m) if m == "bitsandbytes" => Ok(QuantizedConfig::Bitsandbytes {
348 bnb_4bit_quant_type: raw.bnb_4bit_quant_type,
349 }),
350 Some(m) if m == "afq" => {
351 let bits = raw
352 .bits
353 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
354 let group_size = raw
355 .group_size
356 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
357 Ok(QuantizedConfig::Afq { bits, group_size })
358 }
359 Some(m) if m == "mxfp4" => {
360 Ok(QuantizedConfig::MXFP4 { })
361 }
362 None => {
363 let bits = raw
364 .bits
365 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
366 let group_size = raw
367 .group_size
368 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
369 Ok(QuantizedConfig::Afq { bits, group_size })
370 }
371 Some(unknown_method) => {
372 Err(serde::de::Error::custom(format!(
373 "Unknown quantization method: {unknown_method}. Expected one of: gptq, fp8, bitsandbytes, afq, or not specified"
374 )))
375 },
376 }
377 }
378}
379
380impl QuantizedConfig {
381 pub fn name(&self) -> &'static str {
382 match self {
383 Self::GptqAwq { .. } => "gptq",
384 Self::Fp8 { .. } => "fp8",
385 Self::Bitsandbytes { .. } => "bitsandbytes",
386 Self::Afq { .. } => "afq",
387 Self::MXFP4 { .. } => "mxfp4",
388 }
389 }
390
391 pub fn get_bits_name(&self, _vb: &ShardedVarBuilder) -> String {
392 match self {
393 Self::GptqAwq { bits, .. } => format!("{bits} bits"),
394 Self::Fp8 { .. } => "8 bits".to_string(),
395 Self::Bitsandbytes {
396 bnb_4bit_quant_type: Some(_),
397 } => "4 bits".to_string(),
398 Self::Bitsandbytes {
399 bnb_4bit_quant_type: None,
400 } => "8 bits".to_string(),
401 Self::Afq { bits, .. } => format!("{bits} bits"),
402 Self::MXFP4 {} => format!("{} bits", mxfp4::N_BITS),
403 }
404 }
405
406 pub fn pack_factor(&self, dtype: DType) -> usize {
407 match self {
408 Self::GptqAwq { bits, .. } | Self::Afq { bits, .. } => match bits {
409 2 => IsqType::Q2K.pack_factor(dtype),
410 3 => IsqType::Q3K.pack_factor(dtype),
411 4 => IsqType::Q4K.pack_factor(dtype),
412 5 => IsqType::Q5K.pack_factor(dtype),
413 6 => IsqType::Q6K.pack_factor(dtype),
414 8 => IsqType::Q8_0.pack_factor(dtype),
415 40 => 4, other => panic!("Unexpected bits in `pack_factor` {other}"),
417 },
418 Self::Fp8 { .. } => IsqType::Q8_0.pack_factor(dtype),
419 Self::Bitsandbytes {
420 bnb_4bit_quant_type: Some(_),
421 }
422 | Self::Bitsandbytes {
423 bnb_4bit_quant_type: None,
424 } => IsqType::Q4K.pack_factor(dtype),
425 Self::MXFP4 {} => IsqType::Q4_0.pack_factor(dtype),
426 }
427 }
428}
429
430#[derive(Debug, Clone)]
431pub enum QuantMethodConfig {
432 GptqAwq {
433 bits: i32,
434 use_exllama: bool,
435 q_weight: Tensor,
436 qzeros: Option<Tensor>,
437 scales: Tensor,
438 g_idx: Option<Tensor>,
439 bias: Option<Tensor>,
440 workspace: Option<Tensor>,
441 is_marlin: bool,
442 is_awq: bool,
443 },
444 Gguf {
445 q_weight: Arc<QTensor>,
446 b: Option<Tensor>,
447 },
448 Unquantized(Linear),
449 Hqq {
450 tensor: Tensor,
451 bits: HqqBits,
452 group_size: NonZeroUsize,
453 axis: HqqAxis,
454 optimization_steps: Option<usize>,
455 round_zeros: Option<bool>,
456 channel_wise: Option<bool>,
457 bias: Option<Tensor>,
458 },
459 Dummy,
460 FP8 {
461 lin: Linear,
462 dtype: DType,
463 },
464 Bnb {
465 weight: Tensor,
466 bias: Option<Tensor>,
467 params: BnbQuantParams,
468 quant_ty: BnbQuantType,
469 },
470 BlockwiseFP8 {
471 weight: Tensor,
472 weight_scale_inv: Tensor,
473 bias: Option<Tensor>,
474 dequant_dtype: DType,
475 weight_block_size: Vec<usize>,
476 },
477 PerTensorFP8 {
478 weight: Tensor,
479 weight_scale_inv: Tensor,
480 activation_scale: Option<Tensor>,
481 bias: Option<Tensor>,
482 dequant_dtype: DType,
483 },
484 Afq {
485 weight: Tensor,
486 bias: Option<Tensor>,
487 bits: AfqBits,
488 group_size: AfqGroupSize,
489 },
490 MXFP4 {
491 blocks: Tensor,
492 scales: Tensor,
493 bias: Option<Tensor>,
494 },
495}
496
497pub struct MatMul;
500
501impl MatMul {
502 pub fn matmul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
504 #[cfg(feature = "accelerate")]
505 {
506 let original_dtype = a.dtype();
507 a.to_dtype(DType::F32)?
508 .matmul(&b.to_dtype(DType::F32)?)?
509 .to_dtype(original_dtype)
510 }
511 #[cfg(not(feature = "accelerate"))]
512 {
513 if a.device().is_cpu() {
514 let original_dtype = a.dtype();
515 a.to_dtype(DType::F16)?
516 .matmul(&b.to_dtype(DType::F16)?)?
517 .to_dtype(original_dtype)
518 } else {
519 a.matmul(b)
520 }
521 }
522 }
523
524 pub fn matmul_affine_div(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
527 self.matmul(a, b)? / scale
529 }
530
531 pub fn matmul_affine_mul(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
534 self.matmul(a, b)? * scale
536 }
537
538 pub fn qmatmul(&self, x: &Tensor, matmul: &QMatMul) -> Result<Tensor> {
540 matmul.forward(x)
541 }
542}
543
544pub struct Convolution;
547
548impl Convolution {
549 pub fn forward_1d(&self, layer: &Conv1d, x: &Tensor) -> Result<Tensor> {
550 if x.device().is_cpu() {
551 let original_dtype = x.dtype();
552 Conv1d::new(
553 layer.weight().to_dtype(DType::F32)?,
554 layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
555 *layer.config(),
556 )
557 .forward(&x.to_dtype(DType::F32)?)?
558 .to_dtype(original_dtype)
559 } else {
560 layer.forward(x)
561 }
562 }
563
564 pub fn forward_2d(&self, layer: &Conv2d, x: &Tensor) -> Result<Tensor> {
565 if x.device().is_cpu() {
566 let original_dtype = x.dtype();
567 Conv2d::new(
568 layer.weight().to_dtype(DType::F32)?,
569 layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
570 *layer.config(),
571 )
572 .forward(&x.to_dtype(DType::F32)?)?
573 .to_dtype(original_dtype)
574 } else {
575 layer.forward(x)
576 }
577 }
578}
579
580#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq, Serialize, Deserialize)]
582pub enum IsqType {
583 Q4_0,
584 Q4_1,
585 Q5_0,
586 Q5_1,
587 Q8_0,
588 Q8_1,
589 Q2K,
590 Q3K,
591 Q4K,
592 Q5K,
593 Q6K,
594 Q8K,
595 HQQ8,
596 HQQ4,
597 F8E4M3,
601 AFQ8,
602 AFQ6,
603 AFQ4,
604 AFQ3,
605 AFQ2,
606 F8Q8,
607 MXFP4,
608}
609
610#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
614pub enum IsqBits {
615 Two,
617 Three,
619 Four,
621 Five,
623 Six,
625 Eight,
627}
628
629impl IsqBits {
630 pub fn resolve(self, device: &Device) -> IsqType {
632 match (self, device.is_metal()) {
633 (Self::Two, true) => IsqType::AFQ2,
634 (Self::Two, false) => IsqType::Q2K,
635 (Self::Three, true) => IsqType::AFQ3,
636 (Self::Three, false) => IsqType::Q3K,
637 (Self::Four, true) => IsqType::AFQ4,
638 (Self::Four, false) => IsqType::Q4K,
639 (Self::Five, _) => IsqType::Q5K,
640 (Self::Six, true) => IsqType::AFQ6,
641 (Self::Six, false) => IsqType::Q6K,
642 (Self::Eight, true) => IsqType::AFQ8,
643 (Self::Eight, false) => IsqType::Q8_0,
644 }
645 }
646
647 pub fn expand(self) -> Vec<IsqType> {
650 #[cfg(feature = "metal")]
651 match self {
652 Self::Two => vec![IsqType::AFQ2, IsqType::Q2K],
653 Self::Three => vec![IsqType::AFQ3, IsqType::Q3K],
654 Self::Four => vec![IsqType::AFQ4, IsqType::Q4K],
655 Self::Five => vec![IsqType::Q5K],
656 Self::Six => vec![IsqType::AFQ6, IsqType::Q6K],
657 Self::Eight => vec![IsqType::AFQ8, IsqType::Q8_0],
658 }
659 #[cfg(not(feature = "metal"))]
660 match self {
661 Self::Two => vec![IsqType::Q2K, IsqType::AFQ2],
662 Self::Three => vec![IsqType::Q3K, IsqType::AFQ3],
663 Self::Four => vec![IsqType::Q4K, IsqType::AFQ4],
664 Self::Five => vec![IsqType::Q5K],
665 Self::Six => vec![IsqType::Q6K, IsqType::AFQ6],
666 Self::Eight => vec![IsqType::Q8_0, IsqType::AFQ8],
667 }
668 }
669}
670
671impl TryFrom<&str> for IsqBits {
672 type Error = ();
673 fn try_from(s: &str) -> std::result::Result<Self, ()> {
674 match s {
675 "2" => Ok(Self::Two),
676 "3" => Ok(Self::Three),
677 "4" => Ok(Self::Four),
678 "5" => Ok(Self::Five),
679 "6" => Ok(Self::Six),
680 "8" => Ok(Self::Eight),
681 _ => Err(()),
682 }
683 }
684}
685
686impl std::fmt::Display for IsqType {
687 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
688 match self {
689 Self::Q4_0 => write!(f, "q4_0"),
690 Self::Q4_1 => write!(f, "q4_1"),
691 Self::Q5_0 => write!(f, "q5_0"),
692 Self::Q5_1 => write!(f, "q5_1"),
693 Self::Q8_0 => write!(f, "q8_0"),
694 Self::Q8_1 => write!(f, "q8_1"),
695 Self::Q2K => write!(f, "q2k"),
696 Self::Q3K => write!(f, "q3k"),
697 Self::Q4K => write!(f, "q4k"),
698 Self::Q5K => write!(f, "q5k"),
699 Self::Q6K => write!(f, "q6k"),
700 Self::Q8K => write!(f, "q8k"),
701 Self::HQQ8 => write!(f, "hqq8"),
702 Self::HQQ4 => write!(f, "hqq4"),
703 Self::F8E4M3 => write!(f, "fp8"),
704 Self::AFQ8 => write!(f, "afq8"),
705 Self::AFQ6 => write!(f, "afq6"),
706 Self::AFQ4 => write!(f, "afq4"),
707 Self::AFQ3 => write!(f, "afq3"),
708 Self::AFQ2 => write!(f, "afq2"),
709 Self::F8Q8 => write!(f, "f8q8"),
710 Self::MXFP4 => write!(f, "mxfp4"),
711 }
712 }
713}
714
715impl IsqType {
716 pub fn pack_factor(&self, dtype: DType) -> usize {
719 match self {
720 Self::Q4_0 | Self::AFQ4 => (dtype.size_in_bytes() * GgmlDType::Q4_0.block_size())
721 .div_ceil(GgmlDType::Q4_0.type_size()),
722 Self::Q4_1 => (dtype.size_in_bytes() * GgmlDType::Q4_1.block_size())
723 .div_ceil(GgmlDType::Q4_1.type_size()),
724 Self::Q5_0 => (dtype.size_in_bytes() * GgmlDType::Q5_0.block_size())
725 .div_ceil(GgmlDType::Q5_0.type_size()),
726 Self::Q5_1 => (dtype.size_in_bytes() * GgmlDType::Q5_1.block_size())
727 .div_ceil(GgmlDType::Q5_1.type_size()),
728 Self::Q8_0 | Self::AFQ8 => (dtype.size_in_bytes() * GgmlDType::Q8_0.block_size())
729 .div_ceil(GgmlDType::Q8_0.type_size()),
730 Self::Q8_1 => (dtype.size_in_bytes() * GgmlDType::Q8_1.block_size())
731 .div_ceil(GgmlDType::Q8_1.type_size()),
732 Self::Q2K | Self::AFQ2 => (dtype.size_in_bytes() * GgmlDType::Q2K.block_size())
733 .div_ceil(GgmlDType::Q2K.type_size()),
734 Self::Q3K | Self::AFQ3 => (dtype.size_in_bytes() * GgmlDType::Q3K.block_size())
735 .div_ceil(GgmlDType::Q3K.type_size()),
736 Self::Q4K => (dtype.size_in_bytes() * GgmlDType::Q4K.block_size())
737 .div_ceil(GgmlDType::Q4K.type_size()),
738 Self::Q5K => (dtype.size_in_bytes() * GgmlDType::Q5K.block_size())
739 .div_ceil(GgmlDType::Q5K.type_size()),
740 Self::Q6K | Self::AFQ6 => (dtype.size_in_bytes() * GgmlDType::Q6K.block_size())
741 .div_ceil(GgmlDType::Q6K.type_size()),
742 Self::Q8K => (dtype.size_in_bytes() * GgmlDType::Q8K.block_size())
743 .div_ceil(GgmlDType::Q8K.type_size()),
744 Self::F8Q8 => (dtype.size_in_bytes() * 32).div_ceil(33),
746 Self::HQQ4 => 4,
748 Self::HQQ8 => 2,
749 Self::F8E4M3 => 2,
750 Self::MXFP4 => 3,
753 }
754 }
755
756 pub fn get_max_isq_cpu_threads(&self) -> Option<NonZeroUsize> {
757 match self {
758 IsqType::HQQ4
760 | IsqType::HQQ8
761 | IsqType::AFQ2
762 | IsqType::AFQ3
763 | IsqType::AFQ4
764 | IsqType::AFQ6
765 | IsqType::AFQ8
766 | IsqType::MXFP4 => {
767 Some(1.try_into().unwrap())
769 }
770 IsqType::F8E4M3 | IsqType::F8Q8 => None,
771 IsqType::Q2K
772 | IsqType::Q3K
773 | IsqType::Q4K
774 | IsqType::Q4_0
775 | IsqType::Q4_1
776 | IsqType::Q5K
777 | IsqType::Q5_0
778 | IsqType::Q5_1
779 | IsqType::Q6K
780 | IsqType::Q8K
781 | IsqType::Q8_0
782 | IsqType::Q8_1 => None,
783 }
784 }
785}
786
787impl TryFrom<IsqType> for GgmlDType {
788 type Error = hanzo_ml::Error;
789
790 fn try_from(value: IsqType) -> Result<Self> {
791 let tp = match value {
792 IsqType::Q2K => Self::Q2K,
793 IsqType::Q3K => Self::Q3K,
794 IsqType::Q4K => Self::Q4K,
795 IsqType::Q4_0 => Self::Q4_0,
796 IsqType::Q4_1 => Self::Q4_1,
797 IsqType::Q5K => Self::Q5K,
798 IsqType::Q5_0 => Self::Q5_0,
799 IsqType::Q5_1 => Self::Q5_1,
800 IsqType::Q6K => Self::Q6K,
801 IsqType::Q8K => Self::Q8K,
802 IsqType::Q8_0 => Self::Q8_0,
803 IsqType::Q8_1 => Self::Q8_1,
804 _ => hanzo_ml::bail!("Expected valid GGML ISQ type."),
805 };
806 #[cfg(feature = "cuda")]
807 {
808 if !matches!(
809 tp,
810 GgmlDType::Q4_0
811 | GgmlDType::Q4_1
812 | GgmlDType::Q5_0
813 | GgmlDType::Q5_1
814 | GgmlDType::Q8_0
815 | GgmlDType::Q2K
816 | GgmlDType::Q3K
817 | GgmlDType::Q4K
818 | GgmlDType::Q5K
819 | GgmlDType::Q6K
820 ) {
821 hanzo_ml::bail!("GGML ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`")
822 }
823 }
824 Ok(tp)
825 }
826}
827
828impl TryFrom<GgmlDType> for IsqType {
829 type Error = hanzo_ml::Error;
830
831 fn try_from(value: GgmlDType) -> Result<Self> {
832 match value {
833 GgmlDType::Q2K => Ok(Self::Q2K),
834 GgmlDType::Q3K => Ok(Self::Q3K),
835 GgmlDType::Q4K => Ok(Self::Q4K),
836 GgmlDType::Q5K => Ok(Self::Q5K),
837 GgmlDType::Q6K => Ok(Self::Q6K),
838 GgmlDType::Q4_0 => Ok(Self::Q4_0),
839 GgmlDType::Q4_1 => Ok(Self::Q4_1),
840 GgmlDType::Q5_0 => Ok(Self::Q5_0),
841 GgmlDType::Q5_1 => Ok(Self::Q5_1),
842 GgmlDType::Q8_0 => Ok(Self::Q8_0),
843 GgmlDType::Q8_1 => Ok(Self::Q8_1),
844 GgmlDType::Q8K => Ok(Self::Q8K),
845 GgmlDType::BF16 | GgmlDType::F32 | GgmlDType::F16 => {
846 hanzo_ml::bail!("Expected valid GGML ISQ type.")
847 }
848 }
849 }
850}
851
852#[derive(Debug, Clone, Copy)]
853pub enum QuantizedSerdeType {
854 Gguf = 0,
855 Unquant = 1,
856 Hqq = 2,
857 Fp8 = 3,
858 Afq = 4,
859 F8Q8 = 5,
860 Mxfp4 = 6,
861}
862
863impl TryFrom<usize> for QuantizedSerdeType {
864 type Error = hanzo_ml::Error;
865 fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
866 match value {
867 0 => Ok(Self::Gguf),
868 1 => Ok(Self::Unquant),
869 2 => Ok(Self::Hqq),
870 3 => Ok(Self::Fp8),
871 4 => Ok(Self::Afq),
872 5 => Ok(Self::F8Q8),
873 6 => Ok(Self::Mxfp4),
874 other => hanzo_ml::bail!("QuantizedSerdeType {other} is invalid."),
875 }
876 }
877}
878
879pub trait QuantizedSerde {
880 fn name(&self) -> &'static str;
881 fn isq_serde_supported(&self) -> bool {
882 false
883 }
884 fn serialize(&self) -> Result<Cow<'_, [u8]>> {
885 hanzo_ml::bail!("`QuantizedSerde::serialize` is not supported.")
886 }
887 fn deserialize(
888 _data: Cow<[u8]>,
889 _device: &Device,
890 _comm: &Arc<crate::Comm>,
891 _guard: QuantizeOntoGuard,
892 ) -> Result<Arc<dyn QuantMethod>>
893 where
894 Self: Sized,
895 {
896 hanzo_ml::bail!("`QuantizedSerde::deserialize` is not supported.")
897 }
898 fn deserialize_ext_bias(
899 _data: Cow<[u8]>,
900 _device: &Device,
901 _guard: QuantizeOntoGuard,
902 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
903 where
904 Self: Sized,
905 {
906 hanzo_ml::bail!("`QuantizedSerde::deserialize_ext_bias` is not supported.")
907 }
908 fn serialize_with_bias(&self, _bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
910 hanzo_ml::bail!("`QuantizedSerde::serialize_with_bias` is not supported.")
911 }
912}
913
914#[derive(Clone, Debug)]
916#[allow(unused)]
917pub struct QuantizeOntoGuard {
918 pub inner: Arc<Mutex<()>>,
919}
920
921pub enum QuantizeOntoDropGuard<'a> {
923 Real(MutexGuard<'a, ()>),
924 Fake,
925}
926
927impl Default for QuantizeOntoGuard {
928 fn default() -> Self {
929 Self::new()
930 }
931}
932
933impl QuantizeOntoGuard {
934 pub fn new() -> Self {
935 QuantizeOntoGuard {
936 inner: Arc::new(Mutex::new(())),
937 }
938 }
939
940 pub fn acquire(&self, device: &Device) -> QuantizeOntoDropGuard<'_> {
944 #[cfg(feature = "cuda")]
945 {
946 let _ = device;
947 QuantizeOntoDropGuard::Fake
948 }
949
950 #[cfg(not(feature = "cuda"))]
951 {
952 #[cfg(feature = "metal")]
953 if let Device::Metal(dev) = device {
954 dev.wait_until_completed()
956 .expect("Failed to flush command buffer.");
957 }
958 #[cfg(not(feature = "metal"))]
959 let _ = device;
960
961 QuantizeOntoDropGuard::Real(self.inner.lock().expect("QuantizeOntoGuard was poisoned!"))
962 }
963 }
964}
965
966pub enum DistributedKind {
967 ColumnParallel,
968 RowParallel,
969 Replicated,
970}
971
972pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde {
974 fn new(method: QuantMethodConfig) -> Result<Self>
975 where
976 Self: Sized;
977
978 fn dequantize_w(&self) -> Result<Tensor>;
979
980 fn forward(&self, a: &Tensor) -> Result<Tensor> {
983 if let Some(t) = self.quantized_act_type() {
984 let original_ty = a.dtype();
985 self.forward_raw(&a.to_dtype(t)?)?.to_dtype(original_ty)
986 } else {
987 self.forward_raw(a)
988 }
989 }
990
991 fn forward_raw(&self, a: &Tensor) -> Result<Tensor>;
994
995 fn gather_forward(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
1001 if let Some(t) = self.quantized_act_type() {
1002 let original_ty = a.dtype();
1003 self.gather_forward_raw(&a.to_dtype(t)?, indices)?
1004 .to_dtype(original_ty)
1005 } else {
1006 self.gather_forward_raw(a, indices)
1007 }
1008 }
1009
1010 fn gather_forward_raw(&self, _a: &Tensor, _indices: &Tensor) -> Result<Tensor> {
1013 hanzo_ml::bail!(
1014 "{} does not support `gather_forward`. Please raise an issue.",
1015 self.name()
1016 )
1017 }
1018
1019 #[cfg(feature = "cuda")]
1022 fn get_qtensor(&self) -> Option<&hanzo_ml::quantized::QTensor> {
1023 None
1024 }
1025
1026 fn afq_inner(&self) -> Option<crate::afq::AfqInner<'_>> {
1029 None
1030 }
1031
1032 fn quantized_act_type(&self) -> Option<DType>;
1034
1035 fn dtype_and_device(&self) -> (DType, Device);
1037
1038 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>>;
1040
1041 fn apply_isq(
1043 self: Arc<Self>,
1044 dtype: Option<IsqType>,
1045 device: Device,
1046 n_quantized: &AtomicUsize,
1047 imatrix_weight: Option<Vec<f32>>,
1048 guard: QuantizeOntoGuard,
1049 ) -> Result<Arc<dyn QuantMethod>>;
1050
1051 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
1052 None
1053 }
1054
1055 fn has_bias(&self) -> bool {
1056 false
1057 }
1058
1059 fn begin_track_stats(&mut self) -> Result<()> {
1061 hanzo_ml::bail!("`{}` does not support tracking stats.", self.name())
1062 }
1063
1064 fn end_track_stats(&self) -> Result<Tensor> {
1066 hanzo_ml::bail!("`{}` does not support tracking stats.", self.name())
1067 }
1068
1069 fn is_distributed(&self) -> Option<DistributedKind> {
1070 None
1071 }
1072
1073 fn dummy_info(&self) -> Option<&DummyLayerInfo> {
1074 None
1075 }
1076}
1077
1078impl Module for dyn QuantMethod {
1079 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1080 QuantMethod::forward(self, xs)
1081 }
1082}
1083
1084#[cfg(feature = "cuda")]
1085pub fn try_fused_quantized_gate_up(
1086 xs: &Tensor,
1087 gate: &dyn QuantMethod,
1088 up: &dyn QuantMethod,
1089 activation: GluActivationType,
1090) -> Result<Option<Tensor>> {
1091 if gate.has_bias() || up.has_bias() {
1092 return Ok(None);
1093 }
1094 if !matches!(xs.dtype(), DType::BF16 | DType::F16 | DType::F32) {
1095 return Ok(None);
1096 }
1097
1098 let Some(gate_q) = gate.get_qtensor() else {
1099 return Ok(None);
1100 };
1101 let Some(up_q) = up.get_qtensor() else {
1102 return Ok(None);
1103 };
1104 if gate_q.dtype() != GgmlDType::Q8_0 || up_q.dtype() != GgmlDType::Q8_0 {
1105 return Ok(None);
1106 }
1107 if gate_q.shape() != up_q.shape() {
1108 return Ok(None);
1109 }
1110
1111 let Some((&k, batch_dims)) = xs.dims().split_last() else {
1112 return Ok(None);
1113 };
1114 let flat_batch = batch_dims.iter().product::<usize>();
1115 if flat_batch == 0 || flat_batch > gguf::fast_mmvq::MMVQ_MAX_BATCH {
1116 return Ok(None);
1117 }
1118 let (_, ncols) = gate_q.shape().dims2()?;
1119 if k != ncols {
1120 return Ok(None);
1121 }
1122
1123 Ok(Some(gguf::fast_mmvq::fused_glu(
1124 gate_q, up_q, xs, activation,
1125 )?))
1126}
1127
1128#[cfg(feature = "cuda")]
1129pub fn try_fused_quantized_qkv(
1130 xs: &Tensor,
1131 q: &dyn QuantMethod,
1132 k: &dyn QuantMethod,
1133 v: &dyn QuantMethod,
1134) -> Result<Option<(Tensor, Tensor, Tensor)>> {
1135 if q.has_bias() || k.has_bias() || v.has_bias() {
1136 return Ok(None);
1137 }
1138 if !matches!(xs.dtype(), DType::BF16 | DType::F16 | DType::F32) {
1139 return Ok(None);
1140 }
1141
1142 let Some(q_q) = q.get_qtensor() else {
1143 return Ok(None);
1144 };
1145 let Some(k_q) = k.get_qtensor() else {
1146 return Ok(None);
1147 };
1148 let Some(v_q) = v.get_qtensor() else {
1149 return Ok(None);
1150 };
1151 let dtype = q_q.dtype();
1152 if dtype != k_q.dtype() || dtype != v_q.dtype() || !gguf::fast_mmvq::supports(dtype) {
1153 return Ok(None);
1154 }
1155
1156 let Some((&input_cols, batch_dims)) = xs.dims().split_last() else {
1157 return Ok(None);
1158 };
1159 let flat_batch = batch_dims.iter().product::<usize>();
1160 if flat_batch == 0 || flat_batch > gguf::fast_mmvq::MMVQ_MAX_BATCH {
1161 return Ok(None);
1162 }
1163 let (_, q_cols) = q_q.shape().dims2()?;
1164 let (_, k_cols) = k_q.shape().dims2()?;
1165 let (_, v_cols) = v_q.shape().dims2()?;
1166 if input_cols != q_cols || input_cols != k_cols || input_cols != v_cols {
1167 return Ok(None);
1168 }
1169
1170 Ok(Some(gguf::fast_mmvq::fused_qkv(q_q, k_q, v_q, xs)?))
1171}
1172
1173#[cfg(feature = "metal")]
1176pub fn try_fused_gate_up_metal(
1177 xs: &Tensor,
1178 gate: &dyn QuantMethod,
1179 up: &dyn QuantMethod,
1180 activation: GluActivationType,
1181) -> Result<Option<Tensor>> {
1182 use hanzo_ml::{backend::BackendStorage, MetalStorage, Shape, Storage};
1183
1184 if gate.has_bias() || up.has_bias() {
1185 return Ok(None);
1186 }
1187 if !matches!(xs.dtype(), DType::BF16 | DType::F16 | DType::F32) {
1188 return Ok(None);
1189 }
1190 if !xs.device().is_metal() {
1191 return Ok(None);
1192 }
1193
1194 let Some(gi) = gate.afq_inner() else {
1195 return Ok(None);
1196 };
1197 let Some(ui) = up.afq_inner() else {
1198 return Ok(None);
1199 };
1200 if gi.bits != ui.bits || gi.group_size != ui.group_size {
1201 return Ok(None);
1202 }
1203 if gi.scales.dtype() != ui.scales.dtype() {
1204 return Ok(None);
1205 }
1206 if gi.w_q.rank() != 2 || ui.w_q.rank() != 2 {
1207 return Ok(None);
1208 }
1209 let k = xs.dim(D::Minus1)?;
1210 let n_gate = gi.w_q.dim(0)?;
1211 let n_up = ui.w_q.dim(0)?;
1212 if n_gate != n_up {
1213 return Ok(None);
1214 }
1215 let n = n_gate;
1216 let probe_m = xs.elem_count() / k;
1219 if probe_m < 16 {
1220 return Ok(None);
1221 }
1222 if k * gi.bits as usize / 8 / 4 != gi.w_q.dim(1)? {
1223 return Ok(None);
1225 }
1226
1227 let act_code: u32 = match activation {
1228 GluActivationType::Silu => 0,
1229 GluActivationType::Gelu => 1,
1230 GluActivationType::GeluErf => 2,
1231 GluActivationType::Relu => 3,
1232 };
1233
1234 let xs = xs.contiguous()?;
1235 let m = xs.elem_count() / k;
1236 if m == 0 {
1237 return Ok(None);
1238 }
1239
1240 let (xs_storage, xs_layout) = xs.storage_and_layout();
1241 let Storage::Metal(xs_storage) = &*xs_storage else {
1242 return Ok(None);
1243 };
1244 let (g_w_s, _) = gi.w_q.storage_and_layout();
1245 let Storage::Metal(g_w_s) = &*g_w_s else {
1246 return Ok(None);
1247 };
1248 let (g_s_s, _) = gi.scales.storage_and_layout();
1249 let Storage::Metal(g_s_s) = &*g_s_s else {
1250 return Ok(None);
1251 };
1252 let (g_b_s, _) = gi.biases.storage_and_layout();
1253 let Storage::Metal(g_b_s) = &*g_b_s else {
1254 return Ok(None);
1255 };
1256 let (u_w_s, _) = ui.w_q.storage_and_layout();
1257 let Storage::Metal(u_w_s) = &*u_w_s else {
1258 return Ok(None);
1259 };
1260 let (u_s_s, _) = ui.scales.storage_and_layout();
1261 let Storage::Metal(u_s_s) = &*u_s_s else {
1262 return Ok(None);
1263 };
1264 let (u_b_s, _) = ui.biases.storage_and_layout();
1265 let Storage::Metal(u_b_s) = &*u_b_s else {
1266 return Ok(None);
1267 };
1268
1269 let device = xs_storage.device().clone();
1270 let dtype = xs.dtype();
1271 let mut out_shape = xs.dims().to_vec();
1272 *out_shape.last_mut().unwrap() = n;
1273 let out = device.new_buffer(out_shape.iter().product(), dtype, "afq-gate-up-out")?;
1274
1275 let encoder = device.command_encoder()?;
1276 encoder.set_label("afq-gate-up");
1277
1278 metal_kernels::call_afq_qmm_gate_up(
1279 device.device(),
1280 &encoder,
1281 &metal_kernels::Kernels::new(),
1282 dtype,
1283 (
1284 xs_storage.buffer(),
1285 xs_layout.start_offset() * dtype.size_in_bytes(),
1286 ),
1287 g_w_s.buffer(),
1288 g_s_s.buffer(),
1289 g_b_s.buffer(),
1290 u_w_s.buffer(),
1291 u_s_s.buffer(),
1292 u_b_s.buffer(),
1293 &out,
1294 m,
1295 n,
1296 k,
1297 gi.bits as usize,
1298 gi.group_size as usize,
1299 act_code,
1300 )
1301 .map_err(hanzo_ml::Error::wrap)?;
1302
1303 let out_t = Tensor::from((
1304 Storage::Metal(MetalStorage::new(
1305 out,
1306 device.clone(),
1307 out_shape.iter().product(),
1308 dtype,
1309 )),
1310 Shape::from(out_shape),
1311 ));
1312 Ok(Some(out_t))
1313}
1314
1315#[cfg(feature = "metal")]
1318pub fn try_fused_qkv_metal(
1319 xs: &Tensor,
1320 q: &dyn QuantMethod,
1321 k: &dyn QuantMethod,
1322 v: &dyn QuantMethod,
1323) -> Result<Option<(Tensor, Tensor, Tensor)>> {
1324 use hanzo_ml::{backend::BackendStorage, MetalStorage, Shape, Storage};
1325
1326 if q.has_bias() || k.has_bias() || v.has_bias() {
1327 return Ok(None);
1328 }
1329 if !matches!(xs.dtype(), DType::BF16 | DType::F16 | DType::F32) {
1330 return Ok(None);
1331 }
1332 if !xs.device().is_metal() {
1333 return Ok(None);
1334 }
1335
1336 let Some(qi) = q.afq_inner() else {
1337 return Ok(None);
1338 };
1339 let Some(ki) = k.afq_inner() else {
1340 return Ok(None);
1341 };
1342 let Some(vi) = v.afq_inner() else {
1343 return Ok(None);
1344 };
1345 if qi.bits != ki.bits || qi.bits != vi.bits {
1346 return Ok(None);
1347 }
1348 if qi.group_size != ki.group_size || qi.group_size != vi.group_size {
1349 return Ok(None);
1350 }
1351 if qi.scales.dtype() != ki.scales.dtype() || qi.scales.dtype() != vi.scales.dtype() {
1352 return Ok(None);
1353 }
1354 if qi.w_q.rank() != 2 || ki.w_q.rank() != 2 || vi.w_q.rank() != 2 {
1355 return Ok(None);
1356 }
1357 let n_q = qi.w_q.dim(0)?;
1358 let n_k = ki.w_q.dim(0)?;
1359 let n_v = vi.w_q.dim(0)?;
1360 if n_q % 32 != 0 || n_k % 32 != 0 || n_v % 32 != 0 {
1364 return Ok(None);
1365 }
1366 let k_dim = xs.dim(D::Minus1)?;
1367 let probe_m = xs.elem_count() / k_dim;
1370 if probe_m < 16 {
1371 return Ok(None);
1372 }
1373
1374 let xs = xs.contiguous()?;
1375 let m = xs.elem_count() / k_dim;
1376 if m == 0 {
1377 return Ok(None);
1378 }
1379
1380 let (xs_s, xs_l) = xs.storage_and_layout();
1381 let Storage::Metal(xs_s) = &*xs_s else {
1382 return Ok(None);
1383 };
1384 let qws = qi.w_q.storage_and_layout().0;
1385 let qss = qi.scales.storage_and_layout().0;
1386 let qbs = qi.biases.storage_and_layout().0;
1387 let kws = ki.w_q.storage_and_layout().0;
1388 let kss = ki.scales.storage_and_layout().0;
1389 let kbs = ki.biases.storage_and_layout().0;
1390 let vws = vi.w_q.storage_and_layout().0;
1391 let vss = vi.scales.storage_and_layout().0;
1392 let vbs = vi.biases.storage_and_layout().0;
1393 let (Storage::Metal(qw_m), Storage::Metal(qs_m), Storage::Metal(qb_m)) = (&*qws, &*qss, &*qbs)
1394 else {
1395 return Ok(None);
1396 };
1397 let (Storage::Metal(kw_m), Storage::Metal(ks_m), Storage::Metal(kb_m)) = (&*kws, &*kss, &*kbs)
1398 else {
1399 return Ok(None);
1400 };
1401 let (Storage::Metal(vw_m), Storage::Metal(vs_m), Storage::Metal(vb_m)) = (&*vws, &*vss, &*vbs)
1402 else {
1403 return Ok(None);
1404 };
1405
1406 let device = xs_s.device().clone();
1407 let dtype = xs.dtype();
1408 let mut q_shape = xs.dims().to_vec();
1409 let mut k_shape = q_shape.clone();
1410 let mut v_shape = q_shape.clone();
1411 *q_shape.last_mut().unwrap() = n_q;
1412 *k_shape.last_mut().unwrap() = n_k;
1413 *v_shape.last_mut().unwrap() = n_v;
1414 let q_out = device.new_buffer(q_shape.iter().product(), dtype, "afq-qkv-q")?;
1415 let k_out = device.new_buffer(k_shape.iter().product(), dtype, "afq-qkv-k")?;
1416 let v_out = device.new_buffer(v_shape.iter().product(), dtype, "afq-qkv-v")?;
1417
1418 let encoder = device.command_encoder()?;
1419 encoder.set_label("afq-qkv");
1420
1421 metal_kernels::call_afq_qmm_qkv(
1422 device.device(),
1423 &encoder,
1424 &metal_kernels::Kernels::new(),
1425 dtype,
1426 (xs_s.buffer(), xs_l.start_offset() * dtype.size_in_bytes()),
1427 qw_m.buffer(),
1428 qs_m.buffer(),
1429 qb_m.buffer(),
1430 kw_m.buffer(),
1431 ks_m.buffer(),
1432 kb_m.buffer(),
1433 vw_m.buffer(),
1434 vs_m.buffer(),
1435 vb_m.buffer(),
1436 &q_out,
1437 &k_out,
1438 &v_out,
1439 m,
1440 n_q,
1441 n_k,
1442 n_v,
1443 k_dim,
1444 qi.bits as usize,
1445 qi.group_size as usize,
1446 )
1447 .map_err(hanzo_ml::Error::wrap)?;
1448
1449 let q_t = Tensor::from((
1450 Storage::Metal(MetalStorage::new(
1451 q_out,
1452 device.clone(),
1453 q_shape.iter().product(),
1454 dtype,
1455 )),
1456 Shape::from(q_shape),
1457 ));
1458 let k_t = Tensor::from((
1459 Storage::Metal(MetalStorage::new(
1460 k_out,
1461 device.clone(),
1462 k_shape.iter().product(),
1463 dtype,
1464 )),
1465 Shape::from(k_shape),
1466 ));
1467 let v_t = Tensor::from((
1468 Storage::Metal(MetalStorage::new(
1469 v_out,
1470 device.clone(),
1471 v_shape.iter().product(),
1472 dtype,
1473 )),
1474 Shape::from(v_shape),
1475 ));
1476 Ok(Some((q_t, k_t, v_t)))
1477}
1478
1479fn tensor_prefix(vb: &ShardedVarBuilder) -> String {
1480 let prefix = vb.prefix();
1481 if prefix.is_empty() {
1482 "<root>".to_string()
1483 } else {
1484 prefix
1485 }
1486}
1487
1488fn missing_required_tensors(vb: &ShardedVarBuilder, required: &[&str]) -> Vec<String> {
1489 required
1490 .iter()
1491 .copied()
1492 .filter(|name| !vb.contains_tensor(name))
1493 .map(|name| safetensors::full_tensor_name(vb, name))
1494 .collect()
1495}
1496
1497pub(crate) fn has_missing_required_tensors(vb: &ShardedVarBuilder, required: &[&str]) -> bool {
1498 required.iter().any(|name| !vb.contains_tensor(name))
1499}
1500
1501pub(crate) fn make_dummy_or_error(
1502 context: &str,
1503 vb: &ShardedVarBuilder,
1504 required: &[&str],
1505) -> Result<Arc<dyn QuantMethod>> {
1506 let missing = missing_required_tensors(vb, required);
1507 if missing.is_empty() {
1508 hanzo_ml::bail!(
1509 "Internal error: requested DummyLayer for {context} without missing tensors"
1510 );
1511 }
1512
1513 let has_uqff_placeholder = required
1514 .iter()
1515 .any(|name| safetensors::is_uqff_dummy_tensor(vb, name));
1516 if !has_uqff_placeholder {
1517 hanzo_ml::bail!(
1518 "Missing required tensor(s) for {context} at prefix `{}`: {}. Dummy layers are only allowed for tensors intentionally omitted while loading UQFF artifacts.",
1519 tensor_prefix(vb),
1520 missing.join(", ")
1521 );
1522 }
1523
1524 Ok(Arc::new(DummyLayer::placeholder(DummyLayerInfo {
1525 context: context.to_string(),
1526 prefix: tensor_prefix(vb),
1527 missing_tensors: missing,
1528 })))
1529}
1530
1531pub fn linear_no_bias(
1532 in_dim: usize,
1533 out_dim: usize,
1534 config: &Option<QuantizedConfig>,
1535 vb: ShardedVarBuilder,
1536) -> Result<Arc<dyn QuantMethod>> {
1537 let base_vb = vb.clone();
1538 let vb = if should_apply_immediate_isq(&vb) {
1539 vb.set_device(Device::Cpu)
1540 } else {
1541 vb
1542 };
1543
1544 let layer = if let Some(quant_conf) = &config {
1545 match quant_conf {
1546 QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
1547 QuantizedConfig::Fp8 { weight_block_size } => {
1548 if weight_block_size.is_some() {
1549 blockwise_fp8_linear_b(
1550 in_dim,
1551 out_dim,
1552 quant_conf,
1553 false,
1554 Default::default(),
1555 vb,
1556 )?
1557 } else {
1558 pertensor_fp8_linear_b(
1559 in_dim,
1560 out_dim,
1561 quant_conf,
1562 false,
1563 Default::default(),
1564 vb,
1565 )?
1566 }
1567 }
1568 QuantizedConfig::Bitsandbytes { .. } => {
1569 Arc::new(BnbLinear::linear_b(in_dim, out_dim, false, vb)?) as Arc<_>
1570 }
1571 QuantizedConfig::Afq { .. } => {
1572 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, false, vb)?
1573 }
1574 QuantizedConfig::MXFP4 {} => {
1575 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, false, vb)?
1576 }
1577 }
1578 } else {
1579 if !vb.contains_tensor("weight") {
1580 make_dummy_or_error("linear_no_bias", &vb, &["weight"])?
1581 } else {
1582 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
1583 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
1584
1585 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
1586 Linear::new(weight, None),
1587 ))?;
1588 Arc::new(layer) as Arc<dyn QuantMethod>
1589 }
1590 };
1591 apply_immediate_isq(layer, base_vb)
1592}
1593
1594pub fn linear(
1595 in_dim: usize,
1596 out_dim: usize,
1597 config: &Option<QuantizedConfig>,
1598 vb: ShardedVarBuilder,
1599) -> Result<Arc<dyn QuantMethod>> {
1600 let base_vb = vb.clone();
1601 let vb = if should_apply_immediate_isq(&vb) {
1602 vb.set_device(Device::Cpu)
1603 } else {
1604 vb
1605 };
1606
1607 let layer = if let Some(quant_conf) = &config {
1608 match quant_conf {
1609 QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
1610 QuantizedConfig::Fp8 { weight_block_size } => {
1611 if weight_block_size.is_some() {
1612 blockwise_fp8_linear_b(
1613 in_dim,
1614 out_dim,
1615 quant_conf,
1616 true,
1617 Default::default(),
1618 vb,
1619 )?
1620 } else {
1621 pertensor_fp8_linear_b(
1622 in_dim,
1623 out_dim,
1624 quant_conf,
1625 true,
1626 Default::default(),
1627 vb,
1628 )?
1629 }
1630 }
1631 QuantizedConfig::Bitsandbytes { .. } => {
1632 Arc::new(BnbLinear::linear_b(in_dim, out_dim, true, vb)?) as Arc<_>
1633 }
1634 QuantizedConfig::Afq { .. } => {
1635 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, true, vb)?
1636 }
1637 QuantizedConfig::MXFP4 {} => {
1638 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, true, vb)?
1639 }
1640 }
1641 } else {
1642 if has_missing_required_tensors(&vb, &["weight", "bias"]) {
1643 make_dummy_or_error("linear", &vb, &["weight", "bias"])?
1644 } else {
1645 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
1646 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
1647 let bias = vb.get_with_hints((out_dim,), "bias", Default::default())?;
1648
1649 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
1650 Linear::new(weight, Some(bias)),
1651 ))?;
1652 Arc::new(layer) as Arc<dyn QuantMethod>
1653 }
1654 };
1655 apply_immediate_isq(layer, base_vb)
1656}
1657
1658pub fn linear_b(
1659 in_dim: usize,
1660 out_dim: usize,
1661 bias: bool,
1662 config: &Option<QuantizedConfig>,
1663 vb: ShardedVarBuilder,
1664) -> Result<Arc<dyn QuantMethod>> {
1665 if bias {
1666 linear(in_dim, out_dim, config, vb)
1667 } else {
1668 linear_no_bias(in_dim, out_dim, config, vb)
1669 }
1670}
1671
1672#[cfg(test)]
1673mod tests {
1674 use std::collections::HashMap;
1675
1676 use super::*;
1677
1678 fn empty_vb(make_dummy_regexes: Option<Vec<&str>>) -> ShardedVarBuilder {
1679 let backend: HashMap<String, Tensor> = HashMap::new();
1680 let make_dummy_regexes = make_dummy_regexes.map(|regexes| {
1681 Arc::new(
1682 regexes
1683 .into_iter()
1684 .map(Regex::new)
1685 .collect::<std::result::Result<Vec<_>, _>>()
1686 .unwrap(),
1687 )
1688 });
1689 ShardedSafeTensors::wrap_with_dummy_regexes(
1690 Box::new(backend),
1691 DType::F32,
1692 Device::Cpu,
1693 make_dummy_regexes,
1694 )
1695 }
1696
1697 #[test]
1698 fn missing_linear_weight_outside_uqff_errors() {
1699 let err = linear_no_bias(2, 3, &None, empty_vb(None).pp("foo")).unwrap_err();
1700 let msg = err.to_string();
1701
1702 assert!(msg.contains("Missing required tensor(s)"));
1703 assert!(msg.contains("foo.weight"));
1704 assert!(msg.contains("UQFF"));
1705 }
1706
1707 #[test]
1708 fn missing_uqff_placeholder_creates_contextual_dummy() -> Result<()> {
1709 let layer = linear_no_bias(
1710 2,
1711 3,
1712 &None,
1713 empty_vb(Some(vec![r"^foo\.weight$"])).pp("foo"),
1714 )?;
1715
1716 let info = layer.dummy_info().unwrap();
1717 assert_eq!(layer.name(), "dummy");
1718 assert_eq!(info.context, "linear_no_bias");
1719 assert_eq!(info.prefix, "foo");
1720 assert_eq!(info.missing_tensors, vec!["foo.weight"]);
1721
1722 let input = Tensor::zeros((1, 2), DType::F32, &Device::Cpu)?;
1723 let err = layer.forward_raw(&input).unwrap_err();
1724 let msg = err.to_string();
1725 assert!(msg.contains("forward pass"));
1726 assert!(msg.contains("foo.weight"));
1727 assert!(msg.contains("temporary UQFF placeholders"));
1728
1729 Ok(())
1730 }
1731}