use std::{
collections::{HashMap, HashSet},
path::Path,
};
use regex::Regex;
use crate::{
array::Array,
error::{
CapExceededPayload, Error, FileIoPayload, FileOp, InvariantViolationPayload, LayerKeyedPayload,
LengthMismatchPayload, MissingFieldPayload, MissingKeyPayload, OutOfRangePayload, ParsePayload,
RankMismatchPayload, Result, ShapePairMismatchPayload, UnknownEnumValuePayload,
},
lm::{
load::Weights,
quant::{PerLayerQuantization, Quantization},
},
ops,
};
pub const DEFAULT_LORA_SCALE: f32 = 20.0;
pub const DEFAULT_LORA_RANK: i32 = 8;
pub const DEFAULT_PEFT_LORA_ALPHA: f32 = 8.0;
pub const DEFAULT_NUM_LAYERS: i32 = 16;
pub const MAX_ADAPTER_SAFETENSORS_BYTES: u64 = 2 << 30;
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Eq,
serde::Deserialize,
serde::Serialize,
derive_more::Display,
derive_more::IsVariant,
)]
#[display("{}", self.as_str())]
#[non_exhaustive]
#[serde(rename_all = "lowercase")]
pub enum FineTuneType {
Lora,
Dora,
Full,
}
impl Default for FineTuneType {
fn default() -> Self {
FineTuneType::Lora
}
}
impl FineTuneType {
pub const fn as_str(self) -> &'static str {
match self {
FineTuneType::Lora => "lora",
FineTuneType::Dora => "dora",
FineTuneType::Full => "full",
}
}
}
#[derive(Debug, Clone, PartialEq, serde::Deserialize, serde::Serialize)]
pub struct LoraParameters {
#[serde(default = "default_rank")]
pub rank: i32,
#[serde(default)]
pub scale: Option<f32>,
#[serde(default)]
pub alpha: Option<f32>,
#[serde(default)]
keys: Vec<String>,
#[serde(default)]
pub dropout: Option<f32>,
}
fn default_rank() -> i32 {
DEFAULT_LORA_RANK
}
impl Default for LoraParameters {
fn default() -> Self {
Self {
rank: DEFAULT_LORA_RANK,
scale: None,
alpha: None,
keys: Vec::new(),
dropout: None,
}
}
}
impl LoraParameters {
#[inline(always)]
pub fn keys_slice(&self) -> &[String] {
&self.keys
}
}
impl LoraParameters {
pub fn resolved_scale(&self) -> f32 {
if let Some(a) = self.alpha
&& self.rank > 0
{
return a / self.rank as f32;
}
if let Some(s) = self.scale {
s
} else {
DEFAULT_LORA_SCALE
}
}
}
#[derive(Debug, Clone, derive_more::IsVariant)]
pub enum ModuleMatcher {
List(Vec<String>),
Regex(Box<Regex>),
AllLinear,
}
impl ModuleMatcher {
pub fn matches(&self, module_key: &str) -> bool {
match self {
ModuleMatcher::List(names) => names
.iter()
.any(|n| module_key == n || module_key.ends_with(&format!(".{n}"))),
ModuleMatcher::Regex(re) => re
.find(module_key)
.is_some_and(|m| m.start() == 0 && m.end() == module_key.len()),
ModuleMatcher::AllLinear => !is_output_head_path(module_key),
}
}
}
fn is_output_head_path(path: &str) -> bool {
path.rsplit('.').next() == Some("lm_head")
}
#[derive(Debug, Clone)]
pub struct PeftSelection {
pub target_modules: Option<ModuleMatcher>,
pub exclude_modules: Option<ModuleMatcher>,
pub layers_to_transform: Option<Vec<i32>>,
pub layers_pattern: Vec<String>,
pub rank_pattern: Vec<(String, i32)>,
pub alpha_pattern: Vec<(String, f32)>,
pub use_rslora: bool,
pub fan_in_fan_out: bool,
}
impl PeftSelection {
pub fn rank_for(&self, module_path: &str, default_rank: i32) -> i32 {
pattern_lookup(&self.rank_pattern, module_path).unwrap_or(default_rank)
}
pub fn alpha_for(&self, module_path: &str, default_alpha: f32) -> f32 {
pattern_lookup(&self.alpha_pattern, module_path).unwrap_or(default_alpha)
}
pub fn scale_for(&self, module_path: &str, default_rank: i32, default_alpha: f32) -> f32 {
let r = self.rank_for(module_path, default_rank);
let alpha = self.alpha_for(module_path, default_alpha);
if r <= 0 {
return 0.0;
}
if self.use_rslora {
alpha / (r as f32).sqrt()
} else {
alpha / r as f32
}
}
}
fn pattern_lookup<T: Copy>(patterns: &[(String, T)], module_path: &str) -> Option<T> {
for (pattern, value) in patterns {
let Ok(re) = Regex::new(&format!(r"(.*\.)?({pattern})$")) else {
continue;
};
if re
.find(module_path)
.is_some_and(|m| m.start() == 0 && m.end() == module_path.len())
{
return Some(*value);
}
}
None
}
struct OrderedPattern<T>(Vec<(String, T)>);
impl<'de, T: serde::Deserialize<'de>> serde::Deserialize<'de> for OrderedPattern<T> {
fn deserialize<D: serde::Deserializer<'de>>(
deserializer: D,
) -> std::result::Result<Self, D::Error> {
struct OrderedVisitor<T>(std::marker::PhantomData<T>);
impl<'de, T: serde::Deserialize<'de>> serde::de::Visitor<'de> for OrderedVisitor<T> {
type Value = Vec<(String, T)>;
fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("a `rank_pattern` / `alpha_pattern` JSON object {pattern: value}")
}
fn visit_map<M: serde::de::MapAccess<'de>>(
self,
mut access: M,
) -> std::result::Result<Self::Value, M::Error> {
let mut out = Vec::with_capacity(access.size_hint().unwrap_or(0));
while let Some((k, v)) = access.next_entry::<String, T>()? {
out.push((k, v));
}
Ok(out)
}
}
deserializer
.deserialize_map(OrderedVisitor(std::marker::PhantomData))
.map(OrderedPattern)
}
}
fn ordered_pattern<T>(pattern: Option<OrderedPattern<T>>) -> Vec<(String, T)> {
pattern.map(|p| p.0).unwrap_or_default()
}
#[derive(Debug, Clone)]
pub enum AdapterSelection {
MlxLm {
num_layers: i32,
},
Peft(PeftSelection),
}
#[derive(Debug, Clone)]
pub struct LoraConfig {
pub fine_tune_type: FineTuneType,
pub lora_parameters: LoraParameters,
pub use_dora: bool,
pub selection: AdapterSelection,
}
fn default_num_layers() -> i32 {
DEFAULT_NUM_LAYERS
}
#[derive(serde::Deserialize)]
struct RawLoraConfig {
#[serde(default)]
fine_tune_type: Option<FineTuneType>,
#[serde(default)]
num_layers: Option<i32>,
#[serde(default)]
lora_parameters: Option<LoraParameters>,
#[serde(default)]
use_dora: bool,
#[serde(default)]
peft_type: Option<String>,
#[serde(default)]
r: Option<i32>,
#[serde(default)]
lora_alpha: Option<f32>,
#[serde(default)]
target_modules: Option<StrOrList>,
#[serde(default)]
exclude_modules: Option<StrOrList>,
#[serde(default)]
lora_dropout: Option<f32>,
#[serde(default)]
use_rslora: bool,
#[serde(default)]
fan_in_fan_out: bool,
#[serde(default)]
lora_bias: bool,
#[serde(default)]
bias: Option<String>,
#[serde(default)]
modules_to_save: Option<Vec<String>>,
#[serde(default)]
layers_to_transform: Option<IntOrList>,
#[serde(default)]
layers_pattern: Option<StrOrList>,
#[serde(default)]
rank_pattern: Option<OrderedPattern<i32>>,
#[serde(default)]
alpha_pattern: Option<OrderedPattern<f32>>,
#[serde(default)]
use_qalora: bool,
#[serde(default)]
alora_invocation_tokens: Option<serde_json::Value>,
#[serde(default)]
velora_config: Option<serde_json::Value>,
#[serde(default)]
monteclora_config: Option<serde_json::Value>,
#[serde(flatten)]
extra: HashMap<String, serde_json::Value>,
}
#[derive(serde::Deserialize)]
#[serde(untagged)]
enum StrOrList {
List(Vec<String>),
One(String),
}
#[derive(serde::Deserialize)]
#[serde(untagged)]
enum IntOrList {
List(Vec<i32>),
One(i32),
}
const ALL_LINEAR_SENTINEL: &str = "all-linear";
fn module_matcher_from<E: serde::de::Error>(
value: StrOrList,
field: &str,
is_target_modules: bool,
) -> std::result::Result<ModuleMatcher, E> {
match value {
StrOrList::List(names) => Ok(ModuleMatcher::List(names)),
StrOrList::One(pattern) => {
if is_target_modules && pattern.eq_ignore_ascii_case(ALL_LINEAR_SENTINEL) {
return Ok(ModuleMatcher::AllLinear);
}
let re = Regex::new(&pattern).map_err(|e| {
E::custom(format!(
"adapter_config.json `{field}` is the regex string {pattern:?}, which failed to \
compile: {e}"
))
})?;
Ok(ModuleMatcher::Regex(Box::new(re)))
}
}
}
fn is_active_config_value(value: &serde_json::Value) -> bool {
!matches!(
value,
serde_json::Value::Null | serde_json::Value::Bool(false)
)
}
fn is_benign_ignore_field(field: &str) -> bool {
matches!(
field,
"task_type"
| "auto_mapping"
| "peft_version"
| "base_model_name_or_path"
| "revision"
| "inference_mode"
| "init_lora_weights"
| "eva_config"
| "corda_config"
| "lora_ga_config"
| "loftq_config"
| "megatron_config"
| "megatron_core"
| "runtime_config"
| "qalora_group_size"
| "ensure_weight_tying"
)
}
fn is_factor_only_init_mode(mode: &str) -> bool {
matches!(
mode.to_ascii_lowercase().as_str(),
"gaussian" | "eva" | "orthogonal"
)
}
fn reject_unknown_active_peft_fields<E: serde::de::Error>(
raw: &RawLoraConfig,
) -> std::result::Result<(), E> {
if let Some(init) = raw.extra.get("init_lora_weights")
&& let Some(mode) = init.as_str()
&& !is_factor_only_init_mode(mode)
{
return Err(E::custom(format!(
"adapter_config.json sets `init_lora_weights: {mode:?}` — this loader only supports the \
pure factor-seed init modes (`gaussian` / `eva` / `orthogonal`) and the booleans \
`true` / `false`. Other modes either mutate the base model weight at init (`olora`, \
`pissa` incl. `pissa_niter_<N>`, `corda` incl. prefixed variants, `loftq`, `lora_ga` — \
they subtract a low-rank residual from `base_layer.weight`, pairing the raw saved factors \
with a *modified* base) or are not understood; applying them to this loader's unmodified \
base would be silently wrong, so they are rejected. (A checkpoint converted via PEFT's \
conversion path reports `init_lora_weights: true` and loads fine.)"
)));
}
for (field, value) in &raw.extra {
if is_benign_ignore_field(field) {
continue;
}
if is_active_config_value(value) {
return Err(E::custom(format!(
"adapter_config.json sets the unsupported / unmodeled PEFT field {field:?} to an active \
value; this loader models only a known subset of PEFT `LoraConfig` and rejects any \
other field that is set (not `null` / `false`), so a future forward-switching variant \
fails loudly instead of silently running as vanilla LoRA. If {field:?} does not affect \
inference, it must be added to the benign-ignore allowlist; otherwise it needs explicit \
support."
)));
}
}
Ok(())
}
fn reject_exotic_variants<E: serde::de::Error>(raw: &RawLoraConfig) -> std::result::Result<(), E> {
if raw.use_qalora {
return Err(E::custom(
"adapter_config.json sets `use_qalora: true` — Quantization-Aware LoRA pools the lora_A \
input before the low-rank matmul, a forward this loader does not implement; a QALoRA \
adapter is not supported (loading it as plain LoRA would be wrong)",
));
}
if raw.alora_invocation_tokens.is_some() {
return Err(E::custom(
"adapter_config.json sets `alora_invocation_tokens` — Activated-LoRA applies the adapter \
only to tokens at/after an invocation sequence, a token-position-dependent forward this \
loader does not implement; an aLoRA adapter is not supported (applying it \
unconditionally would be wrong)",
));
}
if raw.velora_config.is_some() {
return Err(E::custom(
"adapter_config.json carries a `velora_config` — VeLoRA alters the adapter's numerics \
with a custom compressed-activation backward; a VeLoRA adapter is not supported by this \
loader (loading it as plain LoRA would be wrong)",
));
}
if raw.monteclora_config.is_some() {
return Err(E::custom(
"adapter_config.json carries a `monteclora_config` — MonteCLoRA adds variational \
Monte-Carlo sampling over the LoRA adapters, changing the forward; a MonteCLoRA adapter \
is not supported by this loader (loading it as plain LoRA would be wrong)",
));
}
Ok(())
}
impl<'de> serde::Deserialize<'de> for LoraConfig {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error as _;
let raw = RawLoraConfig::deserialize(deserializer)?;
reject_exotic_variants::<D::Error>(&raw)?;
if let Some(lora_parameters) = raw.lora_parameters {
return Ok(LoraConfig {
fine_tune_type: raw.fine_tune_type.unwrap_or_default(),
lora_parameters,
use_dora: raw.use_dora,
selection: AdapterSelection::MlxLm {
num_layers: raw.num_layers.unwrap_or_else(default_num_layers),
},
});
}
let is_peft = raw.peft_type.is_some()
|| raw.r.is_some()
|| raw.lora_alpha.is_some()
|| raw.target_modules.is_some();
if is_peft {
if let Some(peft_type) = &raw.peft_type
&& !peft_type.eq_ignore_ascii_case("LORA")
{
return Err(D::Error::custom(format!(
"adapter_config.json `peft_type` is {peft_type:?}, but this loader handles only \
LoRA/DoRA adapters (`peft_type` \"LORA\"); a different PEFT method (LOHA, LOKR, \
IA3, prompt-tuning, …) is not supported"
)));
}
if raw.lora_bias {
return Err(D::Error::custom(
"adapter_config.json sets `lora_bias: true` (a bias on the lora_B projection); this \
loader's LoRALinear has no lora_B-bias term, so a `lora_bias` adapter is not \
supported (it would silently drop the bias)",
));
}
if let Some(bias) = &raw.bias
&& !bias.eq_ignore_ascii_case("none")
{
return Err(D::Error::custom(format!(
"adapter_config.json sets `bias: {bias:?}` (PEFT trains+saves base/adapter `.bias` \
tensors for `\"all\"` / `\"lora_only\"`); this loader's LoRALinear has no adapted-bias \
slot, so a PEFT `bias` adapter is not supported (it would silently drop the bias \
tensors)"
)));
}
if raw.modules_to_save.as_ref().is_some_and(|m| !m.is_empty()) {
return Err(D::Error::custom(
"adapter_config.json sets a non-empty `modules_to_save` (PEFT trains+saves these \
modules in full alongside the LoRA factors); this loader applies only the low-rank \
factors and has no saved-full-module slot, so a `modules_to_save` adapter is not \
supported (it would silently drop the saved module weights)",
));
}
reject_unknown_active_peft_fields::<D::Error>(&raw)?;
let target_modules = match raw.target_modules {
None => None,
Some(v) => Some(module_matcher_from::<D::Error>(v, "target_modules", true)?),
};
let exclude_modules = match raw.exclude_modules {
None => None,
Some(v) => Some(module_matcher_from::<D::Error>(
v,
"exclude_modules",
false,
)?),
};
let layers_to_transform = raw.layers_to_transform.map(|v| match v {
IntOrList::List(xs) => xs,
IntOrList::One(x) => vec![x],
});
let layers_pattern = match raw.layers_pattern {
None => Vec::new(),
Some(StrOrList::List(xs)) => xs,
Some(StrOrList::One(x)) => vec![x],
};
let peft = PeftSelection {
target_modules,
exclude_modules,
layers_to_transform,
layers_pattern,
rank_pattern: ordered_pattern(raw.rank_pattern),
alpha_pattern: ordered_pattern(raw.alpha_pattern),
use_rslora: raw.use_rslora,
fan_in_fan_out: raw.fan_in_fan_out,
};
let lora_parameters = LoraParameters {
rank: raw.r.unwrap_or_else(default_rank),
scale: None,
alpha: Some(raw.lora_alpha.unwrap_or(DEFAULT_PEFT_LORA_ALPHA)),
keys: Vec::new(),
dropout: raw.lora_dropout,
};
return Ok(LoraConfig {
fine_tune_type: FineTuneType::Lora,
lora_parameters,
use_dora: raw.use_dora,
selection: AdapterSelection::Peft(peft),
});
}
Ok(LoraConfig {
fine_tune_type: raw.fine_tune_type.unwrap_or_default(),
lora_parameters: LoraParameters::default(),
use_dora: raw.use_dora,
selection: AdapterSelection::MlxLm {
num_layers: raw.num_layers.unwrap_or_else(default_num_layers),
},
})
}
}
impl LoraConfig {
pub fn from_json(json: &str) -> Result<LoraConfig> {
serde_json::from_str(json).map_err(|e| {
Error::Parse(ParsePayload::new(
"LoraConfig::from_json",
"adapter_config.json",
e,
))
})
}
pub fn is_dora(&self) -> bool {
self.fine_tune_type == FineTuneType::Dora || self.use_dora
}
pub fn scale(&self) -> f32 {
self.lora_parameters.resolved_scale()
}
pub fn scale_for(&self, module_path: &str) -> f32 {
match &self.selection {
AdapterSelection::MlxLm { .. } => self.lora_parameters.resolved_scale(),
AdapterSelection::Peft(peft) => peft.scale_for(
module_path,
self.lora_parameters.rank,
self
.lora_parameters
.alpha
.unwrap_or(DEFAULT_PEFT_LORA_ALPHA),
),
}
}
pub fn rank_for(&self, module_path: &str) -> i32 {
match &self.selection {
AdapterSelection::MlxLm { .. } => self.lora_parameters.rank,
AdapterSelection::Peft(peft) => peft.rank_for(module_path, self.lora_parameters.rank),
}
}
pub fn rank(&self) -> i32 {
self.lora_parameters.rank
}
pub fn peft(&self) -> Option<&PeftSelection> {
match &self.selection {
AdapterSelection::Peft(p) => Some(p),
AdapterSelection::MlxLm { .. } => None,
}
}
pub fn fan_in_fan_out(&self) -> bool {
self.peft().is_some_and(|p| p.fan_in_fan_out)
}
}
#[derive(Debug)]
pub struct AdapterParams {
pub lora_a: Array,
pub lora_b: Array,
pub magnitude: Option<Array>,
}
impl AdapterParams {
pub fn try_clone(&self) -> Result<Self> {
Ok(Self {
lora_a: self.lora_a.try_clone()?,
lora_b: self.lora_b.try_clone()?,
magnitude: match &self.magnitude {
Some(m) => Some(m.try_clone()?),
None => None,
},
})
}
}
#[derive(Debug)]
pub enum BaseLinear {
Dense {
weight: Array,
bias: Option<Array>,
},
Quantized {
weight: Array,
scales: Array,
quant_biases: Option<Array>,
bias: Option<Array>,
group_size: i32,
bits: i32,
mode: String,
},
}
impl BaseLinear {
pub fn dense(weight: Array, bias: Option<Array>) -> Result<Self> {
let w_shape = weight.shape();
let w_rank = w_shape.len();
let w_output_dims = w_shape.first().copied().unwrap_or(0);
if w_rank != 2 {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"BaseLinear::dense: weight must be 2-D [output_dims, input_dims]",
w_rank as u32,
w_shape,
)));
}
if let Some(b) = &bias {
let b_shape = b.shape();
if b_shape.len() != 1 || b_shape[0] != w_output_dims {
return Err(Error::ShapePairMismatch(ShapePairMismatchPayload::new(
"BaseLinear::dense: bias must be [output_dims]",
vec![w_output_dims],
b_shape,
)));
}
}
Ok(BaseLinear::Dense { weight, bias })
}
pub fn quantized(
weight: Array,
scales: Array,
quant_biases: Option<Array>,
bias: Option<Array>,
group_size: i32,
bits: i32,
mode: String,
) -> Result<Self> {
match (mode.as_str(), &quant_biases) {
("affine", None) => {
return Err(Error::MissingField(MissingFieldPayload::new(
"BaseLinear::quantized",
"quant_biases (affine mode requires it; mlx affine_quantize writes {w_q, scales, biases})",
)));
}
("mxfp4" | "mxfp8" | "nvfp4", Some(_)) => {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"BaseLinear::quantized: quant_biases",
"must be None for scale-only modes (mxfp4/mxfp8/nvfp4 — mlx fp_quantize writes {w_q, scales})",
)));
}
("affine", Some(_)) | ("mxfp4" | "mxfp8" | "nvfp4", None) => {}
(other, _) => {
return Err(Error::UnknownEnumValue(UnknownEnumValuePayload::new(
"BaseLinear::quantized: mode",
other.to_string(),
&["affine", "mxfp4", "mxfp8", "nvfp4"],
)));
}
}
if bits <= 0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"BaseLinear::quantized: bits",
"must be > 0",
bits.to_string(),
)));
}
if group_size <= 0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"BaseLinear::quantized: group_size",
"must be > 0",
group_size.to_string(),
)));
}
Ok(BaseLinear::Quantized {
weight,
scales,
quant_biases,
bias,
group_size,
bits,
mode,
})
}
pub fn bias(&self) -> Option<&Array> {
match self {
BaseLinear::Dense { bias, .. } => bias.as_ref(),
BaseLinear::Quantized { bias, .. } => bias.as_ref(),
}
}
pub fn dequantized_weight(&self) -> Result<Array> {
match self {
BaseLinear::Dense { weight, .. } => weight.try_clone(),
BaseLinear::Quantized {
weight,
scales,
quant_biases,
group_size,
bits,
mode,
..
} => ops::quantized::dequantize(
weight,
scales,
quant_biases.as_ref(),
*group_size,
*bits,
mode,
None,
None,
),
}
}
fn base_output_no_bias(&self, x: &Array) -> Result<Array> {
match self {
BaseLinear::Dense { weight, .. } => {
let wt = weight.transpose()?;
x.matmul(&wt)
}
BaseLinear::Quantized {
weight,
scales,
quant_biases,
group_size,
bits,
mode,
..
} => {
ops::quantized::quantized_matmul(
x,
weight,
scales,
quant_biases.as_ref(),
true,
*group_size,
*bits,
mode,
)
}
}
}
fn base_output(&self, x: &Array) -> Result<Array> {
let y = self.base_output_no_bias(x)?;
match self.bias() {
Some(b) => y.add(b),
None => Ok(y),
}
}
fn requantize_fused(&self, fused_weight: Array, fused_bias: Option<Array>) -> Result<BaseLinear> {
match self {
BaseLinear::Dense { .. } => BaseLinear::dense(fused_weight, fused_bias),
BaseLinear::Quantized {
group_size,
bits,
mode,
..
} => {
let (w_q, scales, q_biases) =
ops::quantized::quantize(&fused_weight, *group_size, *bits, mode, None)?;
BaseLinear::quantized(
w_q,
scales,
q_biases,
fused_bias,
*group_size,
*bits,
mode.clone(),
)
}
}
}
fn is_quantized(&self) -> bool {
matches!(self, BaseLinear::Quantized { .. })
}
}
fn scaled(arr: &Array, scale: f32) -> Result<Array> {
let s = Array::from_slice::<f32>(&[scale], &(1usize,))?;
let s = match arr.dtype() {
Ok(dt) => s.astype(dt)?,
Err(_) => s,
};
arr.multiply(&s)
}
fn lora_delta(params: &AdapterParams, scale: f32) -> Result<Array> {
let lb_t = params.lora_b.transpose()?; let la_t = params.lora_a.transpose()?; let lb_t_scaled = scaled(&lb_t, scale)?;
lb_t_scaled.matmul(&la_t)
}
fn lora_z(x: &Array, params: &AdapterParams) -> Result<Array> {
let xa = x.matmul(¶ms.lora_a)?;
xa.matmul(¶ms.lora_b)
}
#[derive(Debug)]
pub struct LoRALinear {
base: BaseLinear,
params: AdapterParams,
scale: f32,
}
impl LoRALinear {
pub fn new(base: BaseLinear, params: AdapterParams, scale: f32) -> Result<Self> {
validate_factor_shapes(&base, ¶ms, LinearValidationContext::LoraLinear)?;
Ok(Self {
base,
params,
scale,
})
}
pub fn scale(&self) -> f32 {
self.scale
}
pub fn base(&self) -> &BaseLinear {
&self.base
}
pub fn forward(&self, x: &Array) -> Result<Array> {
let y = self.base.base_output(x)?;
let z = lora_z(x, &self.params)?;
let scaled_z = scaled(&z, self.scale)?;
let scaled_z = match x.dtype() {
Ok(dt) => scaled_z.astype(dt)?,
Err(_) => scaled_z,
};
y.add(&scaled_z)
}
pub fn fuse(&self, dequantize: bool) -> Result<BaseLinear> {
let weight = self.base.dequantized_weight()?;
let delta = lora_delta(&self.params, self.scale)?;
let delta = match weight.dtype() {
Ok(dt) => delta.astype(dt)?,
Err(_) => delta,
};
let fused_weight = weight.add(&delta)?;
let fused_bias = match self.base.bias() {
Some(b) => Some(b.try_clone()?),
None => None,
};
if self.base.is_quantized() && !dequantize {
self.base.requantize_fused(fused_weight, fused_bias)
} else {
BaseLinear::dense(fused_weight, fused_bias)
}
}
}
#[derive(Debug)]
pub struct DoRALinear {
base: BaseLinear,
params: AdapterParams,
magnitude: Array,
scale: f32,
}
impl DoRALinear {
pub fn new(base: BaseLinear, params: AdapterParams, scale: f32) -> Result<Self> {
validate_factor_shapes(&base, ¶ms, LinearValidationContext::DoraLinear)?;
let magnitude = match ¶ms.magnitude {
Some(m) => m.try_clone()?,
None => {
return Err(Error::MissingField(MissingFieldPayload::new(
"DoRALinear::new",
"magnitude `m` (loaded from adapters.safetensors; DoRA requires it)",
)));
}
};
let output_dims = base_output_dims(&base)?;
let m_shape = magnitude.shape();
if m_shape.len() != 1 || m_shape[0] != output_dims {
return Err(Error::ShapePairMismatch(ShapePairMismatchPayload::new(
"DoRALinear::new: magnitude `m` must be [output_dims]",
vec![output_dims],
m_shape,
)));
}
Ok(Self {
base,
params,
magnitude,
scale,
})
}
pub fn scale(&self) -> f32 {
self.scale
}
pub fn base(&self) -> &BaseLinear {
&self.base
}
pub fn magnitude(&self) -> &Array {
&self.magnitude
}
pub fn forward(&self, x: &Array) -> Result<Array> {
let y = self.base.base_output_no_bias(x)?;
let z = lora_z(x, &self.params)?;
let scaled_z = scaled(&z, self.scale)?;
let scaled_z = match x.dtype() {
Ok(dt) => scaled_z.astype(dt)?,
Err(_) => scaled_z,
};
let out = y.add(&scaled_z)?;
let w = self.base.dequantized_weight()?;
let delta = lora_delta(&self.params, self.scale)?;
let delta = match w.dtype() {
Ok(dt) => delta.astype(dt)?,
Err(_) => delta,
};
let adapted = w.add(&delta)?;
let denom = ops::linalg_full::norm(&adapted, 2.0, &[1], false)?;
let norm_scale = self.magnitude.divide(&denom)?;
let norm_scale = match x.dtype() {
Ok(dt) => norm_scale.astype(dt)?,
Err(_) => norm_scale,
};
let mut out = out.multiply(&norm_scale)?;
if let Some(bias) = self.base.bias() {
out = out.add(bias)?;
}
Ok(out)
}
pub fn fuse(&self, dequantize: bool) -> Result<BaseLinear> {
let weight = self.base.dequantized_weight()?;
let delta = lora_delta(&self.params, self.scale)?;
let delta = match weight.dtype() {
Ok(dt) => delta.astype(dt)?,
Err(_) => delta,
};
let adapted = weight.add(&delta)?;
let denom = ops::linalg_full::norm(&adapted, 2.0, &[1], false)?;
let norm_scale = self.magnitude.divide(&denom)?;
let norm_scale_col = norm_scale.expand_dims_axes(&[-1])?;
let fused_weight = norm_scale_col.multiply(&adapted)?;
let fused_bias = match self.base.bias() {
Some(b) => Some(b.try_clone()?),
None => None,
};
if self.base.is_quantized() && !dequantize {
self.base.requantize_fused(fused_weight, fused_bias)
} else {
BaseLinear::dense(fused_weight, fused_bias)
}
}
}
#[derive(Debug)]
pub enum BaseEmbedding {
Dense {
weight: Array,
},
}
impl BaseEmbedding {
pub fn dense(weight: Array) -> Result<Self> {
let shape = weight.shape();
let rank = shape.len();
if rank != 2 {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"BaseEmbedding::dense: weight must be 2-D [num_embeddings, dims]",
rank as u32,
shape,
)));
}
Ok(BaseEmbedding::Dense { weight })
}
pub fn weight(&self) -> &Array {
match self {
BaseEmbedding::Dense { weight } => weight,
}
}
fn num_embeddings(&self) -> Result<usize> {
let shape = self.weight().shape();
shape.first().copied().ok_or_else(|| {
Error::RankMismatch(RankMismatchPayload::new(
"BaseEmbedding: weight must be rank-2 [num_embeddings, dims] to determine num_embeddings",
shape.len() as u32,
shape.clone(),
))
})
}
fn dims(&self) -> Result<usize> {
let shape = self.weight().shape();
shape.get(1).copied().ok_or_else(|| {
Error::RankMismatch(RankMismatchPayload::new(
"BaseEmbedding: weight must be rank-2 [num_embeddings, dims] to determine dims",
shape.len() as u32,
shape.clone(),
))
})
}
fn lookup(&self, ids: &Array) -> Result<Array> {
self.weight().take_axis(ids, 0)
}
fn as_linear(&self, x: &Array) -> Result<Array> {
let wt = self.weight().transpose()?;
x.matmul(&wt)
}
}
#[derive(Debug)]
pub struct DoRAEmbedding {
base: BaseEmbedding,
params: AdapterParams,
magnitude: Array,
scale: f32,
}
impl DoRAEmbedding {
pub fn new(base: BaseEmbedding, params: AdapterParams, scale: f32) -> Result<Self> {
validate_embedding_factor_shapes(&base, ¶ms, EmbeddingValidationContext::DoraEmbedding)?;
let magnitude = match ¶ms.magnitude {
Some(m) => m.try_clone()?,
None => {
return Err(Error::MissingField(MissingFieldPayload::new(
"DoRAEmbedding::new",
"magnitude `m` (loaded from adapters.safetensors; DoRA requires it)",
)));
}
};
let num_embeddings = base.num_embeddings()?;
let m_shape = magnitude.shape();
if m_shape.len() != 1 || m_shape[0] != num_embeddings {
return Err(Error::ShapePairMismatch(ShapePairMismatchPayload::new(
"DoRAEmbedding::new: magnitude `m` must be [num_embeddings]",
vec![num_embeddings],
m_shape,
)));
}
Ok(Self {
base,
params,
magnitude,
scale,
})
}
pub fn scale(&self) -> f32 {
self.scale
}
pub fn base(&self) -> &BaseEmbedding {
&self.base
}
pub fn magnitude(&self) -> &Array {
&self.magnitude
}
pub fn forward(&self, x: &Array) -> Result<Array> {
let y = self.base.lookup(x)?;
let la_gathered = self.params.lora_a.take_axis(x, 0)?;
let mut z = la_gathered.matmul(&self.params.lora_b)?;
z = scaled(&z, self.scale)?;
let z_for_out = match y.dtype() {
Ok(dt) => z.astype(dt)?,
Err(_) => z.try_clone()?,
};
let out = y.add(&z_for_out)?;
let adapted = y.add(&z)?;
let denom = ops::linalg_full::norm(&adapted, 2.0, &[-1], false)?;
let m_gathered = self.magnitude.take_axis(x, 0)?;
let norm_scale = m_gathered.divide(&denom)?;
let norm_scale = norm_scale.expand_dims_axes(&[-1])?;
norm_scale.multiply(&out)
}
pub fn as_linear(&self, x: &Array) -> Result<Array> {
let y = self.base.as_linear(x)?;
let lb_t = self.params.lora_b.transpose()?;
let la_t = self.params.lora_a.transpose()?;
let xb = x.matmul(&lb_t)?;
let z = xb.matmul(&la_t)?;
let scaled_z = scaled(&z, self.scale)?;
let scaled_z_for_out = match x.dtype() {
Ok(dt) => scaled_z.astype(dt)?,
Err(_) => scaled_z.try_clone()?,
};
let out = y.add(&scaled_z_for_out)?;
let scaled_la = scaled(&self.params.lora_a, self.scale)?;
let delta = scaled_la.matmul(&self.params.lora_b)?;
let w = self.base.weight();
let adapted = w.add(&delta)?;
let denom = ops::linalg_full::norm(&adapted, 2.0, &[1], false)?;
let norm_scale = self.magnitude.divide(&denom)?;
out.multiply(&norm_scale)
}
pub fn fuse(&self) -> Result<BaseEmbedding> {
let scaled_la = scaled(&self.params.lora_a, self.scale)?;
let delta = scaled_la.matmul(&self.params.lora_b)?;
let w = self.base.weight();
let delta = match w.dtype() {
Ok(dt) => delta.astype(dt)?,
Err(_) => delta,
};
let adapted = w.add(&delta)?;
let denom = ops::linalg_full::norm(&adapted, 2.0, &[1], false)?;
let norm_scale = self.magnitude.divide(&denom)?;
let norm_scale_col = norm_scale.expand_dims_axes(&[-1])?;
let fused_weight = norm_scale_col.multiply(&adapted)?;
BaseEmbedding::dense(fused_weight)
}
}
fn validate_embedding_factor_shapes(
base: &BaseEmbedding,
params: &AdapterParams,
who: EmbeddingValidationContext,
) -> Result<()> {
let a_shape = params.lora_a.shape();
let b_shape = params.lora_b.shape();
let a_rank = a_shape.len();
let b_rank = b_shape.len();
let a_rank_axis = a_shape.get(1).copied().unwrap_or_default();
let b_rank_axis = b_shape.first().copied().unwrap_or_default();
let a_leading_axis = a_shape.first().copied().unwrap_or_default();
let b_last_axis = b_shape.get(1).copied().unwrap_or_default();
if a_rank != 2 {
return Err(Error::RankMismatch(RankMismatchPayload::new(
who.lora_a_rank2(),
a_rank as u32,
a_shape,
)));
}
if b_rank != 2 {
return Err(Error::RankMismatch(RankMismatchPayload::new(
who.lora_b_rank2(),
b_rank as u32,
b_shape,
)));
}
if a_rank_axis != b_rank_axis {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
who.shared_rank(),
b_rank_axis,
a_rank_axis,
)));
}
let num_embeddings = base.num_embeddings()?;
if a_leading_axis != num_embeddings {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
who.a_leading_vs_num_embeddings(),
num_embeddings,
a_leading_axis,
)));
}
let dims = base.dims()?;
if b_last_axis != dims {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
who.b_last_vs_dims(),
dims,
b_last_axis,
)));
}
Ok(())
}
#[derive(Debug, Clone, Copy)]
enum EmbeddingValidationContext {
DoraEmbedding,
}
impl EmbeddingValidationContext {
const fn lora_a_rank2(self) -> &'static str {
match self {
Self::DoraEmbedding => "DoRAEmbedding: lora_a must be 2-D [num_embeddings, r]",
}
}
const fn lora_b_rank2(self) -> &'static str {
match self {
Self::DoraEmbedding => "DoRAEmbedding: lora_b must be 2-D [r, dims]",
}
}
const fn shared_rank(self) -> &'static str {
match self {
Self::DoraEmbedding => {
"DoRAEmbedding: lora_a last axis vs lora_b leading axis (shared rank `r`)"
}
}
}
const fn a_leading_vs_num_embeddings(self) -> &'static str {
match self {
Self::DoraEmbedding => "DoRAEmbedding: lora_a leading axis vs base num_embeddings",
}
}
const fn b_last_vs_dims(self) -> &'static str {
match self {
Self::DoraEmbedding => "DoRAEmbedding: lora_b last axis vs base dims",
}
}
}
#[derive(Debug)]
pub enum LoraLayer {
Lora(LoRALinear),
Dora(DoRALinear),
DoraEmbedding(DoRAEmbedding),
}
impl LoraLayer {
pub fn forward(&self, x: &Array) -> Result<Array> {
match self {
LoraLayer::Lora(l) => l.forward(x),
LoraLayer::Dora(d) => d.forward(x),
LoraLayer::DoraEmbedding(d) => d.forward(x),
}
}
pub fn fuse(&self, dequantize: bool) -> Result<BaseLinear> {
match self {
LoraLayer::Lora(l) => l.fuse(dequantize),
LoraLayer::Dora(d) => d.fuse(dequantize),
LoraLayer::DoraEmbedding(_) => {
Err(Error::InvariantViolation(InvariantViolationPayload::new(
"LoraLayer::fuse: variant",
"is a DoRA embedding layer; call `fuse_embedding` to obtain a `BaseEmbedding`",
)))
}
}
}
pub fn fuse_embedding(&self) -> Result<BaseEmbedding> {
match self {
LoraLayer::DoraEmbedding(d) => d.fuse(),
LoraLayer::Lora(_) | LoraLayer::Dora(_) => {
Err(Error::InvariantViolation(InvariantViolationPayload::new(
"LoraLayer::fuse_embedding: variant",
"is a linear LoRA/DoRA layer; call `fuse` to obtain a `BaseLinear`",
)))
}
}
}
pub fn base(&self) -> Option<&BaseLinear> {
match self {
LoraLayer::Lora(l) => Some(l.base()),
LoraLayer::Dora(d) => Some(d.base()),
LoraLayer::DoraEmbedding(_) => None,
}
}
pub fn base_embedding(&self) -> Option<&BaseEmbedding> {
match self {
LoraLayer::DoraEmbedding(d) => Some(d.base()),
LoraLayer::Lora(_) | LoraLayer::Dora(_) => None,
}
}
}
pub type LoraLayers = HashMap<String, LoraLayer>;
fn base_output_dims(base: &BaseLinear) -> Result<usize> {
let shape = match base {
BaseLinear::Dense { weight, .. } => weight.shape(),
BaseLinear::Quantized { weight, .. } => weight.shape(),
};
shape.first().copied().ok_or_else(|| {
Error::RankMismatch(RankMismatchPayload::new(
"base linear: weight must be rank-2 [output_dims, input_dims] to determine output_dims",
shape.len() as u32,
shape.clone(),
))
})
}
fn base_input_dims(base: &BaseLinear) -> Result<usize> {
match base {
BaseLinear::Dense { weight, .. } => {
let shape = weight.shape();
let rank = shape.len() as u32;
shape.get(1).copied().ok_or_else(|| {
Error::RankMismatch(RankMismatchPayload::new(
"dense base weight must be 2-D [output_dims, input_dims]",
rank,
shape.clone(),
))
})
}
BaseLinear::Quantized { weight, bits, .. } => {
let shape = weight.shape();
let rank = shape.len() as u32;
let packed = shape.get(1).copied().ok_or_else(|| {
Error::RankMismatch(RankMismatchPayload::new(
"quantized base weight must be 2-D [output_dims, input_dims*bits/32]",
rank,
shape.clone(),
))
})?;
Ok(packed * 32 / (*bits as usize))
}
}
}
fn validate_factor_shapes(
base: &BaseLinear,
params: &AdapterParams,
who: LinearValidationContext,
) -> Result<()> {
let a_shape = params.lora_a.shape();
let b_shape = params.lora_b.shape();
let a_rank = a_shape.len();
let b_rank = b_shape.len();
let a_rank_axis = a_shape.get(1).copied().unwrap_or_default();
let b_rank_axis = b_shape.first().copied().unwrap_or_default();
let a_leading_axis = a_shape.first().copied().unwrap_or_default();
let b_last_axis = b_shape.get(1).copied().unwrap_or_default();
if a_rank != 2 {
return Err(Error::RankMismatch(RankMismatchPayload::new(
who.lora_a_rank2(),
a_rank as u32,
a_shape,
)));
}
if b_rank != 2 {
return Err(Error::RankMismatch(RankMismatchPayload::new(
who.lora_b_rank2(),
b_rank as u32,
b_shape,
)));
}
if a_rank_axis != b_rank_axis {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
who.shared_rank(),
b_rank_axis,
a_rank_axis,
)));
}
let input_dims = base_input_dims(base)?;
if a_leading_axis != input_dims {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
who.a_leading_vs_input_dims(),
input_dims,
a_leading_axis,
)));
}
let output_dims = base_output_dims(base)?;
if b_last_axis != output_dims {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
who.b_last_vs_output_dims(),
output_dims,
b_last_axis,
)));
}
Ok(())
}
#[derive(Debug, Clone, Copy)]
enum LinearValidationContext {
LoraLinear,
DoraLinear,
}
impl LinearValidationContext {
const fn lora_a_rank2(self) -> &'static str {
match self {
Self::LoraLinear => "LoRALinear: lora_a must be 2-D [input_dims, r]",
Self::DoraLinear => "DoRALinear: lora_a must be 2-D [input_dims, r]",
}
}
const fn lora_b_rank2(self) -> &'static str {
match self {
Self::LoraLinear => "LoRALinear: lora_b must be 2-D [r, output_dims]",
Self::DoraLinear => "DoRALinear: lora_b must be 2-D [r, output_dims]",
}
}
const fn shared_rank(self) -> &'static str {
match self {
Self::LoraLinear => "LoRALinear: lora_a last axis vs lora_b leading axis (shared rank `r`)",
Self::DoraLinear => "DoRALinear: lora_a last axis vs lora_b leading axis (shared rank `r`)",
}
}
const fn a_leading_vs_input_dims(self) -> &'static str {
match self {
Self::LoraLinear => "LoRALinear: lora_a leading axis vs base input_dims",
Self::DoraLinear => "DoRALinear: lora_a leading axis vs base input_dims",
}
}
const fn b_last_vs_output_dims(self) -> &'static str {
match self {
Self::LoraLinear => "LoRALinear: lora_b last axis vs base output_dims",
Self::DoraLinear => "DoRALinear: lora_b last axis vs base output_dims",
}
}
const fn config_rank_vs_lora_a_rank(self) -> &'static str {
match self {
Self::LoraLinear => "LoRALinear: adapter_config.json rank vs lora_a actual rank axis",
Self::DoraLinear => "DoRALinear: adapter_config.json rank vs lora_a actual rank axis",
}
}
const fn config_rank_vs_lora_b_rank(self) -> &'static str {
match self {
Self::LoraLinear => "LoRALinear: adapter_config.json rank vs lora_b actual rank axis",
Self::DoraLinear => "DoRALinear: adapter_config.json rank vs lora_b actual rank axis",
}
}
}
fn validate_config_rank(
params: &AdapterParams,
config_rank: usize,
who: LinearValidationContext,
) -> Result<()> {
let a_shape = params.lora_a.shape();
let b_shape = params.lora_b.shape();
let a_rank = a_shape.get(1).copied().unwrap_or_default();
let b_rank = b_shape.first().copied().unwrap_or_default();
if a_rank != config_rank {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
who.config_rank_vs_lora_a_rank(),
config_rank,
a_rank,
)));
}
if b_rank != config_rank {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
who.config_rank_vs_lora_b_rank(),
config_rank,
b_rank,
)));
}
Ok(())
}
pub fn linear_to_lora_layers(
weights: &Weights,
config: &LoraConfig,
adapter_params: &HashMap<String, AdapterParams>,
quant: Option<&PerLayerQuantization>,
num_blocks: i32,
) -> Result<LoraLayers> {
let mut out: LoraLayers = HashMap::new();
let is_dora = config.is_dora();
let fan_in_fan_out = config.fan_in_fan_out();
let first_adapted = match &config.selection {
AdapterSelection::MlxLm { num_layers } if *num_layers > 0 => (num_blocks - num_layers).max(0),
_ => 0,
};
let mut consumed: HashSet<&str> = HashSet::with_capacity(adapter_params.len());
let mut selected_without_factors: Vec<&str> = Vec::new();
let explicit_selection = match &config.selection {
AdapterSelection::MlxLm { .. } => !config.lora_parameters.keys.is_empty(),
AdapterSelection::Peft(peft) => {
matches!(&peft.target_modules, Some(m) if !matches!(m, ModuleMatcher::AllLinear))
}
};
for (key, weight) in weights {
let Some(path) = key.strip_suffix(".weight") else {
continue;
};
if !module_is_selected(path, weight, config, first_adapted) {
continue;
}
let Some(params) = adapter_params.get(path) else {
selected_without_factors.push(path);
continue;
};
consumed.insert(path);
let module_rank = config.rank_for(path);
let scale = config.scale_for(path);
if let Some(rank) = usize::try_from(module_rank).ok().filter(|&r| r > 0) {
let who = if is_dora {
LinearValidationContext::DoraLinear
} else {
LinearValidationContext::LoraLinear
};
validate_config_rank(params, rank, who)?;
}
let base = build_base_linear(weights, path, weight, quant, fan_in_fan_out)?;
let layer = if is_dora {
LoraLayer::Dora(DoRALinear::new(base, params.try_clone()?, scale)?)
} else {
LoraLayer::Lora(LoRALinear::new(base, params.try_clone()?, scale)?)
};
out.insert(path.to_string(), layer);
}
check_adapter_completeness(
&out,
adapter_params,
&consumed,
&selected_without_factors,
explicit_selection,
)?;
Ok(out)
}
fn module_is_selected(path: &str, weight: &Array, config: &LoraConfig, first_adapted: i32) -> bool {
match &config.selection {
AdapterSelection::MlxLm { .. } => {
if !config.lora_parameters.keys.is_empty() {
if !config
.lora_parameters
.keys
.iter()
.any(|k| path_matches_key(path, k))
{
return false;
}
} else {
if weight.shape().len() != 2 {
return false;
}
}
match parse_block_index(path) {
Some(block) => block >= first_adapted,
None => true,
}
}
AdapterSelection::Peft(peft) => peft_module_is_selected(path, weight, peft),
}
}
fn peft_module_is_selected(path: &str, weight: &Array, peft: &PeftSelection) -> bool {
if let Some(exclude) = &peft.exclude_modules
&& exclude.matches(path)
{
return false;
}
match &peft.target_modules {
Some(ModuleMatcher::AllLinear) => {
if weight.shape().len() != 2 || !ModuleMatcher::AllLinear.matches(path) {
return false;
}
}
Some(target) => {
if !target.matches(path) {
return false;
}
}
None => {
if weight.shape().len() != 2 {
return false;
}
}
}
if let Some(layers) = &peft.layers_to_transform
&& !layers.is_empty()
{
match peft_layer_index(path, &peft.layers_pattern) {
Some(idx) => return layers.contains(&idx),
None => return false,
}
}
true
}
fn peft_layer_index(path: &str, layers_pattern: &[String]) -> Option<i32> {
if layers_pattern.is_empty() {
let re = Regex::new(r"^.*\.[^.]*\.(\d+)\.").ok()?;
let caps = re.captures(path)?;
return caps.get(1)?.as_str().parse::<i32>().ok();
}
for pattern in layers_pattern {
let escaped = regex::escape(pattern);
let Ok(re) = Regex::new(&format!(r"^.*\.{escaped}\.(\d+)\.")) else {
continue;
};
if let Some(caps) = re.captures(path)
&& let Some(m) = caps.get(1)
&& let Ok(idx) = m.as_str().parse::<i32>()
{
return Some(idx);
}
}
None
}
fn check_adapter_completeness(
applied: &LoraLayers,
adapter_params: &HashMap<String, AdapterParams>,
consumed: &HashSet<&str>,
selected_without_factors: &[&str],
explicit_selection: bool,
) -> Result<()> {
if explicit_selection && !selected_without_factors.is_empty() {
let mut missing: Vec<&str> = selected_without_factors.to_vec();
missing.sort_unstable();
return Err(Error::MissingKey(MissingKeyPayload::new(
"load_adapters: explicitly-selected adapter target (adapter_config.json target \
selection does not match adapters.safetensors contents)",
missing[0].to_string(),
)));
}
let mut unused: Vec<&str> = adapter_params
.keys()
.map(String::as_str)
.filter(|p| !consumed.contains(p))
.collect();
if !unused.is_empty() {
unused.sort_unstable();
return Err(Error::LayerKeyed(LayerKeyedPayload::new(
unused[0].to_string(),
Error::InvariantViolation(InvariantViolationPayload::new(
"load_adapters: adapter factor group",
"must match a base layer (the adapters.safetensors paths do not align with \
the base model weights — path-prefix mismatch or config drift)",
)),
)));
}
if applied.is_empty() {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"load_adapters: adapted-layer count",
"must be >= 1 (the adapter_config.json target selection — mlx-lm `keys`/`num_layers` \
or PEFT `target_modules`/`layers_to_transform` — matched nothing in the base model, \
or adapters.safetensors carried no factors)",
)));
}
Ok(())
}
fn build_base_linear(
weights: &Weights,
path: &str,
weight: &Array,
quant: Option<&PerLayerQuantization>,
fan_in_fan_out: bool,
) -> Result<BaseLinear> {
let scales_key = format!("{path}.scales");
let biases_key = format!("{path}.biases");
let bias_key = format!("{path}.bias");
let q: Option<Quantization> = quant.and_then(|c| c.quantization_for(path));
if let (Some(q), Some(scales)) = (q, weights.get(&scales_key)) {
if fan_in_fan_out {
return Err(Error::LayerKeyed(LayerKeyedPayload::new(
path.to_string(),
Error::InvariantViolation(InvariantViolationPayload::new(
"load_adapters: quantized base + adapter_config.json `fan_in_fan_out`",
"must not be combined (a packed quantized weight cannot be transposed without \
corrupting the bit-packing; `fan_in_fan_out` applies only to a dense Conv1D-style \
base)",
)),
)));
}
let quant_biases = weights.get(&biases_key).map(Array::try_clone).transpose()?;
let bias = weights.get(&bias_key).map(Array::try_clone).transpose()?;
return BaseLinear::quantized(
weight.try_clone()?,
scales.try_clone()?,
quant_biases,
bias,
q.group_size,
q.bits,
q.mode.as_str().to_string(),
);
}
let bias = weights.get(&bias_key).map(Array::try_clone).transpose()?;
let dense_weight = if fan_in_fan_out {
weight.transpose()?
} else {
weight.try_clone()?
};
BaseLinear::dense(dense_weight, bias)
}
fn path_matches_key(path: &str, key: &str) -> bool {
path == key || path.ends_with(&format!(".{key}"))
}
fn parse_block_index(path: &str) -> Option<i32> {
let marker = "layers.";
let idx = path.find(marker)? + marker.len();
let rest = &path[idx..];
let end = rest.find('.').unwrap_or(rest.len());
rest[..end].parse::<i32>().ok()
}
pub fn load_adapters(
base_weights: &Weights,
dir: &Path,
quant: Option<&PerLayerQuantization>,
num_blocks: i32,
) -> Result<LoraLayers> {
let config_text = read_bounded_adapter_config(dir)?;
let config = LoraConfig::from_json(&config_text)?;
load_adapters_with_config(base_weights, dir, &config, quant, num_blocks)
}
pub fn load_adapters_with_config(
base_weights: &Weights,
dir: &Path,
config: &LoraConfig,
quant: Option<&PerLayerQuantization>,
num_blocks: i32,
) -> Result<LoraLayers> {
if config.fine_tune_type == FineTuneType::Full {
return Err(Error::UnknownEnumValue(UnknownEnumValuePayload::new(
"load_adapters: fine_tune_type (supported adapter types only — `full` is a full-weight \
fine-tune; merge it at the weight-map level via lm::load::load_weights)",
"full",
&["lora", "dora"],
)));
}
if config.rank() <= 0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"load_adapters: adapter rank",
"must be > 0",
format!("{}", config.rank()),
)));
}
let st_path = locate_adapter_safetensors(dir, config)?;
check_adapter_safetensors(&st_path)?;
let adapter_arrays = crate::io::load_safetensors(&st_path)
.map_err(|e| Error::LayerKeyed(LayerKeyedPayload::new(st_path.display().to_string(), e)))?;
let adapter_arrays = match config.selection {
AdapterSelection::Peft(_) => translate_peft_keys(adapter_arrays)?,
AdapterSelection::MlxLm { .. } => adapter_arrays,
};
let adapter_params = split_adapter_params(adapter_arrays, config.is_dora())?;
linear_to_lora_layers(base_weights, config, &adapter_params, quant, num_blocks)
}
pub fn read_adapter_config(dir: &Path) -> Result<LoraConfig> {
let config_text = read_bounded_adapter_config(dir)?;
LoraConfig::from_json(&config_text)
}
const MLX_LM_ADAPTER_FILE: &str = "adapters.safetensors";
const PEFT_ADAPTER_FILE: &str = "adapter_model.safetensors";
fn locate_adapter_safetensors(dir: &Path, config: &LoraConfig) -> Result<std::path::PathBuf> {
let (preferred, fallback) = match config.selection {
AdapterSelection::Peft(_) => (PEFT_ADAPTER_FILE, MLX_LM_ADAPTER_FILE),
AdapterSelection::MlxLm { .. } => (MLX_LM_ADAPTER_FILE, PEFT_ADAPTER_FILE),
};
let preferred_path = dir.join(preferred);
if adapter_candidate_present(&preferred_path)? {
return Ok(preferred_path);
}
let fallback_path = dir.join(fallback);
if adapter_candidate_present(&fallback_path)? {
return Ok(fallback_path);
}
Err(Error::FileIo(FileIoPayload::new(
"load_adapters: no adapter weights file (expected adapters.safetensors or \
adapter_model.safetensors)",
FileOp::Open,
preferred_path,
std::io::Error::from(std::io::ErrorKind::NotFound),
)))
}
enum CandidateProbe {
Absent,
Present,
NonRegular,
IoError(std::io::Error),
}
fn probe_candidate(path: &Path) -> CandidateProbe {
match std::fs::metadata(path) {
Ok(m) if m.is_file() => CandidateProbe::Present,
Ok(_) => CandidateProbe::NonRegular,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => CandidateProbe::Absent,
Err(e) => CandidateProbe::IoError(e),
}
}
fn adapter_candidate_present(path: &Path) -> Result<bool> {
match probe_candidate(path) {
CandidateProbe::Present => Ok(true),
CandidateProbe::Absent => Ok(false),
CandidateProbe::NonRegular => Err(Error::FileIo(FileIoPayload::new(
"load_adapters: adapter weights candidate exists but is not a regular file",
FileOp::Stat,
path.to_path_buf(),
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"adapter candidate path exists but is not a regular file",
),
))),
CandidateProbe::IoError(e) => Err(Error::FileIo(FileIoPayload::new(
"load_adapters: cannot stat adapter weights candidate",
FileOp::Stat,
path.to_path_buf(),
e,
))),
}
}
const PEFT_KEY_PREFIX: &str = "base_model.model.";
fn translate_peft_keys(arrays: HashMap<String, Array>) -> Result<HashMap<String, Array>> {
let mut out: HashMap<String, Array> = HashMap::with_capacity(arrays.len());
for (key, arr) in arrays {
let Some(rest) = key.strip_prefix(PEFT_KEY_PREFIX) else {
continue;
};
if let Some(path) = rest.strip_suffix(".lora_A.weight") {
out.insert(format!("{path}.lora_a"), arr.transpose()?);
} else if let Some(path) = rest.strip_suffix(".lora_B.weight") {
out.insert(format!("{path}.lora_b"), arr.transpose()?);
} else if let Some(path) = rest
.strip_suffix(".lora_magnitude_vector.weight")
.or_else(|| rest.strip_suffix(".lora_magnitude_vector"))
{
out.insert(format!("{path}.m"), arr);
} else if rest.contains(".lora_embedding_A") || rest.contains(".lora_embedding_B") {
return Err(Error::LayerKeyed(LayerKeyedPayload::new(
key,
Error::InvariantViolation(InvariantViolationPayload::new(
"load_adapters: PEFT adapter tensor",
"is an embedding LoRA factor (`lora_embedding_A` / `lora_embedding_B`); embedding \
LoRA is not supported by this loader (only linear-layer `lora_A` / `lora_B` \
low-rank factors are applied)",
)),
)));
} else {
return Err(Error::LayerKeyed(LayerKeyedPayload::new(
key,
Error::InvariantViolation(InvariantViolationPayload::new(
"load_adapters: PEFT adapter tensor",
"must be one of `.lora_A.weight` / `.lora_B.weight` / `.lora_magnitude_vector` \
(this is a PEFT `bias` or `modules_to_save` tensor, which affects inference and \
this low-rank loader has no slot for — dropping it would silently corrupt the \
adapter)",
)),
)));
}
}
Ok(out)
}
fn split_adapter_params(
arrays: HashMap<String, Array>,
expect_dora: bool,
) -> Result<HashMap<String, AdapterParams>> {
let mut a_map: HashMap<String, Array> = HashMap::new();
let mut b_map: HashMap<String, Array> = HashMap::new();
let mut m_map: HashMap<String, Array> = HashMap::new();
for (key, arr) in arrays {
if let Some(path) = key.strip_suffix(".lora_a") {
a_map.insert(path.to_string(), arr);
} else if let Some(path) = key.strip_suffix(".lora_b") {
b_map.insert(path.to_string(), arr);
} else if let Some(path) = key.strip_suffix(".m") {
m_map.insert(path.to_string(), arr);
}
}
let mut out: HashMap<String, AdapterParams> = HashMap::with_capacity(a_map.len());
for (path, lora_a) in a_map {
let lora_b = b_map.remove(&path).ok_or_else(|| {
Error::MissingKey(MissingKeyPayload::new(
"load_adapters: adapter has `lora_a` but no matching `lora_b`",
format!("{path}.lora_b"),
))
})?;
let magnitude = m_map.remove(&path);
if expect_dora && magnitude.is_none() {
return Err(Error::MissingKey(MissingKeyPayload::new(
"load_adapters: DoRA adapter is missing its magnitude `m`",
format!("{path}.m"),
)));
}
out.insert(
path,
AdapterParams {
lora_a,
lora_b,
magnitude,
},
);
}
if let Some((path, _)) = b_map.into_iter().next() {
return Err(Error::MissingKey(MissingKeyPayload::new(
"load_adapters: adapter has `lora_b` but no matching `lora_a`",
format!("{path}.lora_a"),
)));
}
Ok(out)
}
fn read_bounded_adapter_config(dir: &Path) -> Result<String> {
use std::io::Read;
let path = dir.join("adapter_config.json");
#[cfg(unix)]
let file = {
use std::os::unix::fs::OpenOptionsExt;
std::fs::OpenOptions::new()
.read(true)
.custom_flags(libc::O_NONBLOCK | libc::O_CLOEXEC)
.open(&path)
.map_err(|e| {
Error::FileIo(FileIoPayload::new(
"load_adapters",
FileOp::Open,
path.to_path_buf(),
e,
))
})?
};
#[cfg(not(unix))]
let file = std::fs::File::open(&path).map_err(|e| {
Error::FileIo(FileIoPayload::new(
"load_adapters",
FileOp::Open,
path.to_path_buf(),
e,
))
})?;
let meta = file.metadata().map_err(|e| {
Error::FileIo(FileIoPayload::new(
"load_adapters",
FileOp::Stat,
path.to_path_buf(),
e,
))
})?;
if !meta.is_file() {
return Err(Error::FileIo(FileIoPayload::new(
"load_adapters: adapter_config.json must be a regular file",
FileOp::Stat,
path,
std::io::Error::from(std::io::ErrorKind::InvalidInput),
)));
}
let cap = crate::lm::load::MAX_CONFIG_BYTES;
let mut bytes = Vec::new();
file.take(cap + 1).read_to_end(&mut bytes).map_err(|e| {
Error::FileIo(FileIoPayload::new(
"load_adapters",
FileOp::Read,
path.clone(),
e,
))
})?;
if bytes.len() as u64 > cap {
return Err(Error::CapExceeded(CapExceededPayload::new(
"load_adapters: adapter_config.json body",
"MAX_CONFIG_BYTES",
cap,
bytes.len() as u64,
)));
}
String::from_utf8(bytes).map_err(|e| {
Error::Parse(ParsePayload::new(
"load_adapters: adapter_config.json",
"UTF-8",
e,
))
})
}
fn check_adapter_safetensors(path: &Path) -> Result<()> {
#[cfg(unix)]
let file = {
use std::os::unix::fs::OpenOptionsExt;
std::fs::OpenOptions::new()
.read(true)
.custom_flags(libc::O_NONBLOCK | libc::O_CLOEXEC)
.open(path)
.map_err(|e| {
Error::FileIo(FileIoPayload::new(
"load_adapters",
FileOp::Open,
path.to_path_buf(),
e,
))
})?
};
#[cfg(not(unix))]
let file = std::fs::File::open(path).map_err(|e| {
Error::FileIo(FileIoPayload::new(
"load_adapters",
FileOp::Open,
path.to_path_buf(),
e,
))
})?;
let meta = file.metadata().map_err(|e| {
Error::FileIo(FileIoPayload::new(
"load_adapters",
FileOp::Stat,
path.to_path_buf(),
e,
))
})?;
if !meta.is_file() {
return Err(Error::FileIo(FileIoPayload::new(
"load_adapters: adapter safetensors must be a regular file",
FileOp::Stat,
path.to_path_buf(),
std::io::Error::from(std::io::ErrorKind::InvalidInput),
)));
}
if meta.len() > MAX_ADAPTER_SAFETENSORS_BYTES {
return Err(Error::CapExceeded(CapExceededPayload::new(
"load_adapters: adapter safetensors body",
"MAX_ADAPTER_SAFETENSORS_BYTES",
MAX_ADAPTER_SAFETENSORS_BYTES,
meta.len(),
)));
}
Ok(())
}
#[cfg(test)]
mod tests;