chat-mlx 0.0.0

Local-inference chat-rs provider (and CLI) for MiniCPM5 / Llama / Qwen models on Apple Silicon via MLX.
docs.rs failed to build chat-mlx-0.0.0
Please check the build logs for more information.
See Builds for ideas on how to fix a failed build, or Metadata for how to configure docs.rs builds.
If you believe this is docs.rs' fault, open an issue.

chat-mlx

Local-inference chat-rs provider (and CLI) for MiniCPM5 / Llama / Qwen-family models on Apple Silicon, via MLX. It implements CompletionProvider + StreamProvider, so it drops into chat_core::ChatBuilder and participates in the tool-calling, structured-output, and streaming chat loop — the same surface as chat-claude, chat-openai, etc.

It owns the raw token loop (no daemon, no HTTP): tokenization, chat-templating, sampling, KV cache, tool-call parsing, and JSON-constrained decoding all happen in-process. The closest sibling is chat-mistralrs, but that is a thin wrapper over mistral.rs and rejects tools/structured output; here we implement the parsing and templating it hides.

lib + thin bin: the chat_mlx library exposes the provider; the chat-mlx binary is a small CLI over it. Depends on chat-core directly ({ path = "../chat-rs/core", version = "0.4.2" }).

Layout

src/engine/         the inference core (no chat-rs types)
  config.rs         parse HF config.json / generation_config.json -> ModelArgs
  model.rs          the architecture: Attention/Mlp/Decoder composed from mlx-rs nn modules
  cache.rs          KV cache: pre-allocated growable + rotating attention-sink window
  sampler.rs        on-device greedy / temperature / top-k / top-p sampling
  generate.rs       prefill + decode loop; plus generate_constrained (logit-masked)
  constraint.rs     LogitMask trait (constrained decoding hook)
  template.rs       ChatML rendering
src/loader.rs       HF download + weight load + tie_lm_head (shared by CLI and builder)
src/builder.rs      MlxBuilder (type-state) -> MlxClient
src/client.rs       MlxClient (Arc<Mutex<Model>>, Clone) + StructuredMode
src/api/            CompletionProvider / StreamProvider impls + request/response mapping
src/parsers/        reasoning (<think>), tool (families + stripper), json (validator+mask), structured
src/main.rs         CLI over the lib

mlx-rs ships the nn bricks (Linear, Rope, RmsNorm, Embedding, fast SDPA); we compose the architecture. Struct field names in model.rs mirror HF tensor keys so load_safetensors maps official weights with no manual remapping.

Use as a provider

use chat_core::builder::ChatBuilder;
use chat_core::types::messages::{Messages, content};
use chat_core::parts;
use chat_mlx::MlxBuilder;

let client = MlxBuilder::new().with_model("Qwen/Qwen3-0.6B").build()?;
let mut chat = ChatBuilder::new().with_model(client).build();

let mut msgs = Messages::default();
msgs.push(content::from_user(parts!["Explain RoPE in one sentence."]));
let out = chat.complete(&mut msgs).await?; // ChatOutcome<ChatResponse>

Builder knobs: with_quantize(bool), with_max_context(i32), with_sink_tokens, with_tokens_per_eval, with_tool_format / with_tool_pattern, with_structured_mode.

Tool calling

Register tools-rs tools on the ChatBuilder as usual. The provider advertises them in the prompt and parses the model's calls back out. Formats are families:

  • Hermes/Qwen (<tool_call>{…}</tool_call>) — auto-detected, the default.
  • Custom patternMlxBuilder::with_tool_pattern(open, close): we strip the delimiters and parse the JSON inside.

Parsed calls become PartEnum::Tool with complete_reason = ToolCall, which the chat loop executes and feeds back. Streaming hides the in-progress call markup and surfaces StreamEvent::ToolCall / ToolResult instead.

Structured output

ChatBuilder::with_structured_output::<T>() (T: JsonSchema + Deserialize) works two ways, selected by MlxBuilder::with_structured_mode:

  • StructuredMode::Prompt (default) — inject the schema into the prompt, parse the emitted JSON; the chat loop retries on a parse miss.
  • StructuredMode::Constrained — mask logits each decode step so only tokens keeping the output a valid-JSON prefix can be sampled; EOS only once a complete value is formed. Guarantees well-formed JSON (the schema's types/required fields are still validated on the typed deserialize). It enforces JSON syntax; full schema-level masking would need a grammar engine (llguidance/outlines).

Streaming structured output isn't expressible through chat-core's StreamProvider::stream (it carries no schema), so the provider exposes a native MlxClient::stream_structured(messages, options, schema, mode) that streams the JSON live and emits a final StreamEvent::Structured.

CLI

First build is heavy (compiles the MLX C++/Metal backend; needs Xcode CLT + cmake).

cargo run --release -- --prompt "Explain RoPE in one sentence."
# flags: --model --system --max-tokens --temp --top-k --top-p --quantize
#        --tokens-per-eval --max-context --sink-tokens --seed
cargo run --release -- --model Qwen/Qwen3-0.6B --prompt ""

Default model: openbmb/MiniCPM5-1B (bf16). The 4-bit *-MLX repos need pre-quantized loading (TODO); for now load bf16 and pass --quantize to quantize at runtime.

Examples

cargo run --release --example chat        -- --model Qwen/Qwen3-0.6B --temp 0.7 --top-k 40
cargo run --release --example structured   -- Qwen/Qwen2.5-1.5B-Instruct
  • chat — interactive streaming REPL: multi-turn, gray <think> reasoning, the get_weather tool round-trip shown inline.
  • structured — streaming structured output in both modes side by side.

Supported model families

Architecture is config-driven (config.json), no per-family source files:

  • Llama / MiniCPM — GQA, SwiGLU, RoPE; bias per the attention_bias flag.
  • Qwen2 / Qwen2.5 (model_type == "qwen2") — adds QKV bias (output projection unbiased) and tied embeddings (no shipped lm_head.weight; tie_lm_head shares the input embedding).
  • Qwen3 (model_type == "qwen3") — per-head QK-Norm, no QKV bias.

KV cache

Decode used to slow down quadratically because every token re-concatenated the whole K/V tensor. engine/cache.rs now pre-allocates the K/V buffers and grows them in 256-token chunks, writing each new token into a slice. With --max-context (default 4096) it becomes a rotating attention-sink cache: the first --sink-tokens (default 4) are pinned, the rest of the window rotates ring-buffer style, so KV memory is bounded regardless of generation length. --max-context 0 disables the cap. Because RoPE is relative, pinned sinks + the recent window keep correct relative positions to the current query.

Sampling

Entirely on-device and lazy, so it composes into MLX's batched eval graph instead of forcing a per-token host sync: temperature → categorical, top-k via argpartition masking, top-p via sorted cumsum + which. Determinism from --seed. Sampling method doesn't affect decode throughput. Constrained decoding (generate_constrained) adds a per-step additive logit mask before sampling.

Perf (MiniCPM5-1B, M-series, 256-token decode)

build / config decode tok/s
debug ~8
release (bf16) ~28
release + --quantize (4-bit), tpe=8 ~87
release + --quantize (4-bit), tpe=16 ~85

Decode is memory-bandwidth bound on the per-token weight read, so quantization is the dominant lever (~3×). --tokens-per-eval is now ~flat: with the sampler moved on-device there's no per-token GPU↔host sync left to amortize by batching. Constrained decoding is slower (one token per eval + a vocab-sized mask each step) — fine for short structured extractions.

GGUF feasibility study — "can we build the blocks on GGUF?"

Open research item. Today chat-mlx loads bf16 safetensors and can runtime-quantize to 4-bit via MLX's group-affine nn::quantize (--quantize). Can we instead load GGUF directly so the blocks (MaybeQuantized linears) consume it?

  • mlx-rs 0.25 has no GGUF reader — only .safetensors / .npy. No gguf symbols in mlx-rs or mlx-sys; MLX's C++ mlx_load_gguf is unbound in Rust.
  • MLX's quantization ≠ GGUF's — MLX is group-affine; GGUF is k-quants (Q4_K_M, …). No 1:1 mapping; a transcode is required.
  • The blocks are already quant-capable — the gap is purely the loader/transcode.

Options, increasing effort/payoff: (1) pure-Rust GGUF reader → dequantize to bf16/fp16 on load → feed the existing Model (lowest risk; e.g. candle's reader); (2) FFI to mlx_load_gguf if upstream exposes it; (3) transcode k-quant blocks → MLX quantized layers (best memory, most work). Pragmatic first step: (1) behind a --gguf flag.

Status

  • bf16 generate (MiniCPM5-1B / Qwen3 / Qwen2.5), coherent output
  • on-device top-p / top-k / temperature sampling (seeded)
  • KV cache rewrite: growable + rotating attention-sink window (bounded memory)
  • config-driven QKV bias + tied embeddings (Qwen2.5, Llama)
  • chat-rs provider: completion, streaming, tool families, structured output
  • constrained (valid-JSON) decoding via logit masking
  • load pre-quantized *-MLX safetensors directly
  • GGUF loader (study above)
  • more families (Llama-3, Mistral) + a load matrix
  • a custom fused Metal kernel experiment