use std::collections::HashMap;
use std::path::Path;
use mlx_rs::module::{Module, ModuleParameters, Param};
use mlx_rs::nn;
use mlx_rs::ops;
use mlx_rs::ops::indexing::IndexOp;
use mlx_rs::Array;
use tokenizers::Tokenizer;
use tracing::info;
use crate::InferenceError;
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Qwen3Config {
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: usize,
#[serde(default)]
pub head_dim: Option<usize>,
pub vocab_size: usize,
pub rms_norm_eps: f32,
#[serde(default = "default_rope_theta")]
pub rope_theta: f32,
#[serde(default = "default_max_position")]
pub max_position_embeddings: usize,
#[serde(default)]
pub use_sliding_window: bool,
#[serde(default)]
pub quantization: Option<QuantConfig>,
#[serde(default)]
pub num_experts: Option<usize>,
#[serde(default)]
pub num_experts_per_tok: Option<usize>,
}
impl Qwen3Config {
pub fn resolved_head_dim(&self) -> usize {
self.head_dim
.unwrap_or_else(|| self.hidden_size / self.num_attention_heads)
}
}
fn default_rope_theta() -> f32 {
1_000_000.0
}
fn default_max_position() -> usize {
32768
}
#[derive(Debug, Clone, serde::Deserialize)]
pub(crate) struct QuantConfig {
pub group_size: i32,
pub bits: i32,
}
pub(crate) enum QLinear {
Dense(nn::Linear),
Quantized(nn::QuantizedLinear),
}
impl QLinear {
pub(crate) fn forward(&mut self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let profile = std::env::var("CAR_QLINEAR_PROFILE").is_ok();
let input_shape = x.shape().to_vec();
if profile {
let _ = mlx_rs::transforms::eval([x]);
}
let t0 = std::time::Instant::now();
if input_shape.len() <= 2 {
let out = match self {
Self::Dense(l) => l.forward(x)?,
Self::Quantized(l) => l.forward(x)?,
};
if profile {
mlx_rs::transforms::eval([&out])?;
tracing::info!(
in_shape = ?input_shape,
kind = if matches!(self, Self::Quantized(_)) { "q" } else { "dense" },
elapsed_ms = t0.elapsed().as_millis() as u64,
"qlinear forward (2D)"
);
}
return Ok(out);
}
let in_dim = *input_shape.last().unwrap_or(&0);
let flat = ops::reshape(x, &[-1, in_dim])?;
if profile {
let _ = mlx_rs::transforms::eval([&flat]);
}
let t_pre_matmul = t0.elapsed();
let flat_out = match self {
Self::Dense(l) => l.forward(&flat)?,
Self::Quantized(l) => {
let t_raw = std::time::Instant::now();
let mut out = ops::quantized_matmul(
&flat,
&l.inner.weight,
&l.scales,
&l.biases,
true,
l.group_size,
l.bits,
)?;
if profile {
let _ = mlx_rs::transforms::eval([&out]);
tracing::info!(
raw_qmatmul_ms = t_raw.elapsed().as_millis() as u64,
"raw quantized_matmul only"
);
}
if let Some(bias) = l.inner.bias.value.as_ref() {
out = ops::add(&out, bias)?;
}
out
}
};
if profile {
let _ = mlx_rs::transforms::eval([&flat_out]);
}
let t_post_matmul = t0.elapsed();
let out_dim = *flat_out.shape().last().unwrap_or(&0);
let mut output_shape = input_shape[..input_shape.len() - 1].to_vec();
output_shape.push(out_dim);
let out = ops::reshape(&flat_out, &output_shape)?;
if profile {
mlx_rs::transforms::eval([&out])?;
tracing::info!(
in_shape = ?input_shape,
kind = if matches!(self, Self::Quantized(_)) { "q" } else { "dense" },
pre_matmul_ms = t_pre_matmul.as_millis() as u64,
matmul_only_ms = (t_post_matmul - t_pre_matmul).as_millis() as u64,
total_ms = t0.elapsed().as_millis() as u64,
"qlinear forward (3D)"
);
}
Ok(out)
}
fn parameters(&self) -> Vec<(String, &Array)> {
match self {
Self::Dense(l) => l
.parameters()
.flatten()
.into_iter()
.map(|(k, v)| (k.to_string(), v))
.collect(),
Self::Quantized(l) => l
.parameters()
.flatten()
.into_iter()
.map(|(k, v)| (k.to_string(), v))
.collect(),
}
}
#[allow(dead_code)]
fn parameters_mut(&mut self) -> Vec<(String, &mut Array)> {
match self {
Self::Dense(l) => l
.parameters_mut()
.flatten()
.into_iter()
.map(|(k, v)| (k.to_string(), v))
.collect(),
Self::Quantized(l) => l
.parameters_mut()
.flatten()
.into_iter()
.map(|(k, v)| (k.to_string(), v))
.collect(),
}
}
}
pub(crate) enum QEmbedding {
Dense(nn::Embedding),
Quantized(nn::QuantizedEmbedding),
}
impl QEmbedding {
pub(crate) fn forward(&mut self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
match self {
Self::Dense(e) => e.forward(x),
Self::Quantized(e) => e.forward(x),
}
}
fn parameters(&self) -> Vec<(String, &Array)> {
match self {
Self::Dense(e) => e
.parameters()
.flatten()
.into_iter()
.map(|(k, v)| (k.to_string(), v))
.collect(),
Self::Quantized(e) => e
.parameters()
.flatten()
.into_iter()
.map(|(k, v)| (k.to_string(), v))
.collect(),
}
}
#[allow(dead_code)]
fn parameters_mut(&mut self) -> Vec<(String, &mut Array)> {
match self {
Self::Dense(e) => e
.parameters_mut()
.flatten()
.into_iter()
.map(|(k, v)| (k.to_string(), v))
.collect(),
Self::Quantized(e) => e
.parameters_mut()
.flatten()
.into_iter()
.map(|(k, v)| (k.to_string(), v))
.collect(),
}
}
}
struct RmsNorm {
weight: Array,
eps: f32,
}
impl RmsNorm {
#[allow(dead_code)]
fn new(dim: usize, eps: f32) -> Self {
Self {
weight: Array::ones::<f32>(&[dim as i32]).unwrap(),
eps,
}
}
fn forward(&self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let x_sq = ops::multiply(x, x)?;
let mean = x_sq.mean_axes(&[-1], true)?;
let eps_arr = Array::from_f32(self.eps);
let norm = ops::rsqrt(&ops::add(&mean, &eps_arr)?)?;
let normed = ops::multiply(x, &norm)?;
ops::multiply(&normed, &self.weight)
}
fn parameters(&self) -> Vec<(&str, &Array)> {
vec![("weight", &self.weight)]
}
#[allow(dead_code)]
fn parameters_mut(&mut self) -> Vec<(&str, &mut Array)> {
vec![("weight", &mut self.weight)]
}
}
struct RotaryEmbedding {
head_dim: usize,
theta: f32,
}
impl RotaryEmbedding {
fn new(
head_dim: usize,
_max_seq_len: usize,
theta: f32,
) -> Result<Self, mlx_rs::error::Exception> {
Ok(Self { head_dim, theta })
}
fn apply(&self, x: &Array, offset: usize) -> Result<Array, mlx_rs::error::Exception> {
mlx_rs::fast::rope(
x,
self.head_dim as i32,
false,
self.theta,
1.0,
offset as i32,
None::<&Array>,
)
}
}
struct Mlp {
gate_proj: QLinear,
up_proj: QLinear,
down_proj: QLinear,
}
impl Mlp {
fn forward(&mut self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let gate = nn::silu(&self.gate_proj.forward(x)?)?;
let up = self.up_proj.forward(x)?;
self.down_proj.forward(&ops::multiply(&gate, &up)?)
}
fn parameters(&self) -> Vec<(String, &Array)> {
let mut params = Vec::new();
for (k, v) in self.gate_proj.parameters() {
params.push((format!("gate_proj.{k}"), v));
}
for (k, v) in self.up_proj.parameters() {
params.push((format!("up_proj.{k}"), v));
}
for (k, v) in self.down_proj.parameters() {
params.push((format!("down_proj.{k}"), v));
}
params
}
#[allow(dead_code)]
fn parameters_mut(&mut self) -> Vec<(String, &mut Array)> {
let mut params = Vec::new();
for (k, v) in self.gate_proj.parameters_mut() {
params.push((format!("gate_proj.{k}"), v));
}
for (k, v) in self.up_proj.parameters_mut() {
params.push((format!("up_proj.{k}"), v));
}
for (k, v) in self.down_proj.parameters_mut() {
params.push((format!("down_proj.{k}"), v));
}
params
}
}
struct QuantizedExperts {
weight: Array, scales: Array, biases: Array, group_size: i32,
bits: i32,
}
impl QuantizedExperts {
fn forward_expert(&self, x: &Array, eid: usize) -> Result<Array, mlx_rs::error::Exception> {
let ei = eid as i32;
let w = self.weight.index((ei..ei + 1, .., ..)).squeeze_axes(&[0])?;
let s = self.scales.index((ei..ei + 1, .., ..)).squeeze_axes(&[0])?;
let b = self.biases.index((ei..ei + 1, .., ..)).squeeze_axes(&[0])?;
ops::quantized_matmul(x, &w, &s, &b, true, self.group_size, self.bits)
}
fn all_arrays(&self) -> Vec<&Array> {
vec![&self.weight, &self.scales, &self.biases]
}
}
struct MoeLayer {
gate: QLinear,
gate_proj: QuantizedExperts,
up_proj: QuantizedExperts,
down_proj: QuantizedExperts,
num_experts: usize,
num_experts_per_tok: usize,
}
impl MoeLayer {
fn forward(&mut self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let num_tokens = x.shape()[0] as usize;
let hidden_dim = x.shape()[1] as usize;
let router_logits = self.gate.forward(x)?;
let routing_weights = ops::softmax_axis(&router_logits, -1, None)?;
mlx_rs::transforms::eval([&routing_weights])?;
let rw_f32 = routing_weights.as_dtype(mlx_rs::Dtype::Float32)?;
rw_f32.eval()?;
let routing_data: &[f32] = rw_f32.as_slice();
let mut token_outputs = Vec::with_capacity(num_tokens);
for t in 0..num_tokens {
let row = &routing_data[t * self.num_experts..(t + 1) * self.num_experts];
let mut indices: Vec<usize> = (0..self.num_experts).collect();
indices.sort_by(|&a, &b| {
row[b]
.partial_cmp(&row[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
let top_k_ids = &indices[..self.num_experts_per_tok];
let mut weights: Vec<f32> = top_k_ids.iter().map(|&i| row[i]).collect();
let weight_sum: f32 = weights.iter().sum();
if weight_sum > 0.0 {
for w in &mut weights {
*w /= weight_sum;
}
}
let token = x.index((t as i32..t as i32 + 1, ..));
let mut combined = Array::zeros::<f32>(&[1, hidden_dim as i32])?;
for (k, &eid) in top_k_ids.iter().enumerate() {
let gate_out = self.gate_proj.forward_expert(&token, eid)?;
let up_out = self.up_proj.forward_expert(&token, eid)?;
let activated = ops::multiply(&nn::silu(&gate_out)?, &up_out)?;
let expert_out = self.down_proj.forward_expert(&activated, eid)?;
let w = Array::from_f32(weights[k]);
combined = ops::add(&combined, &ops::multiply(&expert_out, &w)?)?;
}
token_outputs.push(combined);
}
let refs: Vec<&Array> = token_outputs.iter().collect();
ops::concatenate_axis(&refs, 0)
}
fn parameters(&self) -> Vec<(String, &Array)> {
let mut params = Vec::new();
for (k, v) in self.gate.parameters() {
params.push((format!("gate.{k}"), v));
}
for a in self.gate_proj.all_arrays() {
params.push(("switch_mlp.gate_proj".into(), a));
}
for a in self.up_proj.all_arrays() {
params.push(("switch_mlp.up_proj".into(), a));
}
for a in self.down_proj.all_arrays() {
params.push(("switch_mlp.down_proj".into(), a));
}
params
}
#[allow(dead_code)]
fn parameters_mut(&mut self) -> Vec<(String, &mut Array)> {
let mut params = Vec::new();
for (k, v) in self.gate.parameters_mut() {
params.push((format!("gate.{k}"), v));
}
params
}
}
enum FeedForward {
Dense(Mlp),
Moe(MoeLayer),
}
impl FeedForward {
fn forward(&mut self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
match self {
Self::Dense(mlp) => mlp.forward(x),
Self::Moe(moe) => moe.forward(x),
}
}
fn parameters(&self) -> Vec<(String, &Array)> {
match self {
Self::Dense(mlp) => mlp.parameters(),
Self::Moe(moe) => moe.parameters(),
}
}
#[allow(dead_code)]
fn parameters_mut(&mut self) -> Vec<(String, &mut Array)> {
match self {
Self::Dense(mlp) => mlp.parameters_mut(),
Self::Moe(moe) => moe.parameters_mut(),
}
}
}
struct Attention {
q_proj: QLinear,
k_proj: QLinear,
v_proj: QLinear,
o_proj: QLinear,
q_norm: Option<RmsNorm>,
k_norm: Option<RmsNorm>,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
rope: std::sync::Arc<RotaryEmbedding>,
k_cache: Option<Array>,
v_cache: Option<Array>,
}
impl Attention {
fn forward(
&mut self,
x: &Array,
mask: Option<&Array>,
offset: usize,
) -> Result<Array, mlx_rs::error::Exception> {
let (batch, seq_len, _) = {
let s = x.shape();
(s[0] as usize, s[1] as usize, s[2] as usize)
};
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;
let q = ops::transpose_axes(
&ops::reshape(
&q,
&[
batch as i32,
seq_len as i32,
self.num_heads as i32,
self.head_dim as i32,
],
)?,
&[0, 2, 1, 3],
)?;
let k = ops::transpose_axes(
&ops::reshape(
&k,
&[
batch as i32,
seq_len as i32,
self.num_kv_heads as i32,
self.head_dim as i32,
],
)?,
&[0, 2, 1, 3],
)?;
let v = ops::transpose_axes(
&ops::reshape(
&v,
&[
batch as i32,
seq_len as i32,
self.num_kv_heads as i32,
self.head_dim as i32,
],
)?,
&[0, 2, 1, 3],
)?;
let q = match &self.q_norm {
Some(n) => self.apply_norm_to_heads(n, &q)?,
None => q,
};
let k = match &self.k_norm {
Some(n) => self.apply_norm_to_heads(n, &k)?,
None => k,
};
let q = self.rope.apply(&q, offset)?;
let k = self.rope.apply(&k, offset)?;
let (k, v) = self.update_kv_cache(k, v)?;
let groups = self.num_heads / self.num_kv_heads;
let (k, v) = if groups > 1 {
(Self::repeat_kv(&k, groups)?, Self::repeat_kv(&v, groups)?)
} else {
(k, v)
};
let scale = Array::from_f32(1.0 / (self.head_dim as f32).sqrt());
let scores = ops::multiply(
&ops::matmul(&q, &ops::transpose_axes(&k, &[0, 1, 3, 2])?)?,
&scale,
)?;
let scores = if let Some(mask) = mask {
ops::add(&scores, mask)?
} else {
scores
};
let attn = ops::softmax_axis(&scores, -1, None)?;
let out = ops::matmul(&attn, &v)?;
let out = ops::transpose_axes(&out, &[0, 2, 1, 3])?;
let out = ops::reshape(
&out,
&[
batch as i32,
seq_len as i32,
(self.num_heads * self.head_dim) as i32,
],
)?;
self.o_proj.forward(&out)
}
fn apply_norm_to_heads(
&self,
norm: &RmsNorm,
x: &Array,
) -> Result<Array, mlx_rs::error::Exception> {
let shape = x.shape().to_vec();
let flat = ops::reshape(x, &[-1, shape[3]])?;
let normed = norm.forward(&flat)?;
ops::reshape(&normed, &shape)
}
fn update_kv_cache(
&mut self,
k: Array,
v: Array,
) -> Result<(Array, Array), mlx_rs::error::Exception> {
let k = if let Some(ref cached) = self.k_cache {
ops::concatenate_axis(&[cached, &k], 2)?
} else {
k
};
let v = if let Some(ref cached) = self.v_cache {
ops::concatenate_axis(&[cached, &v], 2)?
} else {
v
};
self.k_cache = Some(k.clone());
self.v_cache = Some(v.clone());
Ok((k, v))
}
fn repeat_kv(x: &Array, groups: usize) -> Result<Array, mlx_rs::error::Exception> {
if groups == 1 {
return Ok(x.clone());
}
let shape = x.shape();
let expanded = ops::reshape(x, &[shape[0], shape[1], 1, shape[2], shape[3]])?;
let tiled = ops::tile(&expanded, &[1, 1, groups as i32, 1, 1])?;
ops::reshape(
&tiled,
&[shape[0], shape[1] * groups as i32, shape[2], shape[3]],
)
}
fn clear_kv_cache(&mut self) {
self.k_cache = None;
self.v_cache = None;
}
fn parameters(&self) -> Vec<(String, &Array)> {
let mut params = Vec::new();
for (k, v) in self.q_proj.parameters() {
params.push((format!("q_proj.{k}"), v));
}
for (k, v) in self.k_proj.parameters() {
params.push((format!("k_proj.{k}"), v));
}
for (k, v) in self.v_proj.parameters() {
params.push((format!("v_proj.{k}"), v));
}
for (k, v) in self.o_proj.parameters() {
params.push((format!("o_proj.{k}"), v));
}
if let Some(qn) = &self.q_norm {
for (k, v) in qn.parameters() {
params.push((format!("q_norm.{k}"), v));
}
}
if let Some(kn) = &self.k_norm {
for (k, v) in kn.parameters() {
params.push((format!("k_norm.{k}"), v));
}
}
params
}
#[allow(dead_code)]
fn parameters_mut(&mut self) -> Vec<(String, &mut Array)> {
let mut params = Vec::new();
for (k, v) in self.q_proj.parameters_mut() {
params.push((format!("q_proj.{k}"), v));
}
for (k, v) in self.k_proj.parameters_mut() {
params.push((format!("k_proj.{k}"), v));
}
for (k, v) in self.v_proj.parameters_mut() {
params.push((format!("v_proj.{k}"), v));
}
for (k, v) in self.o_proj.parameters_mut() {
params.push((format!("o_proj.{k}"), v));
}
if let Some(qn) = &mut self.q_norm {
for (k, v) in qn.parameters_mut() {
params.push((format!("q_norm.{k}"), v));
}
}
if let Some(kn) = &mut self.k_norm {
for (k, v) in kn.parameters_mut() {
params.push((format!("k_norm.{k}"), v));
}
}
params
}
}
struct TransformerLayer {
self_attn: Attention,
ffn: FeedForward,
input_layernorm: RmsNorm,
post_attention_layernorm: RmsNorm,
}
impl TransformerLayer {
fn forward(
&mut self,
x: &Array,
mask: Option<&Array>,
offset: usize,
) -> Result<Array, mlx_rs::error::Exception> {
let residual = x.clone();
let h = self.input_layernorm.forward(x)?;
let h = self.self_attn.forward(&h, mask, offset)?;
let x = ops::add(&residual, &h)?;
let residual = x.clone();
let h = self.post_attention_layernorm.forward(&x)?;
let h = match &self.ffn {
FeedForward::Moe(_) => {
let shape = h.shape().to_vec();
let flat = ops::reshape(&h, &[shape[0] * shape[1], shape[2]])?;
let out = self.ffn.forward(&flat)?;
ops::reshape(&out, &shape)?
}
FeedForward::Dense(_) => self.ffn.forward(&h)?,
};
ops::add(&residual, &h)
}
fn clear_kv_cache(&mut self) {
self.self_attn.clear_kv_cache();
}
fn parameters(&self) -> Vec<(String, &Array)> {
let mut params = Vec::new();
for (k, v) in self.self_attn.parameters() {
params.push((format!("self_attn.{k}"), v));
}
for (k, v) in self.ffn.parameters() {
params.push((format!("mlp.{k}"), v));
}
for (k, v) in self.input_layernorm.parameters() {
params.push((format!("input_layernorm.{k}"), v));
}
for (k, v) in self.post_attention_layernorm.parameters() {
params.push((format!("post_attention_layernorm.{k}"), v));
}
params
}
#[allow(dead_code)]
fn parameters_mut(&mut self) -> Vec<(String, &mut Array)> {
let mut params = Vec::new();
for (k, v) in self.self_attn.parameters_mut() {
params.push((format!("self_attn.{k}"), v));
}
for (k, v) in self.ffn.parameters_mut() {
params.push((format!("mlp.{k}"), v));
}
for (k, v) in self.input_layernorm.parameters_mut() {
params.push((format!("input_layernorm.{k}"), v));
}
for (k, v) in self.post_attention_layernorm.parameters_mut() {
params.push((format!("post_attention_layernorm.{k}"), v));
}
params
}
}
pub(crate) fn load_all_tensors(model_dir: &Path) -> Result<HashMap<String, Array>, InferenceError> {
let index_path = model_dir.join("model.safetensors.index.json");
let single_path = model_dir.join("model.safetensors");
let weight_files: Vec<std::path::PathBuf> = if index_path.exists() {
let index_json: serde_json::Value = serde_json::from_str(
&std::fs::read_to_string(&index_path)
.map_err(|e| InferenceError::InferenceFailed(format!("read index: {e}")))?,
)
.map_err(|e| InferenceError::InferenceFailed(format!("parse index: {e}")))?;
let weight_map = index_json
.get("weight_map")
.and_then(|m| m.as_object())
.ok_or_else(|| InferenceError::InferenceFailed("missing weight_map".into()))?;
let mut files: std::collections::HashSet<String> = std::collections::HashSet::new();
for filename in weight_map.values() {
if let Some(f) = filename.as_str() {
files.insert(f.to_string());
}
}
files.into_iter().map(|f| model_dir.join(f)).collect()
} else if single_path.exists() {
vec![single_path]
} else {
let mut recursive = Vec::new();
collect_safetensors_files(model_dir, &mut recursive)?;
if recursive.is_empty() {
return Err(InferenceError::InferenceFailed(format!(
"no safetensors weights found in {}",
model_dir.display()
)));
}
recursive.sort();
recursive
};
let mut all_tensors = HashMap::new();
for weight_file in &weight_files {
info!(path = %weight_file.display(), "loading safetensors file");
let tensors = Array::load_safetensors(weight_file)
.map_err(|e| InferenceError::InferenceFailed(format!("load safetensors: {e}")))?;
info!(path = %weight_file.display(), tensors = tensors.len(), "loaded safetensors file");
for (name, array) in tensors {
all_tensors.insert(name, array);
}
}
add_flux_mlx_aliases(&mut all_tensors);
Ok(all_tensors)
}
fn collect_safetensors_files(
root: &Path,
files: &mut Vec<std::path::PathBuf>,
) -> Result<(), InferenceError> {
let entries = std::fs::read_dir(root).map_err(|e| {
InferenceError::InferenceFailed(format!("read dir {}: {e}", root.display()))
})?;
for entry in entries {
let entry = entry.map_err(|e| {
InferenceError::InferenceFailed(format!("read dir entry {}: {e}", root.display()))
})?;
let path = entry.path();
if path.is_dir() {
collect_safetensors_files(&path, files)?;
} else if path
.extension()
.and_then(|value| value.to_str())
.map(|value| value.eq_ignore_ascii_case("safetensors"))
.unwrap_or(false)
{
files.push(path);
}
}
Ok(())
}
fn add_flux_mlx_aliases(tensors: &mut HashMap<String, Array>) {
let keys: Vec<String> = tensors.keys().cloned().collect();
for key in keys {
let Some(value) = tensors.get(&key).cloned() else {
continue;
};
if let Some(rest) = key.strip_prefix("text_model.") {
tensors
.entry(format!("text_encoders.clip.transformer.text_model.{rest}"))
.or_insert(value);
continue;
}
if let Some(rest) = key.strip_prefix("shared.") {
tensors
.entry(format!("text_encoders.t5.transformer.shared.{rest}"))
.or_insert(value);
continue;
}
if let Some(rest) = key.strip_prefix("t5_blocks.") {
let alias = rest
.replace(".attention.SelfAttention.", ".self_attn.")
.replace(".attention.layer_norm.", ".norm1.")
.replace(".ff.DenseReluDense.", ".ff.")
.replace(".ff.layer_norm.", ".norm2.");
tensors
.entry(format!("text_encoders.t5.transformer.t5_blocks.{alias}"))
.or_insert(value);
continue;
}
if key == "final_layer_norm.weight" {
tensors
.entry("text_encoders.t5.transformer.final_layer_norm.weight".to_string())
.or_insert(value);
continue;
}
if let Some(rest) = key.strip_prefix("decoder.") {
let alias = rest
.replace("conv_in.conv2d.", "conv_in.")
.replace("conv_out.conv2d.", "conv_out.")
.replace("conv_norm_out.norm.", "conv_norm_out.");
tensors
.entry(format!("vae.decoder.{alias}"))
.or_insert(value);
continue;
}
if key.starts_with("x_embedder.")
|| key.starts_with("context_embedder.")
|| key.starts_with("time_text_embed.")
|| key.starts_with("transformer_blocks.")
|| key.starts_with("single_transformer_blocks.")
|| key.starts_with("norm_out.")
|| key.starts_with("proj_out.")
{
let alias = key
.replace(
"time_text_embed.timestep_embedder.linear_1.",
"time_text_embed.timestep_embedder.0.",
)
.replace(
"time_text_embed.timestep_embedder.linear_2.",
"time_text_embed.timestep_embedder.2.",
)
.replace(
"time_text_embed.text_embedder.linear_1.",
"time_text_embed.text_embedder.0.",
)
.replace(
"time_text_embed.text_embedder.linear_2.",
"time_text_embed.text_embedder.2.",
)
.replace(
"time_text_embed.guidance_embedder.linear_1.",
"time_text_embed.guidance_embedder.0.",
)
.replace(
"time_text_embed.guidance_embedder.linear_2.",
"time_text_embed.guidance_embedder.2.",
);
tensors
.entry(format!("transformer.{alias}"))
.or_insert(value);
}
}
}
pub(crate) fn build_qlinear(
tensors: &HashMap<String, Array>,
prefix: &str,
quant: Option<&QuantConfig>,
) -> Result<QLinear, InferenceError> {
let weight_key = format!("{prefix}.weight");
let scales_key = format!("{prefix}.scales");
let biases_key = format!("{prefix}.biases");
let weight = tensors
.get(&weight_key)
.ok_or_else(|| InferenceError::InferenceFailed(format!("missing {weight_key}")))?;
if let Some(q) = quant {
if let (Some(scales), Some(biases)) = (tensors.get(&scales_key), tensors.get(&biases_key)) {
let linear_bias = tensors.get(&format!("{prefix}.bias")).cloned();
let inner = nn::Linear {
weight: Param::new(weight.clone()),
bias: Param::new(linear_bias),
};
let mut ql = nn::QuantizedLinear {
group_size: q.group_size,
bits: q.bits,
scales: Param::new(scales.clone()),
biases: Param::new(biases.clone()),
inner,
};
ql.freeze_parameters(true);
return Ok(QLinear::Quantized(ql));
}
}
let mut l = nn::Linear {
weight: Param::new(weight.clone()),
bias: Param::new(None),
};
if let Some(bias) = tensors.get(&format!("{prefix}.bias")) {
l.bias = Param::new(Some(bias.clone()));
}
Ok(QLinear::Dense(l))
}
pub(crate) fn build_qembedding(
tensors: &HashMap<String, Array>,
prefix: &str,
quant: Option<&QuantConfig>,
) -> Result<QEmbedding, InferenceError> {
let weight_key = format!("{prefix}.weight");
let scales_key = format!("{prefix}.scales");
let biases_key = format!("{prefix}.biases");
let weight = tensors
.get(&weight_key)
.ok_or_else(|| InferenceError::InferenceFailed(format!("missing {weight_key}")))?;
if let Some(q) = quant {
if let (Some(scales), Some(biases)) = (tensors.get(&scales_key), tensors.get(&biases_key)) {
let inner = nn::Embedding {
weight: Param::new(weight.clone()),
};
let mut qe = nn::QuantizedEmbedding {
group_size: q.group_size,
bits: q.bits,
scales: Param::new(scales.clone()),
biases: Param::new(biases.clone()),
inner,
};
qe.freeze_parameters(true);
return Ok(QEmbedding::Quantized(qe));
}
}
Ok(QEmbedding::Dense(nn::Embedding {
weight: Param::new(weight.clone()),
}))
}
struct Qwen3Model {
embed_tokens: QEmbedding,
layers: Vec<TransformerLayer>,
norm: RmsNorm,
lm_head: QLinear,
_config: Qwen3Config,
}
impl Qwen3Model {
fn from_tensors(
config: &Qwen3Config,
tensors: &HashMap<String, Array>,
) -> Result<Self, InferenceError> {
let quant = config.quantization.as_ref();
let map_err = |e: mlx_rs::error::Exception| InferenceError::InferenceFailed(e.to_string());
let head_dim = config.resolved_head_dim();
let rope = std::sync::Arc::new(
RotaryEmbedding::new(
head_dim,
config.max_position_embeddings.min(8192),
config.rope_theta,
)
.map_err(map_err)?,
);
let embed_tokens = build_qembedding(tensors, "model.embed_tokens", quant)?;
let mut layers = Vec::with_capacity(config.num_hidden_layers);
for i in 0..config.num_hidden_layers {
let pfx = format!("model.layers.{i}");
let self_attn = Attention {
q_proj: build_qlinear(tensors, &format!("{pfx}.self_attn.q_proj"), quant)?,
k_proj: build_qlinear(tensors, &format!("{pfx}.self_attn.k_proj"), quant)?,
v_proj: build_qlinear(tensors, &format!("{pfx}.self_attn.v_proj"), quant)?,
o_proj: build_qlinear(tensors, &format!("{pfx}.self_attn.o_proj"), quant)?,
q_norm: if tensors.contains_key(&format!("{pfx}.self_attn.q_norm.weight")) {
Some(load_rms_norm(
tensors,
&format!("{pfx}.self_attn.q_norm"),
config,
)?)
} else {
None
},
k_norm: if tensors.contains_key(&format!("{pfx}.self_attn.k_norm.weight")) {
Some(load_rms_norm(
tensors,
&format!("{pfx}.self_attn.k_norm"),
config,
)?)
} else {
None
},
num_heads: config.num_attention_heads,
num_kv_heads: config.num_key_value_heads,
head_dim,
rope: rope.clone(),
k_cache: None,
v_cache: None,
};
let ffn = if let (Some(ne), Some(nek)) =
(config.num_experts, config.num_experts_per_tok)
{
let q = quant.ok_or_else(|| {
InferenceError::InferenceFailed("MoE models require quantization config".into())
})?;
let get = |key: &str| -> Result<Array, InferenceError> {
tensors
.get(key)
.cloned()
.ok_or_else(|| InferenceError::InferenceFailed(format!("missing {key}")))
};
let load_experts = |proj: &str| -> Result<QuantizedExperts, InferenceError> {
let base = format!("{pfx}.mlp.switch_mlp.{proj}");
Ok(QuantizedExperts {
weight: get(&format!("{base}.weight"))?,
scales: get(&format!("{base}.scales"))?,
biases: get(&format!("{base}.biases"))?,
group_size: q.group_size,
bits: q.bits,
})
};
let moe = MoeLayer {
gate: build_qlinear(tensors, &format!("{pfx}.mlp.gate"), quant)?,
gate_proj: load_experts("gate_proj")?,
up_proj: load_experts("up_proj")?,
down_proj: load_experts("down_proj")?,
num_experts: ne,
num_experts_per_tok: nek,
};
FeedForward::Moe(moe)
} else {
FeedForward::Dense(Mlp {
gate_proj: build_qlinear(tensors, &format!("{pfx}.mlp.gate_proj"), quant)?,
up_proj: build_qlinear(tensors, &format!("{pfx}.mlp.up_proj"), quant)?,
down_proj: build_qlinear(tensors, &format!("{pfx}.mlp.down_proj"), quant)?,
})
};
layers.push(TransformerLayer {
self_attn,
ffn,
input_layernorm: load_rms_norm(tensors, &format!("{pfx}.input_layernorm"), config)?,
post_attention_layernorm: load_rms_norm(
tensors,
&format!("{pfx}.post_attention_layernorm"),
config,
)?,
});
}
let norm = load_rms_norm(tensors, "model.norm", config)?;
let lm_head = if tensors.contains_key("lm_head.weight") {
build_qlinear(tensors, "lm_head", quant)?
} else {
build_qlinear(tensors, "model.embed_tokens", quant)?
};
Ok(Self {
embed_tokens,
layers,
norm,
lm_head,
_config: config.clone(),
})
}
fn forward(
&mut self,
tokens: &Array,
offset: usize,
) -> Result<Array, mlx_rs::error::Exception> {
let h = self.forward_hidden(tokens, offset)?;
self.lm_head.forward(&h)
}
fn forward_hidden(
&mut self,
tokens: &Array,
offset: usize,
) -> Result<Array, mlx_rs::error::Exception> {
let seq_len = tokens.shape()[1] as usize;
let mut h = self.embed_tokens.forward(tokens)?;
let mask = if seq_len > 1 {
Some(Self::causal_mask(seq_len, offset)?)
} else {
None
};
for layer in &mut self.layers {
h = layer.forward(&h, mask.as_ref(), offset)?;
}
self.norm.forward(&h)
}
fn causal_mask(seq_len: usize, offset: usize) -> Result<Array, mlx_rs::error::Exception> {
let total = seq_len + offset;
let mut mask_data = vec![0.0f32; seq_len * total];
for i in 0..seq_len {
for j in 0..total {
if j > i + offset {
mask_data[i * total + j] = f32::NEG_INFINITY;
}
}
}
let mask = Array::from_slice(&mask_data, &[seq_len as i32, total as i32]);
ops::reshape(&mask, &[1, 1, seq_len as i32, total as i32])
}
fn clear_kv_cache(&mut self) {
for layer in &mut self.layers {
layer.clear_kv_cache();
}
}
fn all_parameters(&self) -> Vec<&Array> {
let mut params = Vec::new();
for (_, v) in self.embed_tokens.parameters() {
params.push(v);
}
for layer in &self.layers {
for (_, v) in layer.parameters() {
params.push(v);
}
}
for (_, v) in self.norm.parameters() {
params.push(v);
}
for (_, v) in self.lm_head.parameters() {
params.push(v);
}
params
}
}
fn load_rms_norm(
tensors: &HashMap<String, Array>,
prefix: &str,
config: &Qwen3Config,
) -> Result<RmsNorm, InferenceError> {
let key = format!("{prefix}.weight");
let weight = tensors
.get(&key)
.ok_or_else(|| InferenceError::InferenceFailed(format!("missing {key}")))?;
Ok(RmsNorm {
weight: weight.clone(),
eps: config.rms_norm_eps,
})
}
pub struct MlxBackend {
model: Qwen3Model,
pub tokenizer: Tokenizer,
config: Qwen3Config,
}
unsafe impl Send for MlxBackend {}
unsafe impl Sync for MlxBackend {}
impl MlxBackend {
pub fn supports_capability(&self, cap: crate::schema::ModelCapability) -> bool {
use crate::schema::ModelCapability as C;
match cap {
C::Generate
| C::ToolUse
| C::MultiToolCall
| C::Reasoning
| C::Summarize
| C::Code
| C::Classify
| C::Embed => true,
C::Rerank => true,
C::Grounding => false,
C::Vision
| C::VideoUnderstanding
| C::AudioUnderstanding
| C::SpeechToText
| C::TextToSpeech
| C::ImageGeneration
| C::VideoGeneration => false,
}
}
pub fn load(model_dir: &Path) -> Result<Self, InferenceError> {
let config_path = model_dir.join("config.json");
let config: Qwen3Config = serde_json::from_str(
&std::fs::read_to_string(&config_path)
.map_err(|e| InferenceError::InferenceFailed(format!("read config.json: {e}")))?,
)
.map_err(|e| InferenceError::InferenceFailed(format!("parse config.json: {e}")))?;
info!(
hidden = config.hidden_size,
layers = config.num_hidden_layers,
heads = config.num_attention_heads,
head_dim = config.resolved_head_dim(),
vocab = config.vocab_size,
"loading Qwen3 model via MLX"
);
if config.use_sliding_window {
return Err(InferenceError::InferenceFailed(
"sliding-window attention (use_sliding_window=true) isn't \
implemented in the native MLX backend; route via vllm-mlx \
or a remote API"
.into(),
));
}
if let (Some(ne), Some(nek)) = (config.num_experts, config.num_experts_per_tok) {
info!(
num_experts = ne,
experts_per_tok = nek,
"MoE model detected"
);
}
if let Some(ref q) = config.quantization {
info!(bits = q.bits, group_size = q.group_size, "quantized model");
}
#[cfg(feature = "mlx-metal")]
let default_device = mlx_rs::Device::gpu();
#[cfg(not(feature = "mlx-metal"))]
let default_device = mlx_rs::Device::cpu();
match std::env::var("CAR_MLX_DEVICE").ok().as_deref() {
Some("cpu") => mlx_rs::Device::set_default(&mlx_rs::Device::cpu()),
#[cfg(feature = "mlx-metal")]
Some("gpu") => mlx_rs::Device::set_default(&mlx_rs::Device::gpu()),
_ => mlx_rs::Device::set_default(&default_device),
}
info!("loading safetensors weights");
let mut tensors = load_all_tensors(model_dir)?;
info!(tensors = tensors.len(), "tensors loaded");
if tensors.contains_key("language_model.model.embed_tokens.weight")
|| tensors.contains_key("language_model.model.embed_tokens.biases")
{
let before = tensors.len();
let mut stripped: HashMap<String, Array> = HashMap::with_capacity(tensors.len());
let mut dropped_prefixes: std::collections::HashMap<String, usize> = Default::default();
for (k, v) in tensors.drain() {
if let Some(rest) = k.strip_prefix("language_model.") {
stripped.insert(rest.to_string(), v);
continue;
}
let prefix = k.split('.').next().unwrap_or(&k).to_string();
*dropped_prefixes.entry(prefix).or_insert(0) += 1;
}
let dropped_total = before - stripped.len();
info!(
text_keys = stripped.len(),
dropped = dropped_total,
?dropped_prefixes,
"qwen2.5-vl: kept `language_model.*` subtree, dropped vision tensors (text-only path — see GH #58)"
);
tensors = stripped;
}
let model = Qwen3Model::from_tensors(&config, &tensors)?;
mlx_rs::transforms::eval(model.all_parameters())
.map_err(|e| InferenceError::InferenceFailed(format!("eval weights: {e}")))?;
let tokenizer_path = model_dir.join("tokenizer.json");
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| InferenceError::InferenceFailed(format!("load tokenizer: {e}")))?;
info!("MLX model loaded successfully");
Ok(Self {
model,
tokenizer,
config,
})
}
pub fn forward(&mut self, tokens: &[u32], pos: usize) -> Result<Vec<f32>, InferenceError> {
let seq_len = tokens.len();
let input = Array::from_slice(
&tokens.iter().map(|&t| t as i32).collect::<Vec<_>>(),
&[1, seq_len as i32],
);
let logits = self
.model
.forward(&input, pos)
.map_err(|e| InferenceError::InferenceFailed(format!("forward: {e}")))?;
let last_logits = logits.index((.., -1, ..));
let flat = ops::reshape(&last_logits, &[-1])
.map_err(|e| InferenceError::InferenceFailed(format!("reshape logits: {e}")))?;
flat.eval()
.map_err(|e| InferenceError::InferenceFailed(format!("eval logits: {e}")))?;
let data: &[f32] = flat.as_slice();
Ok(data.to_vec())
}
pub fn encode(&self, text: &str) -> Result<Vec<u32>, InferenceError> {
let encoding = self
.tokenizer
.encode(text, true)
.map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
Ok(encoding.get_ids().to_vec())
}
pub fn decode(&self, tokens: &[u32]) -> Result<String, InferenceError> {
self.tokenizer
.decode(tokens, true)
.map_err(|e| InferenceError::TokenizationError(e.to_string()))
}
pub fn tokenize_raw(&self, text: &str) -> Result<Vec<u32>, InferenceError> {
let encoding = self
.tokenizer
.encode(text, false)
.map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
Ok(encoding.get_ids().to_vec())
}
pub fn detokenize_raw(&self, tokens: &[u32]) -> Result<String, InferenceError> {
self.tokenizer
.decode(tokens, false)
.map_err(|e| InferenceError::TokenizationError(e.to_string()))
}
pub fn eos_token_id(&self) -> Option<u32> {
self.tokenizer
.token_to_id("<|endoftext|>")
.or_else(|| self.tokenizer.token_to_id("</s>"))
}
pub fn token_id(&self, token: &str) -> Option<u32> {
self.tokenizer.token_to_id(token)
}
pub fn context_length(&self) -> usize {
self.config.max_position_embeddings
}
pub fn clear_kv_cache(&mut self) {
self.model.clear_kv_cache();
}
pub fn hidden_size(&self) -> usize {
self.config.hidden_size
}
pub fn embed_one(&mut self, text: &str) -> Result<Vec<f32>, InferenceError> {
self.model.clear_kv_cache();
let tokens = self.encode(text)?;
if tokens.is_empty() {
return Ok(vec![0.0; self.config.hidden_size]);
}
let input = Array::from_slice(
&tokens.iter().map(|&t| t as i32).collect::<Vec<_>>(),
&[1, tokens.len() as i32],
);
let hidden = self
.model
.forward_hidden(&input, 0)
.map_err(|e| InferenceError::InferenceFailed(format!("forward_hidden: {e}")))?;
let last = hidden.index((.., -1, ..));
let flat = ops::reshape(&last, &[-1])
.map_err(|e| InferenceError::InferenceFailed(format!("reshape: {e}")))?;
flat.eval()
.map_err(|e| InferenceError::InferenceFailed(format!("eval: {e}")))?;
let data: &[f32] = flat.as_slice();
Ok(l2_normalize(data))
}
pub fn embed_query(
&mut self,
text: &str,
instruction: &str,
) -> Result<Vec<f32>, InferenceError> {
let formatted = format!("Instruct: {instruction}\nQuery: {text}");
self.embed_one(&formatted)
}
}
fn l2_normalize(v: &[f32]) -> Vec<f32> {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
v.iter().map(|x| x / norm).collect()
} else {
v.to_vec()
}
}