use std::collections::{BTreeSet, HashMap};
use serde::{Deserialize, Deserializer};
use crate::{
array::Array,
dtype::Dtype,
error::{
ArithmeticOverflowPayload, DivisibilityConstraintPayload, Error, InvariantViolationPayload,
KeyCollisionPayload, LayerKeyedPayload, LengthMismatchPayload, MissingKeyPayload,
OutOfRangePayload, ParsePayload, RankMismatchPayload, Result, ShapePairMismatchPayload,
UnknownEnumValuePayload, UnsupportedDtypePayload,
},
lm::load::Weights,
ops,
};
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Eq,
Hash,
serde::Deserialize,
serde::Serialize,
derive_more::Display,
derive_more::IsVariant,
)]
#[display("{}", self.as_str())]
#[non_exhaustive]
#[serde(rename_all = "lowercase")]
pub enum QuantMode {
Affine,
Mxfp4,
Mxfp8,
Nvfp4,
}
impl Default for QuantMode {
fn default() -> Self {
QuantMode::Affine
}
}
impl QuantMode {
pub const fn as_str(self) -> &'static str {
match self {
QuantMode::Affine => "affine",
QuantMode::Mxfp4 => "mxfp4",
QuantMode::Mxfp8 => "mxfp8",
QuantMode::Nvfp4 => "nvfp4",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize)]
pub struct Quantization {
pub group_size: i32,
pub bits: i32,
#[serde(default)]
pub mode: QuantMode,
}
impl Quantization {
pub fn affine(group_size: i32, bits: i32) -> Self {
Self {
group_size,
bits,
mode: QuantMode::Affine,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, derive_more::IsVariant)]
#[non_exhaustive]
pub enum QuantizationOption {
Skip,
Quantize(Quantization),
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct PerLayerQuantization {
pub quantization: Option<Quantization>,
per_layer: HashMap<String, QuantizationOption>,
}
impl PerLayerQuantization {
pub fn new(
quantization: Option<Quantization>,
per_layer: HashMap<String, QuantizationOption>,
) -> Self {
Self {
quantization,
per_layer,
}
}
pub fn from_global(q: Quantization) -> Self {
Self {
quantization: Some(q),
per_layer: HashMap::new(),
}
}
#[inline(always)]
pub fn per_layer_ref(&self) -> &HashMap<String, QuantizationOption> {
&self.per_layer
}
pub fn quantization_for(&self, layer: &str) -> Option<Quantization> {
match self.per_layer.get(layer) {
Some(QuantizationOption::Skip) => None,
Some(QuantizationOption::Quantize(q)) => Some(*q),
None => self.quantization,
}
}
}
impl<'de> Deserialize<'de> for PerLayerQuantization {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'de>,
{
use serde::de::Error as _;
use serde_json::Value;
let value = Value::deserialize(deserializer)?;
let Value::Object(map) = value else {
return Err(D::Error::custom("quantization block must be a JSON object"));
};
const RESERVED: &[&str] = &[
"group_size",
"bits",
"mode",
"quant_method",
"linear_class",
"quantization_mode",
];
let is_reserved = |k: &str| RESERVED.contains(&k);
if !map.contains_key("group_size") {
return Err(D::Error::custom(
"`quantization` block is missing required key `group_size`",
));
}
if !map.contains_key("bits") {
return Err(D::Error::custom(
"`quantization` block is missing required key `bits`",
));
}
let mut globals = serde_json::Map::new();
for k in ["group_size", "bits", "mode"] {
if let Some(v) = map.get(k) {
globals.insert(k.to_string(), v.clone());
}
}
let quantization = Some(
serde_json::from_value::<Quantization>(Value::Object(globals)).map_err(D::Error::custom)?,
);
let mut per_layer: HashMap<String, QuantizationOption> = HashMap::new();
for (key, v) in &map {
if is_reserved(key) {
continue;
}
let opt = match v {
Value::Bool(false) => QuantizationOption::Skip,
Value::Bool(true) => continue,
Value::Object(_) => {
let q = serde_json::from_value::<Quantization>(v.clone()).map_err(D::Error::custom)?;
QuantizationOption::Quantize(q)
}
other => {
return Err(D::Error::custom(format!(
"per-layer quantization value at {key:?} must be `false` or a quantization object, got {other:?}"
)));
}
};
per_layer.insert(key.clone(), opt);
}
Ok(PerLayerQuantization {
quantization,
per_layer,
})
}
}
pub fn parse_quantization(config_json: &str) -> Result<Option<PerLayerQuantization>> {
let value: serde_json::Value = serde_json::from_str(config_json)
.map_err(|e| Error::Parse(ParsePayload::new("parse_quantization", "config JSON", e)))?;
let Some(block) = value.get("quantization") else {
return Ok(None);
};
let plq: PerLayerQuantization = serde_json::from_value(block.clone()).map_err(|e| {
Error::Parse(ParsePayload::new(
"parse_quantization",
"`quantization` block",
e,
))
})?;
Ok(Some(plq))
}
const WEIGHT_SUFFIX: &str = ".weight";
const SCALES_SUFFIX: &str = ".scales";
const BIASES_SUFFIX: &str = ".biases";
pub type Eligible<'a> = dyn Fn(&str, &Array) -> bool + 'a;
pub fn default_eligible(_path: &str, _weight: &Array) -> bool {
true
}
enum TripleClass {
Absent,
Valid,
Invalid(Error),
}
fn classify_triple(
weights: &Weights,
layer_path: &str,
layer_weight: &Array,
cfg: &PerLayerQuantization,
) -> TripleClass {
let scales_key = format!("{layer_path}{SCALES_SUFFIX}");
let biases_key = format!("{layer_path}{BIASES_SUFFIX}");
let scales = weights.get(&scales_key);
let biases = weights.get(&biases_key);
match (scales, biases) {
(None, None) => TripleClass::Absent,
(None, Some(_)) => TripleClass::Invalid(Error::LayerKeyed(LayerKeyedPayload::new(
layer_path.to_string(),
Error::MissingKey(MissingKeyPayload::new(
"quantize_weights: stale `.biases` with no matching `.scales` \
(mlx `quantize` always writes `.scales` alongside `.biases`; refusing to \
silently overwrite the generated bias)",
scales_key.clone(),
)),
))),
(Some(s), b_opt) => {
let q = match cfg.per_layer.get(layer_path) {
Some(QuantizationOption::Skip) => {
return TripleClass::Invalid(Error::LayerKeyed(LayerKeyedPayload::new(
layer_path.to_string(),
Error::KeyCollision(KeyCollisionPayload::new(
"quantize_weights: input carries `.scales` but the per-layer config marks this \
layer as `Skip` (not quantized) — refusing to silently treat the stale triple \
as a valid already-quantized layer",
scales_key.clone(),
)),
)));
}
Some(QuantizationOption::Quantize(q)) => *q,
None => match cfg.quantization {
Some(q) => q,
None => {
return TripleClass::Invalid(Error::LayerKeyed(LayerKeyedPayload::new(
layer_path.to_string(),
Error::InvariantViolation(InvariantViolationPayload::new(
"quantize_weights: input carries `.scales` but `cfg` has no global \
`Quantization` and no per-layer override for this layer — cannot resolve \
expected `.scales` shape (defensive: any parsed `quantization` block \
carries group_size + bits)",
"quantization parameters must be resolvable",
)),
)));
}
},
};
match (q.mode, b_opt) {
(QuantMode::Affine, None) => {
return TripleClass::Invalid(Error::LayerKeyed(LayerKeyedPayload::new(
layer_path.to_string(),
Error::MissingKey(MissingKeyPayload::new(
"quantize_weights: `affine` mode requires `.biases` alongside `.scales` \
(mlx `affine_quantize` always writes `{w_q, scales, biases}`, \
mlx/ops.cpp:4793-4798); this is a structurally incomplete affine triple",
biases_key.clone(),
)),
)));
}
(QuantMode::Mxfp4 | QuantMode::Mxfp8 | QuantMode::Nvfp4, Some(_)) => {
return TripleClass::Invalid(Error::LayerKeyed(LayerKeyedPayload::new(
layer_path.to_string(),
Error::KeyCollision(KeyCollisionPayload::new(
"quantize_weights: scale-only mode (mxfp4 / mxfp8 / nvfp4) must not carry \
`.biases` (mlx `fp_quantize` writes `{w_q, scales}` with no biases, \
mlx/ops.cpp:4890,4898-4904); refusing to silently retain a bias from a \
different (affine) mode",
biases_key.clone(),
)),
)));
}
_ => {}
}
let w_dtype = match layer_weight.dtype() {
Ok(d) => d,
Err(e) => {
return TripleClass::Invalid(Error::LayerKeyed(LayerKeyedPayload::new(
layer_path.to_string(),
e,
)));
}
};
if w_dtype != Dtype::U32 {
return TripleClass::Invalid(Error::LayerKeyed(LayerKeyedPayload::new(
layer_path.to_string(),
Error::UnsupportedDtype(UnsupportedDtypePayload::new(
"quantize_weights: input has `.scales` but `.weight` dtype must be uint32 \
(mlx-quantized `.weight` is always uint32 — mlx/ops.cpp:4795,4900); this is \
a stale `.scales` orphan next to a dense `.weight`, not a valid \
already-quantized triple",
w_dtype,
&[Dtype::U32],
)),
)));
}
let w_shape = layer_weight.shape();
let w_rank = w_shape.len();
if w_rank < 2 {
return TripleClass::Invalid(Error::LayerKeyed(LayerKeyedPayload::new(
layer_path.to_string(),
Error::RankMismatch(RankMismatchPayload::new(
"quantize_weights: `.weight` next to `.scales` must be rank-2 (mlx `quantize` \
requires rank >= 2 inputs — mlx/ops.cpp:4925-4929; a uint32 1-D / scalar \
`.weight` next to a `.scales` is not a layout mlx's `quantize` can have produced)",
w_rank as u32,
w_shape,
)),
)));
}
let s_shape = s.shape();
if s_shape.len() != w_shape.len() {
return TripleClass::Invalid(Error::LayerKeyed(LayerKeyedPayload::new(
layer_path.to_string(),
Error::LengthMismatch(LengthMismatchPayload::new(
"quantize_weights: `.scales` rank vs `.weight` rank — mlx `quantize` preserves \
the leading shape across the packed `.weight` / `.scales` / `.biases` outputs \
(mlx/ops.cpp:4789-4798)",
w_shape.len(),
s_shape.len(),
)),
)));
}
if s_shape[..s_shape.len() - 1] != w_shape[..w_shape.len() - 1] {
return TripleClass::Invalid(Error::LayerKeyed(LayerKeyedPayload::new(
layer_path.to_string(),
Error::ShapePairMismatch(ShapePairMismatchPayload::new(
"quantize_weights: `.scales` leading dims (all but the last) must match \
`.weight` leading dims — mlx `quantize` preserves all-but-last dims",
w_shape[..w_shape.len() - 1].to_vec(),
s_shape[..s_shape.len() - 1].to_vec(),
)),
)));
}
TripleClass::Valid
}
}
}
pub fn quantize_weights(
weights: Weights,
cfg: &PerLayerQuantization,
eligible: &Eligible<'_>,
) -> Result<Weights> {
let mut out: Weights = HashMap::with_capacity(weights.len());
let mut to_quantize: Vec<(String, Quantization)> = Vec::new();
for (key, arr) in &weights {
let Some(layer_path) = key.strip_suffix(WEIGHT_SUFFIX) else {
continue;
};
match classify_triple(&weights, layer_path, arr, cfg) {
TripleClass::Absent => {}
TripleClass::Valid => continue,
TripleClass::Invalid(err) => return Err(err),
}
if !eligible(layer_path, arr) {
continue;
}
let Some(q) = cfg.quantization_for(layer_path) else {
continue;
};
let shape = arr.shape();
if shape.len() < 2 {
continue;
}
let last = *shape.last().expect("len >= 2");
let gs = usize::try_from(q.group_size).map_err(|_| {
Error::LayerKeyed(LayerKeyedPayload::new(
layer_path.to_string(),
Error::OutOfRange(OutOfRangePayload::new(
"quantize_weights: group_size",
"must be a non-negative i32",
q.group_size.to_string(),
)),
))
})?;
if gs == 0 || last % gs != 0 {
continue;
}
to_quantize.push((layer_path.to_string(), q));
}
let quantize_set: HashMap<String, Quantization> = to_quantize.into_iter().collect();
for (key, arr) in weights {
let layer_path = key.strip_suffix(WEIGHT_SUFFIX);
let quant_match = layer_path.and_then(|p| quantize_set.get(p).map(|q| (p, *q)));
if let Some((path, q)) = quant_match {
let (w_q, scales, biases) =
ops::quantized::quantize(&arr, q.group_size, q.bits, q.mode.as_str(), None)?;
out.insert(format!("{path}{WEIGHT_SUFFIX}"), w_q);
out.insert(format!("{path}{SCALES_SUFFIX}"), scales);
if let Some(b) = biases {
out.insert(format!("{path}{BIASES_SUFFIX}"), b);
}
} else {
out.insert(key, arr);
}
}
Ok(out)
}
pub fn dequantize_weights(weights: Weights, cfg: &PerLayerQuantization) -> Result<Weights> {
let mut out: Weights = HashMap::with_capacity(weights.len());
for key in weights.keys() {
if let Some(path) = key.strip_suffix(BIASES_SUFFIX) {
let scales_key = format!("{path}{SCALES_SUFFIX}");
let weight_key = format!("{path}{WEIGHT_SUFFIX}");
if weights.contains_key(&scales_key) {
continue;
}
let Some(weight_arr) = weights.get(&weight_key) else {
continue;
};
let w_dtype = weight_arr
.dtype()
.map_err(|e| Error::LayerKeyed(LayerKeyedPayload::new(path.to_string(), e)))?;
if w_dtype != Dtype::U32 {
continue;
}
if weight_arr.shape().len() < 2 {
continue;
}
return Err(Error::LayerKeyed(LayerKeyedPayload::new(
path.to_string(),
Error::MissingKey(MissingKeyPayload::new(
"dequantize_weights: stale `.biases` with no matching `.scales` \
(mlx `quantize` always writes `.scales` alongside `.biases`, \
`mlx/ops.cpp:4793-4798`); this is a structurally incomplete triple, \
refusing to silently leave the `uint32`-packed `.weight` as a \
pass-through in the dequantized output",
format!("{path}{SCALES_SUFFIX}"),
)),
)));
}
}
let mut triple_set: HashMap<String, ()> = HashMap::new();
for key in weights.keys() {
if let Some(path) = key.strip_suffix(".scales") {
let weight_key = format!("{path}.weight");
if weights.contains_key(&weight_key) {
triple_set.insert(path.to_string(), ());
}
}
}
type StagedTriple = (Option<Array>, Option<Array>, Option<Array>);
let mut staged: HashMap<String, StagedTriple> = HashMap::new();
for (key, arr) in weights {
let component = if let Some(path) = key.strip_suffix(WEIGHT_SUFFIX) {
triple_set.contains_key(path).then(|| (path.to_string(), 0))
} else if let Some(path) = key.strip_suffix(".scales") {
triple_set.contains_key(path).then(|| (path.to_string(), 1))
} else if let Some(path) = key.strip_suffix(".biases") {
triple_set.contains_key(path).then(|| (path.to_string(), 2))
} else {
None
};
if let Some((path, slot)) = component {
let entry = staged.entry(path).or_insert((None, None, None));
match slot {
0 => entry.0 = Some(arr),
1 => entry.1 = Some(arr),
2 => entry.2 = Some(arr),
_ => unreachable!(),
}
} else {
out.insert(key, arr);
}
}
for (path, (w_opt, s_opt, b_opt)) in staged {
let w = w_opt.ok_or_else(|| {
Error::MissingKey(MissingKeyPayload::new(
"dequantize_weights: triple missing `.weight`",
format!("{path}{WEIGHT_SUFFIX}"),
))
})?;
let scales = s_opt.ok_or_else(|| {
Error::MissingKey(MissingKeyPayload::new(
"dequantize_weights: triple missing `.scales`",
format!("{path}{SCALES_SUFFIX}"),
))
})?;
let q = cfg.quantization_for(&path).ok_or_else(|| {
Error::LayerKeyed(LayerKeyedPayload::new(
path.to_string(),
Error::InvariantViolation(InvariantViolationPayload::new(
"dequantize_weights: quantization parameters",
"must be resolvable (no global default and no per-layer override)",
)),
))
})?;
match (q.mode, b_opt.as_ref()) {
(QuantMode::Affine, None) => {
return Err(Error::LayerKeyed(LayerKeyedPayload::new(
path.to_string(),
Error::MissingKey(MissingKeyPayload::new(
"dequantize_weights: `affine` mode requires `.biases` alongside `.scales` \
(mlx `affine_dequantize` takes `{w_q, scales, biases}`, mlx/ops.cpp:5085-5099)",
format!("{path}.biases"),
)),
)));
}
(QuantMode::Mxfp4 | QuantMode::Mxfp8 | QuantMode::Nvfp4, Some(_)) => {
return Err(Error::LayerKeyed(LayerKeyedPayload::new(
path.to_string(),
Error::KeyCollision(KeyCollisionPayload::new(
"dequantize_weights: scale-only mode (mxfp4 / mxfp8 / nvfp4) must not carry \
`.biases` (mlx `fp_dequantize` takes `{w_q, scales}` with no biases, \
mlx/ops.cpp:5198-5210); refusing to silently retain a bias from a \
different (affine) mode",
format!("{path}.biases"),
)),
)));
}
_ => {}
}
let dense = ops::quantized::dequantize(
&w,
&scales,
b_opt.as_ref(),
q.group_size,
q.bits,
q.mode.as_str(),
None,
None,
)?;
out.insert(format!("{path}.weight"), dense);
}
Ok(out)
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct AwqLoadConfig {
#[serde(default = "AwqLoadConfig::default_bits")]
pub bits: u32,
#[serde(default = "AwqLoadConfig::default_group_size")]
pub group_size: u32,
#[serde(default = "AwqLoadConfig::default_zero_point")]
pub zero_point: bool,
#[serde(default)]
pub version: String,
}
impl AwqLoadConfig {
fn default_bits() -> u32 {
4
}
fn default_group_size() -> u32 {
128
}
fn default_zero_point() -> bool {
true
}
}
impl Default for AwqLoadConfig {
fn default() -> Self {
Self {
bits: Self::default_bits(),
group_size: Self::default_group_size(),
zero_point: Self::default_zero_point(),
version: String::new(),
}
}
}
const AWQ_BITS: u32 = 4;
const AWQ_PACK_FACTOR: usize = 32 / (AWQ_BITS as usize);
const AWQ_NIBBLE_MASK: u32 = (1 << AWQ_BITS) - 1;
const AWQ_SHIFTS: [u32; 8] = [0, 16, 4, 20, 8, 24, 12, 28];
const _: () = assert!(AWQ_BITS == 4 && AWQ_SHIFTS[1] == 4 * AWQ_BITS);
pub fn unpack_awq_weights(qweight: &Array) -> Result<Array> {
let shape = qweight.shape();
let shape_len = shape.len();
if shape_len != 2 {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"unpack_awq_weights: qweight must be 2-D [rows, packed_cols]",
shape_len as u32,
shape,
)));
}
let shape = qweight.shape();
let dtype = qweight.dtype()?;
let owned_view;
let packed_u32: &Array = match dtype {
Dtype::U32 => qweight,
Dtype::I32 => {
owned_view = ops::misc::view(qweight, Dtype::U32)?;
&owned_view
}
other => {
return Err(Error::UnsupportedDtype(UnsupportedDtypePayload::new(
"unpack_awq_weights: qweight (AutoAWQ stores qweight as 32-bit packed nibbles, \
utils.py:72-82; accepts uint32 (mlx-lm canonical) or int32 (AutoAWQ WQLinear_GEMM's \
default torch.int32 allocation))",
other,
&[Dtype::U32, Dtype::I32],
)));
}
};
let rows = shape[0];
let packed_cols = shape[1];
let cols = packed_cols.checked_mul(AWQ_PACK_FACTOR).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"unpack_awq_weights: unpacked col count `packed_cols * 8`",
"usize",
[
("packed_cols", packed_cols as u64),
("multiplier", AWQ_PACK_FACTOR as u64),
],
))
})?;
let shifts = Array::from_slice::<u32>(&AWQ_SHIFTS, &(AWQ_SHIFTS.len(),))?;
let mask = Array::from_slice::<u32>(&[AWQ_NIBBLE_MASK], &(1usize,))?;
let expanded = ops::shape::expand_dims_axes(packed_u32, &[2])?;
let shifted = ops::arithmetic::right_shift(&expanded, &shifts)?;
let nibbles = ops::arithmetic::bitwise_and(&shifted, &mask)?;
ops::shape::reshape(&nibbles, &(rows, cols))
}
fn resolve_awq_model_dtype(
weights: &Weights,
qweight_prefixes: &[String],
) -> Result<Option<Dtype>> {
if qweight_prefixes.is_empty() {
return Ok(None);
}
let mut best: Option<Dtype> = None;
let mut has_f16 = false;
let mut has_bf16 = false;
for prefix in qweight_prefixes {
let scales_key = format!("{prefix}.scales");
let scales = weights.get(&scales_key).ok_or_else(|| {
Error::MissingKey(MissingKeyPayload::new(
"transform_awq_weights: AWQ `.qweight` missing its `.scales` companion \
(AutoAWQ writes `.qweight` / `.scales` / `.qzeros` as a triple — refusing to \
silently drop the layer)",
scales_key.clone(),
))
})?;
let d = scales.dtype()?;
match d {
Dtype::F16 => has_f16 = true,
Dtype::BF16 => has_bf16 = true,
_ => {}
}
match best {
None => best = Some(d),
Some(prev) => {
if floating_dtype_precision_rank(d) > floating_dtype_precision_rank(prev) {
best = Some(d);
}
}
}
}
if let Some(b) = best
&& has_f16
&& has_bf16
&& b != Dtype::F32
&& b != Dtype::F64
{
best = Some(Dtype::F32);
}
Ok(best)
}
fn floating_dtype_precision_rank(d: Dtype) -> u8 {
match d {
Dtype::F64 => 4,
Dtype::F32 => 3,
Dtype::BF16 => 2,
Dtype::F16 => 1,
_ => 0,
}
}
fn validate_awq_scales_are_floating(weights: &Weights, qweight_prefixes: &[String]) -> Result<()> {
for prefix in qweight_prefixes {
let scales_key = format!("{prefix}.scales");
let Some(scales) = weights.get(&scales_key) else {
continue;
};
let d = scales.dtype()?;
if !is_floating(d) {
return Err(Error::LayerKeyed(LayerKeyedPayload::new(
scales_key.clone(),
Error::UnsupportedDtype(UnsupportedDtypePayload::new(
"transform_awq_weights: AutoAWQ `.scales` (any other dtype would corrupt the \
dtype-unification cast)",
d,
&[Dtype::F16, Dtype::F32, Dtype::F64, Dtype::BF16],
)),
)));
}
}
Ok(())
}
fn is_floating(d: Dtype) -> bool {
matches!(d, Dtype::F16 | Dtype::F32 | Dtype::F64 | Dtype::BF16)
}
pub fn transform_awq_weights(
weights: Weights,
config: &AwqLoadConfig,
) -> Result<(Weights, PerLayerQuantization)> {
match config.version.as_str() {
"" | "gemm" => { }
other => {
return Err(Error::UnknownEnumValue(UnknownEnumValuePayload::new(
"transform_awq_weights: AWQ version (only `gemm` is implemented — GEMV checkpoints \
use a different qweight shape, scales layout, and sequential packing; converting one \
through the GEMM path would silently produce corrupt weights — re-quantize with \
`awq --version gemm` if possible)",
other.to_string(),
&["gemm", ""],
)));
}
}
if config.bits != AWQ_BITS {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"transform_awq_weights: AWQ bits (mlx-lm/mlx_lm/utils.py:88-89)",
"must be 4 (only 4-bit AutoAWQ/GPTQ is supported)",
config.bits.to_string(),
)));
}
let group_size = config.group_size;
let group_size_i32 = i32::try_from(group_size).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"transform_awq_weights: group_size",
"must fit in i32",
group_size.to_string(),
))
})?;
for key in weights.keys() {
if key.ends_with(".g_idx") {
return Err(Error::LayerKeyed(LayerKeyedPayload::new(
key.clone(),
Error::InvariantViolation(InvariantViolationPayload::new(
"transform_awq_weights: GPTQ `.g_idx`",
"must not be present (models with non-contiguous group indices are not supported \
by mlx-lm's AutoAWQ on-load converter — mlx-lm/mlx_lm/utils.py:95-100 — please use \
a model without `g_idx` or re-quantize via `mlx_lm.convert`)",
)),
)));
}
}
let mut qweight_prefixes: Vec<String> = weights
.keys()
.filter_map(|k| k.strip_suffix(".qweight").map(str::to_string))
.collect();
qweight_prefixes.sort();
validate_awq_scales_are_floating(&weights, &qweight_prefixes)?;
let model_dtype = resolve_awq_model_dtype(&weights, &qweight_prefixes)?;
for prefix in &qweight_prefixes {
let qweight_key = format!("{prefix}.qweight");
let scales_key = format!("{prefix}.scales");
let weight_key = format!("{prefix}.weight");
if weights.contains_key(&weight_key) {
return Err(Error::KeyCollision(KeyCollisionPayload::new(
"transform_awq_weights: input contains both `.qweight` and `.weight` (the generated \
AWQ output would overwrite the stale dense `.weight`); remove the stale dense weight \
before fusing (precedent: non-AWQ `quantize_weights` refuses analogous orphan/stale \
collisions via `classify_triple`)",
weight_key,
)));
}
let biases_key = format!("{prefix}.biases");
if weights.contains_key(&biases_key) {
return Err(Error::KeyCollision(KeyCollisionPayload::new(
"transform_awq_weights: input contains both `.qweight` and `.biases` (the generated \
AWQ output would overwrite the stale `.biases`); remove the stale biases before \
fusing (precedent: non-AWQ `quantize_weights` refuses analogous orphan/stale \
collisions via `classify_triple`)",
biases_key,
)));
}
let Some(qweight) = weights.get(&qweight_key) else {
return Err(Error::MissingKey(MissingKeyPayload::new(
"transform_awq_weights: `.qweight` missing after prefix scan (defensive)",
qweight_key,
)));
};
let Some(scales) = weights.get(&scales_key) else {
return Err(Error::MissingKey(MissingKeyPayload::new(
"transform_awq_weights: AWQ `.qweight` missing its `.scales` companion \
(AutoAWQ writes `.qweight` / `.scales` / `.qzeros` as a triple — refusing to silently \
drop the layer)",
scales_key,
)));
};
let qw_dtype = qweight.dtype()?;
if !matches!(qw_dtype, Dtype::U32 | Dtype::I32) {
return Err(Error::LayerKeyed(LayerKeyedPayload::new(
qweight_key.clone(),
Error::UnsupportedDtype(UnsupportedDtypePayload::new(
"transform_awq_weights: qweight (AutoAWQ stores packed nibbles as `uint32` \
(mlx-lm canonical) or `int32` (AutoAWQ `WQLinear_GEMM` default `torch.int32` \
allocation))",
qw_dtype,
&[Dtype::U32, Dtype::I32],
)),
)));
}
let q_shape0 = qweight.shape();
let s_shape0 = scales.shape();
let q_rank = q_shape0.len();
let s_rank = s_shape0.len();
if q_rank != 2 {
return Err(Error::LayerKeyed(LayerKeyedPayload::new(
qweight_key.clone(),
Error::RankMismatch(RankMismatchPayload::new(
"transform_awq_weights: qweight must be 2-D [in_features, packed_out]",
q_rank as u32,
q_shape0,
)),
)));
}
if s_rank != 2 {
return Err(Error::LayerKeyed(LayerKeyedPayload::new(
scales_key.clone(),
Error::RankMismatch(RankMismatchPayload::new(
"transform_awq_weights: scales must be 2-D [n_groups, out_features]",
s_rank as u32,
s_shape0,
)),
)));
}
let q_shape = qweight.shape();
let s_shape = scales.shape();
let in_features = q_shape[0];
let packed_out = q_shape[1];
let out_features = packed_out.checked_mul(AWQ_PACK_FACTOR).ok_or_else(|| {
Error::LayerKeyed(LayerKeyedPayload::new(
qweight_key.clone(),
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"transform_awq_weights: out_features = packed_out * AWQ_PACK_FACTOR",
"usize",
[
("packed_out", packed_out as u64),
("multiplier", AWQ_PACK_FACTOR as u64),
],
)),
))
})?;
if group_size as usize == 0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"transform_awq_weights: group_size",
"must be > 0",
group_size.to_string(),
)));
}
if in_features % (group_size as usize) != 0 {
return Err(Error::LayerKeyed(LayerKeyedPayload::new(
qweight_key.clone(),
Error::DivisibilityConstraint(DivisibilityConstraintPayload::new(
"transform_awq_weights: in_features must be a multiple of group_size \
(utils.py:118: n_groups = in_features // group_size)",
"in_features",
in_features as u64,
"group_size",
group_size as u64,
)),
)));
}
let n_groups = in_features / (group_size as usize);
if s_shape[0] != n_groups || s_shape[1] != out_features {
return Err(Error::LayerKeyed(LayerKeyedPayload::new(
prefix.clone(),
Error::ShapePairMismatch(ShapePairMismatchPayload::new(
"transform_awq_weights: scales must be [n_groups, out_features] \
(derived from qweight shape with group_size)",
vec![n_groups, out_features],
s_shape,
)),
)));
}
let qzeros_key = format!("{prefix}.qzeros");
if let Some(qzeros) = weights.get(&qzeros_key) {
let qz_dtype = qzeros.dtype()?;
if !matches!(qz_dtype, Dtype::U32 | Dtype::I32) {
return Err(Error::LayerKeyed(LayerKeyedPayload::new(
qzeros_key.clone(),
Error::UnsupportedDtype(UnsupportedDtypePayload::new(
"transform_awq_weights: qzeros (accept `uint32` mlx-lm canonical or `int32` \
AutoAWQ default `torch.int32`)",
qz_dtype,
&[Dtype::U32, Dtype::I32],
)),
)));
}
let z_shape = qzeros.shape();
if z_shape.len() != 2 || z_shape[0] != n_groups || z_shape[1] != packed_out {
return Err(Error::LayerKeyed(LayerKeyedPayload::new(
qzeros_key.clone(),
Error::ShapePairMismatch(ShapePairMismatchPayload::new(
"transform_awq_weights: qzeros must be [n_groups, packed_out] \
(derived from qweight shape with group_size)",
vec![n_groups, packed_out],
z_shape,
)),
)));
}
}
}
let mut new_weights: Weights = HashMap::with_capacity(weights.len());
let mut awq_generated_floating_keys: BTreeSet<String> = BTreeSet::new();
type AwqTriple = (Option<Array>, Option<Array>, Option<Array>);
let mut awq_components: HashMap<String, AwqTriple> =
HashMap::with_capacity(qweight_prefixes.len());
let mut remainder: Weights = HashMap::with_capacity(weights.len());
for (key, arr) in weights {
if let Some(prefix) = key.strip_suffix(".qweight") {
awq_components
.entry(prefix.to_string())
.or_insert((None, None, None))
.0 = Some(arr);
} else if let Some(prefix) = key.strip_suffix(".scales") {
if qweight_prefixes.binary_search(&prefix.to_string()).is_ok() {
awq_components
.entry(prefix.to_string())
.or_insert((None, None, None))
.1 = Some(arr);
} else {
remainder.insert(key, arr);
}
} else if let Some(prefix) = key.strip_suffix(".qzeros") {
if qweight_prefixes.binary_search(&prefix.to_string()).is_ok() {
awq_components
.entry(prefix.to_string())
.or_insert((None, None, None))
.2 = Some(arr);
} else {
remainder.insert(key, arr);
}
} else {
remainder.insert(key, arr);
}
}
for prefix in &qweight_prefixes {
let (qw_opt, sc_opt, qz_opt) = awq_components.remove(prefix).ok_or_else(|| {
Error::LayerKeyed(LayerKeyedPayload::new(
prefix.clone(),
Error::InvariantViolation(InvariantViolationPayload::new(
"transform_awq_weights: AWQ components",
"must be present (lost mid-pipeline — defensive)",
)),
))
})?;
let qweight = qw_opt.ok_or_else(|| {
Error::MissingKey(MissingKeyPayload::new(
"transform_awq_weights: `.qweight` disappeared mid-pipeline (defensive)",
format!("{prefix}.qweight"),
))
})?;
let scales = sc_opt.ok_or_else(|| {
Error::MissingKey(MissingKeyPayload::new(
"transform_awq_weights: `.scales` disappeared mid-pipeline (defensive)",
format!("{prefix}.scales"),
))
})?;
let q_shape = qweight.shape();
let in_features = q_shape[0];
let packed_out = q_shape[1];
let out_features = packed_out * AWQ_PACK_FACTOR; let packed_in = in_features / AWQ_PACK_FACTOR;
let unpacked = unpack_awq_weights(&qweight)?;
let unpacked_t = ops::shape::transpose(&unpacked)?;
let reshaped = ops::shape::reshape(&unpacked_t, &(out_features, packed_in, AWQ_PACK_FACTOR))?;
let pack_shifts_data: Vec<u32> = (0..AWQ_PACK_FACTOR as u32).map(|i| i * AWQ_BITS).collect();
let pack_shifts = Array::from_slice::<u32>(&pack_shifts_data, &(pack_shifts_data.len(),))?;
let reshaped_u32 = ops::misc::astype(&reshaped, Dtype::U32)?;
let shifted = ops::arithmetic::left_shift(&reshaped_u32, &pack_shifts)?;
let repacked = ops::reduction::sum_axes(&shifted, &[2_i32], false)?;
let new_weight = ops::misc::astype(&repacked, Dtype::U32)?;
let scales_t = ops::shape::transpose(&scales)?;
let scales_c = ops::shape::contiguous(&scales_t, false)?;
let scales_dtype = scales.dtype()?;
let biases = if config.zero_point {
match qz_opt {
Some(qzeros) => {
let unpacked_zeros = unpack_awq_weights(&qzeros)?;
let unpacked_zeros_t = ops::shape::transpose(&unpacked_zeros)?;
let zeros_f32 = ops::misc::astype(&unpacked_zeros_t, Dtype::F32)?;
let scales_f32 = ops::misc::astype(&scales_c, Dtype::F32)?;
let prod = ops::arithmetic::multiply(&zeros_f32, &scales_f32)?;
let neg = ops::arithmetic::negative(&prod)?;
ops::misc::astype(&neg, scales_dtype)?
}
None => {
symmetric_biases(&scales_c, scales_dtype)?
}
}
} else {
symmetric_biases(&scales_c, scales_dtype)?
};
let scales_key = format!("{prefix}.scales");
let biases_key = format!("{prefix}.biases");
new_weights.insert(format!("{prefix}.weight"), new_weight);
new_weights.insert(scales_key.clone(), scales_c);
new_weights.insert(biases_key.clone(), biases);
awq_generated_floating_keys.insert(scales_key);
awq_generated_floating_keys.insert(biases_key);
}
for (key, arr) in remainder {
new_weights.insert(key, arr);
}
if let Some(target) = model_dtype {
for key in &awq_generated_floating_keys {
let arr = new_weights
.get(key)
.expect("AWQ-generated key inserted moments ago must still be present");
let d = arr.dtype()?;
if is_floating(d) && d != target {
let cast = ops::misc::astype(arr, target)?;
new_weights.insert(key.clone(), cast);
}
}
}
let mlx_quantization = PerLayerQuantization::from_global(Quantization {
group_size: group_size_i32,
bits: i32::try_from(config.bits).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"transform_awq_weights: bits",
"must fit in i32",
config.bits.to_string(),
))
})?,
mode: QuantMode::Affine,
});
Ok((new_weights, mlx_quantization))
}
fn symmetric_biases(scales_c: &Array, scales_dtype: Dtype) -> Result<Array> {
let zero_point = (1u32 << (AWQ_BITS - 1)) as f32; let factor = ops::misc::full_like(scales_c, -zero_point)?;
let biases = ops::arithmetic::multiply(scales_c, &factor)?;
if biases.dtype()? != scales_dtype {
ops::misc::astype(&biases, scales_dtype)
} else {
Ok(biases)
}
}
#[cfg(test)]
mod tests;