Skip to main content

hanzo_quant/
lib.rs

1use std::{
2    borrow::Cow,
3    fmt::Debug,
4    num::NonZeroUsize,
5    sync::{atomic::AtomicUsize, Arc, Condvar, Mutex, MutexGuard},
6};
7
8use blockwise_fp8::blockwise_fp8_linear_b;
9#[cfg(feature = "metal")]
10use hanzo_ml::D;
11use hanzo_ml::{
12    quantized::{GgmlDType, QMatMul, QTensor},
13    DType, Device, Result, Tensor,
14};
15use pertensor_fp8::pertensor_fp8_linear_b;
16
17#[cfg(feature = "metal")]
18pub mod metal_kernels;
19
20mod afq;
21mod bitsandbytes;
22mod blockwise_fp8;
23pub mod cublaslt;
24pub mod distributed;
25mod dummy;
26pub mod f8q8;
27mod fp8;
28pub mod gemv;
29mod gguf;
30mod gptq;
31mod hqq;
32mod imatrix;
33mod lora;
34mod mxfp4;
35mod pending_layer;
36mod pertensor_fp8;
37pub mod rotary;
38pub mod safetensors;
39mod scalar_fp8;
40mod unquantized;
41mod utils;
42mod vector_fp8;
43
44use gptq::gptq_linear;
45use lora::merge_lora_weights;
46use regex::Regex;
47pub use safetensors::{Shard, ShardedSafeTensors, ShardedVarBuilder};
48
49pub use afq::{AfqBits, AfqGroupSize, AfqInner, AfqLayer};
50pub use bitsandbytes::{BnbLinear, BnbQuantParams, BnbQuantType};
51pub use blockwise_fp8::{
52    blockwise_fp8_moe, fp8_blockwise_dequantize, fp8_blockwise_quantize, BlockwiseFP8Linear,
53};
54pub use distributed::{
55    layers::{
56        compute_kv_shard, compute_n_kv_groups, ColumnParallelLayer, FusedExperts, PackedExperts,
57        ReplicatedLayer, RowParallelLayer,
58    },
59    socket::{Client, Server},
60    BarrierLike, Comm, Id, RingConfig, SumAllReduce,
61};
62pub use dummy::{DummyLayer, DummyLayerInfo};
63pub use f8q8::F8Q8Linear;
64pub use fp8::FP8Linear;
65#[cfg(feature = "cuda")]
66pub use gemv::gemv;
67pub use gemv::{should_use_gemv, GEMV_CONTROLLER};
68#[cfg(feature = "cuda")]
69pub use gguf::cuda::{
70    grouped_moe_gemm_prequantized, indexed_moe_fused_decode, moe_dispatch_build,
71    moe_weighted_reduce_flat, quantize_input_q8_1, ACT_GELU_PYTORCH_TANH, ACT_SILU,
72};
73#[cfg(feature = "cuda")]
74pub use gguf::fast_mmq::{
75    grouped as grouped_moe_mmq, grouped_from_glu_pair as grouped_moe_mmq_from_glu_pair,
76    grouped_pair as grouped_moe_mmq_pair, supports as supports_mmq,
77};
78pub use gguf::GgufMatMul;
79pub use gptq::GptqLayer;
80pub use hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer};
81pub use imatrix::{CollectedImatrixData, ImatrixLayerStats};
82pub use lora::{
83    clear_applied_loras, get_applied_loras, linear_no_bias_static_lora, push_applied_lora,
84    LoraAdapter, LoraConfig, StaticLoraConfig, MULTI_LORA_DELIMITER,
85};
86pub use mxfp4::MXFP4Layer;
87pub use pending_layer::PendingIsqLayer;
88pub use pertensor_fp8::PerTensorFP8Linear;
89pub use unquantized::UnquantLinear;
90pub use utils::flash_attn_sinks_metal;
91pub use utils::flash_attn_sinks_varlen_metal;
92#[cfg(feature = "cuda")]
93pub use utils::gptoss_swiglu_fused;
94#[cfg(feature = "cuda")]
95pub use utils::gptoss_swiglu_interleaved;
96pub use utils::isq::apply_immediate_isq;
97pub use utils::softmax_with_sinks;
98pub use utils::{fused_glu, GluActivationType};
99pub use utils::{log, BitWiseOp, CumSumOp, LeftshiftOp, NonZeroOp, SortOp, UQFF_QUANT_TYPE_OFFSET};
100pub use vector_fp8::{fp8_vector_dequantize, fp8_vector_quantize};
101
102use hanzo_nn::{Conv1d, Conv2d, Linear, Module};
103use serde::{Deserialize, Deserializer, Serialize};
104
105/// Limits outstanding async ISQ jobs to prevent unbounded memory growth.
106///
107/// Without backpressure, MoE models (e.g. Gemma4 with 128 experts × 30 layers)
108/// queue BF16 tensor data in the rayon pool faster than the pool can quantize,
109/// causing OOM on memory-constrained systems like macOS Metal with unified memory.
110pub struct IsqBackpressure {
111    count: Mutex<usize>,
112    cvar: Condvar,
113    max: usize,
114}
115
116impl IsqBackpressure {
117    pub fn new(max: usize) -> Self {
118        Self {
119            count: Mutex::new(0),
120            cvar: Condvar::new(),
121            max,
122        }
123    }
124
125    /// Block until a slot is available, then increment the outstanding count.
126    pub fn acquire(&self) {
127        let mut count = self.count.lock().expect("ISQ backpressure lock poisoned");
128        while *count >= self.max {
129            count = self
130                .cvar
131                .wait(count)
132                .expect("ISQ backpressure lock poisoned");
133        }
134        *count += 1;
135    }
136
137    /// Decrement the outstanding count and wake a blocked loader thread.
138    pub fn release(&self) {
139        let mut count = self.count.lock().expect("ISQ backpressure lock poisoned");
140        *count = count.saturating_sub(1);
141        self.cvar.notify_one();
142    }
143}
144
145impl Debug for IsqBackpressure {
146    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147        let count = self.count.lock().map(|c| *c).unwrap_or(0);
148        f.debug_struct("IsqBackpressure")
149            .field("outstanding", &count)
150            .field("max", &self.max)
151            .finish()
152    }
153}
154
155#[derive(Clone, Debug)]
156pub struct ImmediateIsqParams {
157    pub guard: QuantizeOntoGuard,
158    pub ty: Option<IsqType>,
159    pub predicates: Vec<Regex>,
160    pub overrides: Vec<ImmediateIsqOverride>,
161    /// Thread pool for parallel immediate ISQ on discrete GPUs.
162    /// When `Some`, `apply_immediate_isq` will spawn quantization tasks
163    /// on this pool and return `PendingIsqLayer` wrappers.
164    pub pool: Option<Arc<rayon::ThreadPool>>,
165    /// Backpressure to limit outstanding async ISQ jobs.
166    pub backpressure: Arc<IsqBackpressure>,
167}
168
169#[derive(Clone, Debug)]
170pub struct ImmediateIsqOverride {
171    pub predicate: Regex,
172    pub ty: Option<IsqType>,
173    pub device: Option<Device>,
174}
175
176#[derive(Clone, Debug)]
177pub struct ImmediateIsqMatch {
178    pub ty: IsqType,
179    pub device: Option<Device>,
180}
181
182thread_local! {
183    static ENGINE_IMMEDIATE_ISQ: std::cell::RefCell<Option<ImmediateIsqParams>> = const { std::cell::RefCell::new(None) } ;
184}
185
186pub fn set_immediate_isq(isq: Option<IsqType>, predicates: Vec<Regex>) {
187    let (pool, _) = create_isq_thread_pool(isq);
188    set_immediate_isq_with_pool(isq, predicates, Vec::new(), pool);
189}
190
191pub fn set_immediate_isq_with_pool(
192    isq: Option<IsqType>,
193    predicates: Vec<Regex>,
194    overrides: Vec<ImmediateIsqOverride>,
195    pool: rayon::ThreadPool,
196) {
197    // Allow pool threads + 1 outstanding jobs: enough for pipeline overlap
198    // (load next tensor while pool quantizes current) without unbounded growth.
199    let max_outstanding = pool.current_num_threads() + 1;
200    ENGINE_IMMEDIATE_ISQ.with(|cell| {
201        *cell.borrow_mut() = Some(ImmediateIsqParams {
202            guard: QuantizeOntoGuard::new(),
203            ty: isq,
204            predicates,
205            overrides,
206            backpressure: Arc::new(IsqBackpressure::new(max_outstanding)),
207            pool: Some(Arc::new(pool)),
208        });
209    });
210}
211
212/// Create a rayon thread pool for parallel immediate ISQ.
213/// Returns `(pool, num_threads)` so callers can log the thread count.
214///
215/// Thread count is based on the quantization type:
216/// - GGML types (Q2K-Q8K) and F8E4M3: `rayon::current_num_threads()` (CPU quantization)
217/// - HQQ/AFQ: 1 thread (GPU quantization, serialized by `QuantizeOntoGuard`)
218pub fn create_isq_thread_pool(ty: Option<IsqType>) -> (rayon::ThreadPool, usize) {
219    let num_threads = if std::env::var("HANZO_ISQ_SINGLETHREAD").is_ok() {
220        1
221    } else if let Some(ty) = ty {
222        ty.get_max_isq_cpu_threads()
223            .map(usize::from)
224            .unwrap_or_else(rayon::current_num_threads)
225    } else {
226        rayon::current_num_threads()
227    };
228
229    let pool = rayon::ThreadPoolBuilder::new()
230        .num_threads(num_threads)
231        .build()
232        .expect("Failed to create ISQ thread pool");
233    (pool, num_threads)
234}
235
236pub fn get_immediate_isq() -> Option<ImmediateIsqParams> {
237    ENGINE_IMMEDIATE_ISQ.with(|cell| cell.borrow().clone())
238}
239
240pub fn clear_immediate_isq() {
241    ENGINE_IMMEDIATE_ISQ.with(|cell| {
242        *cell.borrow_mut() = None;
243    });
244}
245
246pub fn should_apply_immediate_isq(vb: &ShardedVarBuilder) -> bool {
247    immediate_isq_match(vb).is_some()
248}
249
250pub fn immediate_isq_match(vb: &ShardedVarBuilder) -> Option<ImmediateIsqMatch> {
251    let immediate_isq = get_immediate_isq()?;
252    // Add a .weight to match the ISQ regexes!
253    let prefix = format!("{}.weight", vb.prefix());
254    resolve_immediate_isq(&immediate_isq, &prefix)
255}
256
257fn resolve_immediate_isq(params: &ImmediateIsqParams, prefix: &str) -> Option<ImmediateIsqMatch> {
258    if let Some(override_hit) = params
259        .overrides
260        .iter()
261        .find(|override_pred| override_pred.predicate.is_match(prefix))
262    {
263        if let Some(ty) = override_hit.ty.or(params.ty) {
264            return Some(ImmediateIsqMatch {
265                ty,
266                device: override_hit.device.clone(),
267            });
268        }
269        return None;
270    }
271
272    if let Some(ty) = params.ty {
273        if params
274            .predicates
275            .iter()
276            .any(|predicate| predicate.is_match(prefix))
277        {
278            return Some(ImmediateIsqMatch { ty, device: None });
279        }
280    }
281
282    None
283}
284
285#[derive(Debug, Clone, Serialize)]
286#[serde(tag = "quant_method", rename_all = "lowercase")]
287pub enum QuantizedConfig {
288    GptqAwq {
289        bits: usize,
290        group_size: usize,
291        checkpoint_format: Option<String>,
292        is_awq: bool,
293    },
294    Fp8 {
295        weight_block_size: Option<Vec<usize>>,
296    },
297    Bitsandbytes {
298        bnb_4bit_quant_type: Option<String>,
299    },
300    Afq {
301        bits: usize,
302        group_size: usize,
303    },
304    MXFP4 {},
305}
306
307// Common fields for all variants
308#[derive(Deserialize)]
309struct RawConfig {
310    quant_method: Option<String>,
311    bits: Option<usize>,
312    group_size: Option<usize>,
313    checkpoint_format: Option<String>,
314    weight_block_size: Option<Vec<usize>>,
315    bnb_4bit_quant_type: Option<String>,
316}
317
318// Custom deserializer implementation
319impl<'de> Deserialize<'de> for QuantizedConfig {
320    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
321    where
322        D: Deserializer<'de>,
323    {
324        let raw = RawConfig::deserialize(deserializer)?;
325
326        match &raw.quant_method {
327            Some(m) if m == "gptq" || m == "awq" => {
328                let bits = raw
329                    .bits
330                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
331                let group_size = raw
332                    .group_size
333                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
334                Ok(QuantizedConfig::GptqAwq {
335                    bits,
336                    group_size,
337                    checkpoint_format: raw.checkpoint_format,
338                    is_awq: m == "awq",
339                })
340            }
341            Some(m) if m == "fp8" => {
342                // weight_block_size is optional - None means per-tensor quantization
343                Ok(QuantizedConfig::Fp8 {
344                    weight_block_size: raw.weight_block_size,
345                })
346            }
347            Some(m) if m == "bitsandbytes" => Ok(QuantizedConfig::Bitsandbytes {
348                bnb_4bit_quant_type: raw.bnb_4bit_quant_type,
349            }),
350            Some(m) if m == "afq" => {
351                let bits = raw
352                    .bits
353                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
354                let group_size = raw
355                    .group_size
356                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
357                Ok(QuantizedConfig::Afq { bits, group_size })
358            }
359            Some(m) if m == "mxfp4" => {
360                Ok(QuantizedConfig::MXFP4 {  })
361            }
362            None => {
363                let bits = raw
364                    .bits
365                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
366                let group_size = raw
367                    .group_size
368                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
369                Ok(QuantizedConfig::Afq { bits, group_size })
370            }
371            Some(unknown_method) => {
372                Err(serde::de::Error::custom(format!(
373                    "Unknown quantization method: {unknown_method}. Expected one of: gptq, fp8, bitsandbytes, afq, or not specified"
374                )))
375            },
376        }
377    }
378}
379
380impl QuantizedConfig {
381    pub fn name(&self) -> &'static str {
382        match self {
383            Self::GptqAwq { .. } => "gptq",
384            Self::Fp8 { .. } => "fp8",
385            Self::Bitsandbytes { .. } => "bitsandbytes",
386            Self::Afq { .. } => "afq",
387            Self::MXFP4 { .. } => "mxfp4",
388        }
389    }
390
391    pub fn get_bits_name(&self, _vb: &ShardedVarBuilder) -> String {
392        match self {
393            Self::GptqAwq { bits, .. } => format!("{bits} bits"),
394            Self::Fp8 { .. } => "8 bits".to_string(),
395            Self::Bitsandbytes {
396                bnb_4bit_quant_type: Some(_),
397            } => "4 bits".to_string(),
398            Self::Bitsandbytes {
399                bnb_4bit_quant_type: None,
400            } => "8 bits".to_string(),
401            Self::Afq { bits, .. } => format!("{bits} bits"),
402            Self::MXFP4 {} => format!("{} bits", mxfp4::N_BITS),
403        }
404    }
405
406    pub fn pack_factor(&self, dtype: DType) -> usize {
407        match self {
408            Self::GptqAwq { bits, .. } | Self::Afq { bits, .. } => match bits {
409                2 => IsqType::Q2K.pack_factor(dtype),
410                3 => IsqType::Q3K.pack_factor(dtype),
411                4 => IsqType::Q4K.pack_factor(dtype),
412                5 => IsqType::Q5K.pack_factor(dtype),
413                6 => IsqType::Q6K.pack_factor(dtype),
414                8 => IsqType::Q8_0.pack_factor(dtype),
415                40 => 4, // mxfp4: 2 FP4 values per byte = factor of 4
416                other => panic!("Unexpected bits in `pack_factor` {other}"),
417            },
418            Self::Fp8 { .. } => IsqType::Q8_0.pack_factor(dtype),
419            Self::Bitsandbytes {
420                bnb_4bit_quant_type: Some(_),
421            }
422            | Self::Bitsandbytes {
423                bnb_4bit_quant_type: None,
424            } => IsqType::Q4K.pack_factor(dtype),
425            Self::MXFP4 {} => IsqType::Q4_0.pack_factor(dtype),
426        }
427    }
428}
429
430#[derive(Debug, Clone)]
431pub enum QuantMethodConfig {
432    GptqAwq {
433        bits: i32,
434        use_exllama: bool,
435        q_weight: Tensor,
436        qzeros: Option<Tensor>,
437        scales: Tensor,
438        g_idx: Option<Tensor>,
439        bias: Option<Tensor>,
440        workspace: Option<Tensor>,
441        is_marlin: bool,
442        is_awq: bool,
443    },
444    Gguf {
445        q_weight: Arc<QTensor>,
446        b: Option<Tensor>,
447    },
448    Unquantized(Linear),
449    Hqq {
450        tensor: Tensor,
451        bits: HqqBits,
452        group_size: NonZeroUsize,
453        axis: HqqAxis,
454        optimization_steps: Option<usize>,
455        round_zeros: Option<bool>,
456        channel_wise: Option<bool>,
457        bias: Option<Tensor>,
458    },
459    Dummy,
460    FP8 {
461        lin: Linear,
462        dtype: DType,
463    },
464    Bnb {
465        weight: Tensor,
466        bias: Option<Tensor>,
467        params: BnbQuantParams,
468        quant_ty: BnbQuantType,
469    },
470    BlockwiseFP8 {
471        weight: Tensor,
472        weight_scale_inv: Tensor,
473        bias: Option<Tensor>,
474        dequant_dtype: DType,
475        weight_block_size: Vec<usize>,
476    },
477    PerTensorFP8 {
478        weight: Tensor,
479        weight_scale_inv: Tensor,
480        activation_scale: Option<Tensor>,
481        bias: Option<Tensor>,
482        dequant_dtype: DType,
483    },
484    Afq {
485        weight: Tensor,
486        bias: Option<Tensor>,
487        bits: AfqBits,
488        group_size: AfqGroupSize,
489    },
490    MXFP4 {
491        blocks: Tensor,
492        scales: Tensor,
493        bias: Option<Tensor>,
494    },
495}
496
497/// Device/configurable intelligent matrix multiplication
498/// - Handles limitation of `accelerate` which requires f32
499pub struct MatMul;
500
501impl MatMul {
502    /// Compute matrix-matrix product.
503    pub fn matmul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
504        #[cfg(feature = "accelerate")]
505        {
506            let original_dtype = a.dtype();
507            a.to_dtype(DType::F32)?
508                .matmul(&b.to_dtype(DType::F32)?)?
509                .to_dtype(original_dtype)
510        }
511        #[cfg(not(feature = "accelerate"))]
512        {
513            if a.device().is_cpu() {
514                let original_dtype = a.dtype();
515                a.to_dtype(DType::F16)?
516                    .matmul(&b.to_dtype(DType::F16)?)?
517                    .to_dtype(original_dtype)
518            } else {
519                a.matmul(b)
520            }
521        }
522    }
523
524    /// Compute matrix-matrix product.
525    /// The result will be divided by the `scale` parameter in an affine division.
526    pub fn matmul_affine_div(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
527        // TODO(hanzoai): Optimize this by using the gemm parameter?
528        self.matmul(a, b)? / scale
529    }
530
531    /// Compute matrix-matrix product.
532    /// The result will be divided by the `scale` parameter in an affine multiplication.
533    pub fn matmul_affine_mul(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
534        // TODO(hanzoai): Optimize this by using the gemm parameter?
535        self.matmul(a, b)? * scale
536    }
537
538    /// Compute quantized matrix-matrix product.
539    pub fn qmatmul(&self, x: &Tensor, matmul: &QMatMul) -> Result<Tensor> {
540        matmul.forward(x)
541    }
542}
543
544/// Device/configurable intelligent convolution
545/// - Handles limitation of cpu which requires f32
546pub struct Convolution;
547
548impl Convolution {
549    pub fn forward_1d(&self, layer: &Conv1d, x: &Tensor) -> Result<Tensor> {
550        if x.device().is_cpu() {
551            let original_dtype = x.dtype();
552            Conv1d::new(
553                layer.weight().to_dtype(DType::F32)?,
554                layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
555                *layer.config(),
556            )
557            .forward(&x.to_dtype(DType::F32)?)?
558            .to_dtype(original_dtype)
559        } else {
560            layer.forward(x)
561        }
562    }
563
564    pub fn forward_2d(&self, layer: &Conv2d, x: &Tensor) -> Result<Tensor> {
565        if x.device().is_cpu() {
566            let original_dtype = x.dtype();
567            Conv2d::new(
568                layer.weight().to_dtype(DType::F32)?,
569                layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
570                *layer.config(),
571            )
572            .forward(&x.to_dtype(DType::F32)?)?
573            .to_dtype(original_dtype)
574        } else {
575            layer.forward(x)
576        }
577    }
578}
579
580/// In-situ quantization type specifying the format to apply to model weights.
581#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq, Serialize, Deserialize)]
582pub enum IsqType {
583    Q4_0,
584    Q4_1,
585    Q5_0,
586    Q5_1,
587    Q8_0,
588    Q8_1,
589    Q2K,
590    Q3K,
591    Q4K,
592    Q5K,
593    Q6K,
594    Q8K,
595    HQQ8,
596    HQQ4,
597    // HQQ3,
598    // HQQ2,
599    // HQQ1,
600    F8E4M3,
601    AFQ8,
602    AFQ6,
603    AFQ4,
604    AFQ3,
605    AFQ2,
606    F8Q8,
607    MXFP4,
608}
609
610/// Target bit width for automatic ISQ quantization.
611///
612/// On Metal, these select AFQ variants; on CUDA/CPU, they select Q*K variants.
613#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
614pub enum IsqBits {
615    /// 2-bit quantization (AFQ2 on Metal, Q2K otherwise).
616    Two,
617    /// 3-bit quantization (AFQ3 on Metal, Q3K otherwise).
618    Three,
619    /// 4-bit quantization (AFQ4 on Metal, Q4K otherwise).
620    Four,
621    /// 5-bit quantization (Q5K on all platforms).
622    Five,
623    /// 6-bit quantization (AFQ6 on Metal, Q6K otherwise).
624    Six,
625    /// 8-bit quantization (AFQ8 on Metal, Q8_0 otherwise).
626    Eight,
627}
628
629impl IsqBits {
630    /// Resolve to the platform-appropriate `IsqType` for the given device.
631    pub fn resolve(self, device: &Device) -> IsqType {
632        match (self, device.is_metal()) {
633            (Self::Two, true) => IsqType::AFQ2,
634            (Self::Two, false) => IsqType::Q2K,
635            (Self::Three, true) => IsqType::AFQ3,
636            (Self::Three, false) => IsqType::Q3K,
637            (Self::Four, true) => IsqType::AFQ4,
638            (Self::Four, false) => IsqType::Q4K,
639            (Self::Five, _) => IsqType::Q5K,
640            (Self::Six, true) => IsqType::AFQ6,
641            (Self::Six, false) => IsqType::Q6K,
642            (Self::Eight, true) => IsqType::AFQ8,
643            (Self::Eight, false) => IsqType::Q8_0,
644        }
645    }
646
647    /// Return all platform variants, with the current platform's preferred variant first.
648    /// On Metal, AFQ variants come first; on other platforms, GGUF/Q variants come first.
649    pub fn expand(self) -> Vec<IsqType> {
650        #[cfg(feature = "metal")]
651        match self {
652            Self::Two => vec![IsqType::AFQ2, IsqType::Q2K],
653            Self::Three => vec![IsqType::AFQ3, IsqType::Q3K],
654            Self::Four => vec![IsqType::AFQ4, IsqType::Q4K],
655            Self::Five => vec![IsqType::Q5K],
656            Self::Six => vec![IsqType::AFQ6, IsqType::Q6K],
657            Self::Eight => vec![IsqType::AFQ8, IsqType::Q8_0],
658        }
659        #[cfg(not(feature = "metal"))]
660        match self {
661            Self::Two => vec![IsqType::Q2K, IsqType::AFQ2],
662            Self::Three => vec![IsqType::Q3K, IsqType::AFQ3],
663            Self::Four => vec![IsqType::Q4K, IsqType::AFQ4],
664            Self::Five => vec![IsqType::Q5K],
665            Self::Six => vec![IsqType::Q6K, IsqType::AFQ6],
666            Self::Eight => vec![IsqType::Q8_0, IsqType::AFQ8],
667        }
668    }
669}
670
671impl TryFrom<&str> for IsqBits {
672    type Error = ();
673    fn try_from(s: &str) -> std::result::Result<Self, ()> {
674        match s {
675            "2" => Ok(Self::Two),
676            "3" => Ok(Self::Three),
677            "4" => Ok(Self::Four),
678            "5" => Ok(Self::Five),
679            "6" => Ok(Self::Six),
680            "8" => Ok(Self::Eight),
681            _ => Err(()),
682        }
683    }
684}
685
686impl std::fmt::Display for IsqType {
687    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
688        match self {
689            Self::Q4_0 => write!(f, "q4_0"),
690            Self::Q4_1 => write!(f, "q4_1"),
691            Self::Q5_0 => write!(f, "q5_0"),
692            Self::Q5_1 => write!(f, "q5_1"),
693            Self::Q8_0 => write!(f, "q8_0"),
694            Self::Q8_1 => write!(f, "q8_1"),
695            Self::Q2K => write!(f, "q2k"),
696            Self::Q3K => write!(f, "q3k"),
697            Self::Q4K => write!(f, "q4k"),
698            Self::Q5K => write!(f, "q5k"),
699            Self::Q6K => write!(f, "q6k"),
700            Self::Q8K => write!(f, "q8k"),
701            Self::HQQ8 => write!(f, "hqq8"),
702            Self::HQQ4 => write!(f, "hqq4"),
703            Self::F8E4M3 => write!(f, "fp8"),
704            Self::AFQ8 => write!(f, "afq8"),
705            Self::AFQ6 => write!(f, "afq6"),
706            Self::AFQ4 => write!(f, "afq4"),
707            Self::AFQ3 => write!(f, "afq3"),
708            Self::AFQ2 => write!(f, "afq2"),
709            Self::F8Q8 => write!(f, "f8q8"),
710            Self::MXFP4 => write!(f, "mxfp4"),
711        }
712    }
713}
714
715impl IsqType {
716    /// Factor by which the weight size is reduced over the given dtype.
717    /// original size / pack factor = quantized size
718    pub fn pack_factor(&self, dtype: DType) -> usize {
719        match self {
720            Self::Q4_0 | Self::AFQ4 => (dtype.size_in_bytes() * GgmlDType::Q4_0.block_size())
721                .div_ceil(GgmlDType::Q4_0.type_size()),
722            Self::Q4_1 => (dtype.size_in_bytes() * GgmlDType::Q4_1.block_size())
723                .div_ceil(GgmlDType::Q4_1.type_size()),
724            Self::Q5_0 => (dtype.size_in_bytes() * GgmlDType::Q5_0.block_size())
725                .div_ceil(GgmlDType::Q5_0.type_size()),
726            Self::Q5_1 => (dtype.size_in_bytes() * GgmlDType::Q5_1.block_size())
727                .div_ceil(GgmlDType::Q5_1.type_size()),
728            Self::Q8_0 | Self::AFQ8 => (dtype.size_in_bytes() * GgmlDType::Q8_0.block_size())
729                .div_ceil(GgmlDType::Q8_0.type_size()),
730            Self::Q8_1 => (dtype.size_in_bytes() * GgmlDType::Q8_1.block_size())
731                .div_ceil(GgmlDType::Q8_1.type_size()),
732            Self::Q2K | Self::AFQ2 => (dtype.size_in_bytes() * GgmlDType::Q2K.block_size())
733                .div_ceil(GgmlDType::Q2K.type_size()),
734            Self::Q3K | Self::AFQ3 => (dtype.size_in_bytes() * GgmlDType::Q3K.block_size())
735                .div_ceil(GgmlDType::Q3K.type_size()),
736            Self::Q4K => (dtype.size_in_bytes() * GgmlDType::Q4K.block_size())
737                .div_ceil(GgmlDType::Q4K.type_size()),
738            Self::Q5K => (dtype.size_in_bytes() * GgmlDType::Q5K.block_size())
739                .div_ceil(GgmlDType::Q5K.type_size()),
740            Self::Q6K | Self::AFQ6 => (dtype.size_in_bytes() * GgmlDType::Q6K.block_size())
741                .div_ceil(GgmlDType::Q6K.type_size()),
742            Self::Q8K => (dtype.size_in_bytes() * GgmlDType::Q8K.block_size())
743                .div_ceil(GgmlDType::Q8K.type_size()),
744            // F8Q8: 33 bytes per 32 values -> similar to Q8_0
745            Self::F8Q8 => (dtype.size_in_bytes() * 32).div_ceil(33),
746            // Estimates
747            Self::HQQ4 => 4,
748            Self::HQQ8 => 2,
749            Self::F8E4M3 => 2,
750            // MXFP4: 4 bits per value + 1 byte scale per 32 values
751            // For BF16 (2 bytes): (2*32)/(16+1) ≈ 3.76 → 3
752            Self::MXFP4 => 3,
753        }
754    }
755
756    pub fn get_max_isq_cpu_threads(&self) -> Option<NonZeroUsize> {
757        match self {
758            /*IsqType::HQQ1 | IsqType::HQQ2 | IsqType::HQQ3 | */
759            IsqType::HQQ4
760            | IsqType::HQQ8
761            | IsqType::AFQ2
762            | IsqType::AFQ3
763            | IsqType::AFQ4
764            | IsqType::AFQ6
765            | IsqType::AFQ8
766            | IsqType::MXFP4 => {
767                // Use 1 because our HQQ quantizes on the GPU
768                Some(1.try_into().unwrap())
769            }
770            IsqType::F8E4M3 | IsqType::F8Q8 => None,
771            IsqType::Q2K
772            | IsqType::Q3K
773            | IsqType::Q4K
774            | IsqType::Q4_0
775            | IsqType::Q4_1
776            | IsqType::Q5K
777            | IsqType::Q5_0
778            | IsqType::Q5_1
779            | IsqType::Q6K
780            | IsqType::Q8K
781            | IsqType::Q8_0
782            | IsqType::Q8_1 => None,
783        }
784    }
785}
786
787impl TryFrom<IsqType> for GgmlDType {
788    type Error = hanzo_ml::Error;
789
790    fn try_from(value: IsqType) -> Result<Self> {
791        let tp = match value {
792            IsqType::Q2K => Self::Q2K,
793            IsqType::Q3K => Self::Q3K,
794            IsqType::Q4K => Self::Q4K,
795            IsqType::Q4_0 => Self::Q4_0,
796            IsqType::Q4_1 => Self::Q4_1,
797            IsqType::Q5K => Self::Q5K,
798            IsqType::Q5_0 => Self::Q5_0,
799            IsqType::Q5_1 => Self::Q5_1,
800            IsqType::Q6K => Self::Q6K,
801            IsqType::Q8K => Self::Q8K,
802            IsqType::Q8_0 => Self::Q8_0,
803            IsqType::Q8_1 => Self::Q8_1,
804            _ => hanzo_ml::bail!("Expected valid GGML ISQ type."),
805        };
806        #[cfg(feature = "cuda")]
807        {
808            if !matches!(
809                tp,
810                GgmlDType::Q4_0
811                    | GgmlDType::Q4_1
812                    | GgmlDType::Q5_0
813                    | GgmlDType::Q5_1
814                    | GgmlDType::Q8_0
815                    | GgmlDType::Q2K
816                    | GgmlDType::Q3K
817                    | GgmlDType::Q4K
818                    | GgmlDType::Q5K
819                    | GgmlDType::Q6K
820            ) {
821                hanzo_ml::bail!("GGML ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`")
822            }
823        }
824        Ok(tp)
825    }
826}
827
828impl TryFrom<GgmlDType> for IsqType {
829    type Error = hanzo_ml::Error;
830
831    fn try_from(value: GgmlDType) -> Result<Self> {
832        match value {
833            GgmlDType::Q2K => Ok(Self::Q2K),
834            GgmlDType::Q3K => Ok(Self::Q3K),
835            GgmlDType::Q4K => Ok(Self::Q4K),
836            GgmlDType::Q5K => Ok(Self::Q5K),
837            GgmlDType::Q6K => Ok(Self::Q6K),
838            GgmlDType::Q4_0 => Ok(Self::Q4_0),
839            GgmlDType::Q4_1 => Ok(Self::Q4_1),
840            GgmlDType::Q5_0 => Ok(Self::Q5_0),
841            GgmlDType::Q5_1 => Ok(Self::Q5_1),
842            GgmlDType::Q8_0 => Ok(Self::Q8_0),
843            GgmlDType::Q8_1 => Ok(Self::Q8_1),
844            GgmlDType::Q8K => Ok(Self::Q8K),
845            GgmlDType::BF16 | GgmlDType::F32 | GgmlDType::F16 => {
846                hanzo_ml::bail!("Expected valid GGML ISQ type.")
847            }
848        }
849    }
850}
851
852#[derive(Debug, Clone, Copy)]
853pub enum QuantizedSerdeType {
854    Gguf = 0,
855    Unquant = 1,
856    Hqq = 2,
857    Fp8 = 3,
858    Afq = 4,
859    F8Q8 = 5,
860    Mxfp4 = 6,
861}
862
863impl TryFrom<usize> for QuantizedSerdeType {
864    type Error = hanzo_ml::Error;
865    fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
866        match value {
867            0 => Ok(Self::Gguf),
868            1 => Ok(Self::Unquant),
869            2 => Ok(Self::Hqq),
870            3 => Ok(Self::Fp8),
871            4 => Ok(Self::Afq),
872            5 => Ok(Self::F8Q8),
873            6 => Ok(Self::Mxfp4),
874            other => hanzo_ml::bail!("QuantizedSerdeType {other} is invalid."),
875        }
876    }
877}
878
879pub trait QuantizedSerde {
880    fn name(&self) -> &'static str;
881    fn isq_serde_supported(&self) -> bool {
882        false
883    }
884    fn serialize(&self) -> Result<Cow<'_, [u8]>> {
885        hanzo_ml::bail!("`QuantizedSerde::serialize` is not supported.")
886    }
887    fn deserialize(
888        _data: Cow<[u8]>,
889        _device: &Device,
890        _comm: &Arc<crate::Comm>,
891        _guard: QuantizeOntoGuard,
892    ) -> Result<Arc<dyn QuantMethod>>
893    where
894        Self: Sized,
895    {
896        hanzo_ml::bail!("`QuantizedSerde::deserialize` is not supported.")
897    }
898    fn deserialize_ext_bias(
899        _data: Cow<[u8]>,
900        _device: &Device,
901        _guard: QuantizeOntoGuard,
902    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
903    where
904        Self: Sized,
905    {
906        hanzo_ml::bail!("`QuantizedSerde::deserialize_ext_bias` is not supported.")
907    }
908    /// NOT meant for external calling
909    fn serialize_with_bias(&self, _bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
910        hanzo_ml::bail!("`QuantizedSerde::serialize_with_bias` is not supported.")
911    }
912}
913
914/// Used to gate access to quantizing onto the host device
915#[derive(Clone, Debug)]
916#[allow(unused)]
917pub struct QuantizeOntoGuard {
918    pub inner: Arc<Mutex<()>>,
919}
920
921/// Real (for Metal) and Fake (for CUDA)
922pub enum QuantizeOntoDropGuard<'a> {
923    Real(MutexGuard<'a, ()>),
924    Fake,
925}
926
927impl Default for QuantizeOntoGuard {
928    fn default() -> Self {
929        Self::new()
930    }
931}
932
933impl QuantizeOntoGuard {
934    pub fn new() -> Self {
935        QuantizeOntoGuard {
936            inner: Arc::new(Mutex::new(())),
937        }
938    }
939
940    /// Acquire the quantize drop guard to protect the critical section.
941    ///
942    /// On metal, this waits for outstanding work to finish to avoid "A command encoder is already encoding to this command buffer"
943    pub fn acquire(&self, device: &Device) -> QuantizeOntoDropGuard<'_> {
944        #[cfg(feature = "cuda")]
945        {
946            let _ = device;
947            QuantizeOntoDropGuard::Fake
948        }
949
950        #[cfg(not(feature = "cuda"))]
951        {
952            #[cfg(feature = "metal")]
953            if let Device::Metal(dev) = device {
954                // This is necessary to avoid the errors of "A command encoder is already encoding to this command buffer"
955                dev.wait_until_completed()
956                    .expect("Failed to flush command buffer.");
957            }
958            #[cfg(not(feature = "metal"))]
959            let _ = device;
960
961            QuantizeOntoDropGuard::Real(self.inner.lock().expect("QuantizeOntoGuard was poisoned!"))
962        }
963    }
964}
965
966pub enum DistributedKind {
967    ColumnParallel,
968    RowParallel,
969    Replicated,
970}
971
972/// Quantized method for a quantized matmul.
973pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde {
974    fn new(method: QuantMethodConfig) -> Result<Self>
975    where
976        Self: Sized;
977
978    fn dequantize_w(&self) -> Result<Tensor>;
979
980    /// Compute matmul of `self` and `a`. `self` should contain the weights.
981    /// Automatically casts to the required quantization activation type and back.
982    fn forward(&self, a: &Tensor) -> Result<Tensor> {
983        if let Some(t) = self.quantized_act_type() {
984            let original_ty = a.dtype();
985            self.forward_raw(&a.to_dtype(t)?)?.to_dtype(original_ty)
986        } else {
987            self.forward_raw(a)
988        }
989    }
990
991    /// Raw matmul without dtype casting. Implementors override this.
992    /// Callers should use `forward` instead.
993    fn forward_raw(&self, a: &Tensor) -> Result<Tensor>;
994
995    /// Compute gather matmul of `self` and `a`. `self` should contain the weights.
996    /// Automatically casts to the required quantization activation type and back.
997    ///
998    /// If `a` is (n_tokens, n_experts, cols), `self` weights are (n_experts, rows, cols),
999    /// then the indices are (n_tokens, n_experts).
1000    fn gather_forward(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
1001        if let Some(t) = self.quantized_act_type() {
1002            let original_ty = a.dtype();
1003            self.gather_forward_raw(&a.to_dtype(t)?, indices)?
1004                .to_dtype(original_ty)
1005        } else {
1006            self.gather_forward_raw(a, indices)
1007        }
1008    }
1009
1010    /// Raw gather matmul without dtype casting. Implementors override this.
1011    /// Callers should use `gather_forward` instead.
1012    fn gather_forward_raw(&self, _a: &Tensor, _indices: &Tensor) -> Result<Tensor> {
1013        hanzo_ml::bail!(
1014            "{} does not support `gather_forward`. Please raise an issue.",
1015            self.name()
1016        )
1017    }
1018
1019    /// Get the underlying QTensor if this is a GGUF quantized layer.
1020    /// Used for direct kernel access in the grouped MoE prefill path.
1021    #[cfg(feature = "cuda")]
1022    fn get_qtensor(&self) -> Option<&hanzo_ml::quantized::QTensor> {
1023        None
1024    }
1025
1026    /// If this is an AFQ layer, return its (w_q, scales, biases, bits, group_size).
1027    /// Used by Metal fused QKV / gate-up paths.
1028    fn afq_inner(&self) -> Option<crate::afq::AfqInner<'_>> {
1029        None
1030    }
1031
1032    /// If a quantized method, return the activation dtype.
1033    fn quantized_act_type(&self) -> Option<DType>;
1034
1035    /// Weight dtype and device
1036    fn dtype_and_device(&self) -> (DType, Device);
1037
1038    /// Add a delta weight from LoRA to the weights. This should be prescaled with alpha.
1039    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>>;
1040
1041    /// If the quant is backed by a qmatmul.
1042    fn apply_isq(
1043        self: Arc<Self>,
1044        dtype: Option<IsqType>,
1045        device: Device,
1046        n_quantized: &AtomicUsize,
1047        imatrix_weight: Option<Vec<f32>>,
1048        guard: QuantizeOntoGuard,
1049    ) -> Result<Arc<dyn QuantMethod>>;
1050
1051    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
1052        None
1053    }
1054
1055    fn has_bias(&self) -> bool {
1056        false
1057    }
1058
1059    /// Begin tracking stats into an ImatrixLayerStats
1060    fn begin_track_stats(&mut self) -> Result<()> {
1061        hanzo_ml::bail!("`{}` does not support tracking stats.", self.name())
1062    }
1063
1064    /// End tracking stats into an ImatrixLayerStats. Returns the computed imatrix.
1065    fn end_track_stats(&self) -> Result<Tensor> {
1066        hanzo_ml::bail!("`{}` does not support tracking stats.", self.name())
1067    }
1068
1069    fn is_distributed(&self) -> Option<DistributedKind> {
1070        None
1071    }
1072
1073    fn dummy_info(&self) -> Option<&DummyLayerInfo> {
1074        None
1075    }
1076}
1077
1078impl Module for dyn QuantMethod {
1079    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1080        QuantMethod::forward(self, xs)
1081    }
1082}
1083
1084#[cfg(feature = "cuda")]
1085pub fn try_fused_quantized_gate_up(
1086    xs: &Tensor,
1087    gate: &dyn QuantMethod,
1088    up: &dyn QuantMethod,
1089    activation: GluActivationType,
1090) -> Result<Option<Tensor>> {
1091    if gate.has_bias() || up.has_bias() {
1092        return Ok(None);
1093    }
1094    if !matches!(xs.dtype(), DType::BF16 | DType::F16 | DType::F32) {
1095        return Ok(None);
1096    }
1097
1098    let Some(gate_q) = gate.get_qtensor() else {
1099        return Ok(None);
1100    };
1101    let Some(up_q) = up.get_qtensor() else {
1102        return Ok(None);
1103    };
1104    if gate_q.dtype() != GgmlDType::Q8_0 || up_q.dtype() != GgmlDType::Q8_0 {
1105        return Ok(None);
1106    }
1107    if gate_q.shape() != up_q.shape() {
1108        return Ok(None);
1109    }
1110
1111    let Some((&k, batch_dims)) = xs.dims().split_last() else {
1112        return Ok(None);
1113    };
1114    let flat_batch = batch_dims.iter().product::<usize>();
1115    if flat_batch == 0 || flat_batch > gguf::fast_mmvq::MMVQ_MAX_BATCH {
1116        return Ok(None);
1117    }
1118    let (_, ncols) = gate_q.shape().dims2()?;
1119    if k != ncols {
1120        return Ok(None);
1121    }
1122
1123    Ok(Some(gguf::fast_mmvq::fused_glu(
1124        gate_q, up_q, xs, activation,
1125    )?))
1126}
1127
1128#[cfg(feature = "cuda")]
1129pub fn try_fused_quantized_qkv(
1130    xs: &Tensor,
1131    q: &dyn QuantMethod,
1132    k: &dyn QuantMethod,
1133    v: &dyn QuantMethod,
1134) -> Result<Option<(Tensor, Tensor, Tensor)>> {
1135    if q.has_bias() || k.has_bias() || v.has_bias() {
1136        return Ok(None);
1137    }
1138    if !matches!(xs.dtype(), DType::BF16 | DType::F16 | DType::F32) {
1139        return Ok(None);
1140    }
1141
1142    let Some(q_q) = q.get_qtensor() else {
1143        return Ok(None);
1144    };
1145    let Some(k_q) = k.get_qtensor() else {
1146        return Ok(None);
1147    };
1148    let Some(v_q) = v.get_qtensor() else {
1149        return Ok(None);
1150    };
1151    let dtype = q_q.dtype();
1152    if dtype != k_q.dtype() || dtype != v_q.dtype() || !gguf::fast_mmvq::supports(dtype) {
1153        return Ok(None);
1154    }
1155
1156    let Some((&input_cols, batch_dims)) = xs.dims().split_last() else {
1157        return Ok(None);
1158    };
1159    let flat_batch = batch_dims.iter().product::<usize>();
1160    if flat_batch == 0 || flat_batch > gguf::fast_mmvq::MMVQ_MAX_BATCH {
1161        return Ok(None);
1162    }
1163    let (_, q_cols) = q_q.shape().dims2()?;
1164    let (_, k_cols) = k_q.shape().dims2()?;
1165    let (_, v_cols) = v_q.shape().dims2()?;
1166    if input_cols != q_cols || input_cols != k_cols || input_cols != v_cols {
1167        return Ok(None);
1168    }
1169
1170    Ok(Some(gguf::fast_mmvq::fused_qkv(q_q, k_q, v_q, xs)?))
1171}
1172
1173/// Metal fused gate+up: single Metal kernel that does both matmuls with shared
1174/// x reads and applies the GLU activation in-register before writing one output.
1175#[cfg(feature = "metal")]
1176pub fn try_fused_gate_up_metal(
1177    xs: &Tensor,
1178    gate: &dyn QuantMethod,
1179    up: &dyn QuantMethod,
1180    activation: GluActivationType,
1181) -> Result<Option<Tensor>> {
1182    use hanzo_ml::{backend::BackendStorage, MetalStorage, Shape, Storage};
1183
1184    if gate.has_bias() || up.has_bias() {
1185        return Ok(None);
1186    }
1187    if !matches!(xs.dtype(), DType::BF16 | DType::F16 | DType::F32) {
1188        return Ok(None);
1189    }
1190    if !xs.device().is_metal() {
1191        return Ok(None);
1192    }
1193
1194    let Some(gi) = gate.afq_inner() else {
1195        return Ok(None);
1196    };
1197    let Some(ui) = up.afq_inner() else {
1198        return Ok(None);
1199    };
1200    if gi.bits != ui.bits || gi.group_size != ui.group_size {
1201        return Ok(None);
1202    }
1203    if gi.scales.dtype() != ui.scales.dtype() {
1204        return Ok(None);
1205    }
1206    if gi.w_q.rank() != 2 || ui.w_q.rank() != 2 {
1207        return Ok(None);
1208    }
1209    let k = xs.dim(D::Minus1)?;
1210    let n_gate = gi.w_q.dim(0)?;
1211    let n_up = ui.w_q.dim(0)?;
1212    if n_gate != n_up {
1213        return Ok(None);
1214    }
1215    let n = n_gate;
1216    // qmm_t kernel uses BM=32 tiles; for small M (decode) it wastes most of the
1217    // tile. Let the caller fall back to separate qmv-based forwards.
1218    let probe_m = xs.elem_count() / k;
1219    if probe_m < 16 {
1220        return Ok(None);
1221    }
1222    if k * gi.bits as usize / 8 / 4 != gi.w_q.dim(1)? {
1223        // unexpected pack factor; let the generic path handle it
1224        return Ok(None);
1225    }
1226
1227    let act_code: u32 = match activation {
1228        GluActivationType::Silu => 0,
1229        GluActivationType::Gelu => 1,
1230        GluActivationType::GeluErf => 2,
1231        GluActivationType::Relu => 3,
1232    };
1233
1234    let xs = xs.contiguous()?;
1235    let m = xs.elem_count() / k;
1236    if m == 0 {
1237        return Ok(None);
1238    }
1239
1240    let (xs_storage, xs_layout) = xs.storage_and_layout();
1241    let Storage::Metal(xs_storage) = &*xs_storage else {
1242        return Ok(None);
1243    };
1244    let (g_w_s, _) = gi.w_q.storage_and_layout();
1245    let Storage::Metal(g_w_s) = &*g_w_s else {
1246        return Ok(None);
1247    };
1248    let (g_s_s, _) = gi.scales.storage_and_layout();
1249    let Storage::Metal(g_s_s) = &*g_s_s else {
1250        return Ok(None);
1251    };
1252    let (g_b_s, _) = gi.biases.storage_and_layout();
1253    let Storage::Metal(g_b_s) = &*g_b_s else {
1254        return Ok(None);
1255    };
1256    let (u_w_s, _) = ui.w_q.storage_and_layout();
1257    let Storage::Metal(u_w_s) = &*u_w_s else {
1258        return Ok(None);
1259    };
1260    let (u_s_s, _) = ui.scales.storage_and_layout();
1261    let Storage::Metal(u_s_s) = &*u_s_s else {
1262        return Ok(None);
1263    };
1264    let (u_b_s, _) = ui.biases.storage_and_layout();
1265    let Storage::Metal(u_b_s) = &*u_b_s else {
1266        return Ok(None);
1267    };
1268
1269    let device = xs_storage.device().clone();
1270    let dtype = xs.dtype();
1271    let mut out_shape = xs.dims().to_vec();
1272    *out_shape.last_mut().unwrap() = n;
1273    let out = device.new_buffer(out_shape.iter().product(), dtype, "afq-gate-up-out")?;
1274
1275    let encoder = device.command_encoder()?;
1276    encoder.set_label("afq-gate-up");
1277
1278    metal_kernels::call_afq_qmm_gate_up(
1279        device.device(),
1280        &encoder,
1281        &metal_kernels::Kernels::new(),
1282        dtype,
1283        (
1284            xs_storage.buffer(),
1285            xs_layout.start_offset() * dtype.size_in_bytes(),
1286        ),
1287        g_w_s.buffer(),
1288        g_s_s.buffer(),
1289        g_b_s.buffer(),
1290        u_w_s.buffer(),
1291        u_s_s.buffer(),
1292        u_b_s.buffer(),
1293        &out,
1294        m,
1295        n,
1296        k,
1297        gi.bits as usize,
1298        gi.group_size as usize,
1299        act_code,
1300    )
1301    .map_err(hanzo_ml::Error::wrap)?;
1302
1303    let out_t = Tensor::from((
1304        Storage::Metal(MetalStorage::new(
1305            out,
1306            device.clone(),
1307            out_shape.iter().product(),
1308            dtype,
1309        )),
1310        Shape::from(out_shape),
1311    ));
1312    Ok(Some(out_t))
1313}
1314
1315/// Metal fused QKV: single Metal kernel that handles all three projections,
1316/// routing per-tile to the right weight matrix.
1317#[cfg(feature = "metal")]
1318pub fn try_fused_qkv_metal(
1319    xs: &Tensor,
1320    q: &dyn QuantMethod,
1321    k: &dyn QuantMethod,
1322    v: &dyn QuantMethod,
1323) -> Result<Option<(Tensor, Tensor, Tensor)>> {
1324    use hanzo_ml::{backend::BackendStorage, MetalStorage, Shape, Storage};
1325
1326    if q.has_bias() || k.has_bias() || v.has_bias() {
1327        return Ok(None);
1328    }
1329    if !matches!(xs.dtype(), DType::BF16 | DType::F16 | DType::F32) {
1330        return Ok(None);
1331    }
1332    if !xs.device().is_metal() {
1333        return Ok(None);
1334    }
1335
1336    let Some(qi) = q.afq_inner() else {
1337        return Ok(None);
1338    };
1339    let Some(ki) = k.afq_inner() else {
1340        return Ok(None);
1341    };
1342    let Some(vi) = v.afq_inner() else {
1343        return Ok(None);
1344    };
1345    if qi.bits != ki.bits || qi.bits != vi.bits {
1346        return Ok(None);
1347    }
1348    if qi.group_size != ki.group_size || qi.group_size != vi.group_size {
1349        return Ok(None);
1350    }
1351    if qi.scales.dtype() != ki.scales.dtype() || qi.scales.dtype() != vi.scales.dtype() {
1352        return Ok(None);
1353    }
1354    if qi.w_q.rank() != 2 || ki.w_q.rank() != 2 || vi.w_q.rank() != 2 {
1355        return Ok(None);
1356    }
1357    let n_q = qi.w_q.dim(0)?;
1358    let n_k = ki.w_q.dim(0)?;
1359    let n_v = vi.w_q.dim(0)?;
1360    // The kernel routes by tile-aligned column boundaries; require N_q and
1361    // N_k to be multiples of the tile width (32). For Gemma-style models
1362    // those are already 32-multiples; fall back when they're not.
1363    if n_q % 32 != 0 || n_k % 32 != 0 || n_v % 32 != 0 {
1364        return Ok(None);
1365    }
1366    let k_dim = xs.dim(D::Minus1)?;
1367    // qmm_t kernel uses BM=32; for small M (decode) the tile is mostly empty.
1368    // Fall back to separate qmv calls.
1369    let probe_m = xs.elem_count() / k_dim;
1370    if probe_m < 16 {
1371        return Ok(None);
1372    }
1373
1374    let xs = xs.contiguous()?;
1375    let m = xs.elem_count() / k_dim;
1376    if m == 0 {
1377        return Ok(None);
1378    }
1379
1380    let (xs_s, xs_l) = xs.storage_and_layout();
1381    let Storage::Metal(xs_s) = &*xs_s else {
1382        return Ok(None);
1383    };
1384    let qws = qi.w_q.storage_and_layout().0;
1385    let qss = qi.scales.storage_and_layout().0;
1386    let qbs = qi.biases.storage_and_layout().0;
1387    let kws = ki.w_q.storage_and_layout().0;
1388    let kss = ki.scales.storage_and_layout().0;
1389    let kbs = ki.biases.storage_and_layout().0;
1390    let vws = vi.w_q.storage_and_layout().0;
1391    let vss = vi.scales.storage_and_layout().0;
1392    let vbs = vi.biases.storage_and_layout().0;
1393    let (Storage::Metal(qw_m), Storage::Metal(qs_m), Storage::Metal(qb_m)) = (&*qws, &*qss, &*qbs)
1394    else {
1395        return Ok(None);
1396    };
1397    let (Storage::Metal(kw_m), Storage::Metal(ks_m), Storage::Metal(kb_m)) = (&*kws, &*kss, &*kbs)
1398    else {
1399        return Ok(None);
1400    };
1401    let (Storage::Metal(vw_m), Storage::Metal(vs_m), Storage::Metal(vb_m)) = (&*vws, &*vss, &*vbs)
1402    else {
1403        return Ok(None);
1404    };
1405
1406    let device = xs_s.device().clone();
1407    let dtype = xs.dtype();
1408    let mut q_shape = xs.dims().to_vec();
1409    let mut k_shape = q_shape.clone();
1410    let mut v_shape = q_shape.clone();
1411    *q_shape.last_mut().unwrap() = n_q;
1412    *k_shape.last_mut().unwrap() = n_k;
1413    *v_shape.last_mut().unwrap() = n_v;
1414    let q_out = device.new_buffer(q_shape.iter().product(), dtype, "afq-qkv-q")?;
1415    let k_out = device.new_buffer(k_shape.iter().product(), dtype, "afq-qkv-k")?;
1416    let v_out = device.new_buffer(v_shape.iter().product(), dtype, "afq-qkv-v")?;
1417
1418    let encoder = device.command_encoder()?;
1419    encoder.set_label("afq-qkv");
1420
1421    metal_kernels::call_afq_qmm_qkv(
1422        device.device(),
1423        &encoder,
1424        &metal_kernels::Kernels::new(),
1425        dtype,
1426        (xs_s.buffer(), xs_l.start_offset() * dtype.size_in_bytes()),
1427        qw_m.buffer(),
1428        qs_m.buffer(),
1429        qb_m.buffer(),
1430        kw_m.buffer(),
1431        ks_m.buffer(),
1432        kb_m.buffer(),
1433        vw_m.buffer(),
1434        vs_m.buffer(),
1435        vb_m.buffer(),
1436        &q_out,
1437        &k_out,
1438        &v_out,
1439        m,
1440        n_q,
1441        n_k,
1442        n_v,
1443        k_dim,
1444        qi.bits as usize,
1445        qi.group_size as usize,
1446    )
1447    .map_err(hanzo_ml::Error::wrap)?;
1448
1449    let q_t = Tensor::from((
1450        Storage::Metal(MetalStorage::new(
1451            q_out,
1452            device.clone(),
1453            q_shape.iter().product(),
1454            dtype,
1455        )),
1456        Shape::from(q_shape),
1457    ));
1458    let k_t = Tensor::from((
1459        Storage::Metal(MetalStorage::new(
1460            k_out,
1461            device.clone(),
1462            k_shape.iter().product(),
1463            dtype,
1464        )),
1465        Shape::from(k_shape),
1466    ));
1467    let v_t = Tensor::from((
1468        Storage::Metal(MetalStorage::new(
1469            v_out,
1470            device.clone(),
1471            v_shape.iter().product(),
1472            dtype,
1473        )),
1474        Shape::from(v_shape),
1475    ));
1476    Ok(Some((q_t, k_t, v_t)))
1477}
1478
1479fn tensor_prefix(vb: &ShardedVarBuilder) -> String {
1480    let prefix = vb.prefix();
1481    if prefix.is_empty() {
1482        "<root>".to_string()
1483    } else {
1484        prefix
1485    }
1486}
1487
1488fn missing_required_tensors(vb: &ShardedVarBuilder, required: &[&str]) -> Vec<String> {
1489    required
1490        .iter()
1491        .copied()
1492        .filter(|name| !vb.contains_tensor(name))
1493        .map(|name| safetensors::full_tensor_name(vb, name))
1494        .collect()
1495}
1496
1497pub(crate) fn has_missing_required_tensors(vb: &ShardedVarBuilder, required: &[&str]) -> bool {
1498    required.iter().any(|name| !vb.contains_tensor(name))
1499}
1500
1501pub(crate) fn make_dummy_or_error(
1502    context: &str,
1503    vb: &ShardedVarBuilder,
1504    required: &[&str],
1505) -> Result<Arc<dyn QuantMethod>> {
1506    let missing = missing_required_tensors(vb, required);
1507    if missing.is_empty() {
1508        hanzo_ml::bail!(
1509            "Internal error: requested DummyLayer for {context} without missing tensors"
1510        );
1511    }
1512
1513    let has_uqff_placeholder = required
1514        .iter()
1515        .any(|name| safetensors::is_uqff_dummy_tensor(vb, name));
1516    if !has_uqff_placeholder {
1517        hanzo_ml::bail!(
1518            "Missing required tensor(s) for {context} at prefix `{}`: {}. Dummy layers are only allowed for tensors intentionally omitted while loading UQFF artifacts.",
1519            tensor_prefix(vb),
1520            missing.join(", ")
1521        );
1522    }
1523
1524    Ok(Arc::new(DummyLayer::placeholder(DummyLayerInfo {
1525        context: context.to_string(),
1526        prefix: tensor_prefix(vb),
1527        missing_tensors: missing,
1528    })))
1529}
1530
1531pub fn linear_no_bias(
1532    in_dim: usize,
1533    out_dim: usize,
1534    config: &Option<QuantizedConfig>,
1535    vb: ShardedVarBuilder,
1536) -> Result<Arc<dyn QuantMethod>> {
1537    let base_vb = vb.clone();
1538    let vb = if should_apply_immediate_isq(&vb) {
1539        vb.set_device(Device::Cpu)
1540    } else {
1541        vb
1542    };
1543
1544    let layer = if let Some(quant_conf) = &config {
1545        match quant_conf {
1546            QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
1547            QuantizedConfig::Fp8 { weight_block_size } => {
1548                if weight_block_size.is_some() {
1549                    blockwise_fp8_linear_b(
1550                        in_dim,
1551                        out_dim,
1552                        quant_conf,
1553                        false,
1554                        Default::default(),
1555                        vb,
1556                    )?
1557                } else {
1558                    pertensor_fp8_linear_b(
1559                        in_dim,
1560                        out_dim,
1561                        quant_conf,
1562                        false,
1563                        Default::default(),
1564                        vb,
1565                    )?
1566                }
1567            }
1568            QuantizedConfig::Bitsandbytes { .. } => {
1569                Arc::new(BnbLinear::linear_b(in_dim, out_dim, false, vb)?) as Arc<_>
1570            }
1571            QuantizedConfig::Afq { .. } => {
1572                AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, false, vb)?
1573            }
1574            QuantizedConfig::MXFP4 {} => {
1575                MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, false, vb)?
1576            }
1577        }
1578    } else {
1579        if !vb.contains_tensor("weight") {
1580            make_dummy_or_error("linear_no_bias", &vb, &["weight"])?
1581        } else {
1582            let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
1583            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
1584
1585            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
1586                Linear::new(weight, None),
1587            ))?;
1588            Arc::new(layer) as Arc<dyn QuantMethod>
1589        }
1590    };
1591    apply_immediate_isq(layer, base_vb)
1592}
1593
1594pub fn linear(
1595    in_dim: usize,
1596    out_dim: usize,
1597    config: &Option<QuantizedConfig>,
1598    vb: ShardedVarBuilder,
1599) -> Result<Arc<dyn QuantMethod>> {
1600    let base_vb = vb.clone();
1601    let vb = if should_apply_immediate_isq(&vb) {
1602        vb.set_device(Device::Cpu)
1603    } else {
1604        vb
1605    };
1606
1607    let layer = if let Some(quant_conf) = &config {
1608        match quant_conf {
1609            QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
1610            QuantizedConfig::Fp8 { weight_block_size } => {
1611                if weight_block_size.is_some() {
1612                    blockwise_fp8_linear_b(
1613                        in_dim,
1614                        out_dim,
1615                        quant_conf,
1616                        true,
1617                        Default::default(),
1618                        vb,
1619                    )?
1620                } else {
1621                    pertensor_fp8_linear_b(
1622                        in_dim,
1623                        out_dim,
1624                        quant_conf,
1625                        true,
1626                        Default::default(),
1627                        vb,
1628                    )?
1629                }
1630            }
1631            QuantizedConfig::Bitsandbytes { .. } => {
1632                Arc::new(BnbLinear::linear_b(in_dim, out_dim, true, vb)?) as Arc<_>
1633            }
1634            QuantizedConfig::Afq { .. } => {
1635                AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, true, vb)?
1636            }
1637            QuantizedConfig::MXFP4 {} => {
1638                MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, true, vb)?
1639            }
1640        }
1641    } else {
1642        if has_missing_required_tensors(&vb, &["weight", "bias"]) {
1643            make_dummy_or_error("linear", &vb, &["weight", "bias"])?
1644        } else {
1645            let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
1646            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
1647            let bias = vb.get_with_hints((out_dim,), "bias", Default::default())?;
1648
1649            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
1650                Linear::new(weight, Some(bias)),
1651            ))?;
1652            Arc::new(layer) as Arc<dyn QuantMethod>
1653        }
1654    };
1655    apply_immediate_isq(layer, base_vb)
1656}
1657
1658pub fn linear_b(
1659    in_dim: usize,
1660    out_dim: usize,
1661    bias: bool,
1662    config: &Option<QuantizedConfig>,
1663    vb: ShardedVarBuilder,
1664) -> Result<Arc<dyn QuantMethod>> {
1665    if bias {
1666        linear(in_dim, out_dim, config, vb)
1667    } else {
1668        linear_no_bias(in_dim, out_dim, config, vb)
1669    }
1670}
1671
1672#[cfg(test)]
1673mod tests {
1674    use std::collections::HashMap;
1675
1676    use super::*;
1677
1678    fn empty_vb(make_dummy_regexes: Option<Vec<&str>>) -> ShardedVarBuilder {
1679        let backend: HashMap<String, Tensor> = HashMap::new();
1680        let make_dummy_regexes = make_dummy_regexes.map(|regexes| {
1681            Arc::new(
1682                regexes
1683                    .into_iter()
1684                    .map(Regex::new)
1685                    .collect::<std::result::Result<Vec<_>, _>>()
1686                    .unwrap(),
1687            )
1688        });
1689        ShardedSafeTensors::wrap_with_dummy_regexes(
1690            Box::new(backend),
1691            DType::F32,
1692            Device::Cpu,
1693            make_dummy_regexes,
1694        )
1695    }
1696
1697    #[test]
1698    fn missing_linear_weight_outside_uqff_errors() {
1699        let err = linear_no_bias(2, 3, &None, empty_vb(None).pp("foo")).unwrap_err();
1700        let msg = err.to_string();
1701
1702        assert!(msg.contains("Missing required tensor(s)"));
1703        assert!(msg.contains("foo.weight"));
1704        assert!(msg.contains("UQFF"));
1705    }
1706
1707    #[test]
1708    fn missing_uqff_placeholder_creates_contextual_dummy() -> Result<()> {
1709        let layer = linear_no_bias(
1710            2,
1711            3,
1712            &None,
1713            empty_vb(Some(vec![r"^foo\.weight$"])).pp("foo"),
1714        )?;
1715
1716        let info = layer.dummy_info().unwrap();
1717        assert_eq!(layer.name(), "dummy");
1718        assert_eq!(info.context, "linear_no_bias");
1719        assert_eq!(info.prefix, "foo");
1720        assert_eq!(info.missing_tensors, vec!["foo.weight"]);
1721
1722        let input = Tensor::zeros((1, 2), DType::F32, &Device::Cpu)?;
1723        let err = layer.forward_raw(&input).unwrap_err();
1724        let msg = err.to_string();
1725        assert!(msg.contains("forward pass"));
1726        assert!(msg.contains("foo.weight"));
1727        assert!(msg.contains("temporary UQFF placeholders"));
1728
1729        Ok(())
1730    }
1731}