1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
//! Embedding-scoped `mlx_fast` wrappers.
//!
//! `mlxrs` has no general `ops::fast` module yet (the fast-ops port is
//! out of M3 scope), so the two fused-norm primitives that the
//! [`pool`](super::pool) dispatcher applies to the *pooled* sentence
//! vector (post-pooling, before matryoshka truncation / L2-normalize —
//! swift `Pooling`'s `applyLayerNorm` step) are surfaced here, bounded
//! to embedding use. These are *not* the model's internal token-level
//! normalization (per-architecture, out of scope):
//!
//! - [`layer_norm`] → `mlx_fast_layer_norm` (backs the dispatcher's
//! `apply_layer_norm` flag — swift `MLXFast.layerNorm`, eps `1e-5`).
//! - [`rms_norm`] → `mlx_fast_rms_norm` (an RMSNorm post-pool variant —
//! some embedding backbones, e.g. gemma/llama-bidirec, normalize the
//! pooled vector with RMSNorm rather than LayerNorm).
//!
//! These are deliberately *not* a general `mlx_fast` port. Only the two
//! norm fns are wrapped; `rope`, the metal/cuda custom-kernel surface,
//! `scaled_dot_product_attention`, etc. are intentionally skipped — they
//! are not embedding-pooling support surface.
use crate::;
/// Optional affine weight/bias forwarded to a fused-norm call.
///
/// mlx-c's `mlx_fast_layer_norm` / `mlx_fast_rms_norm` accept the
/// `weight`/`bias` handles as "may be null"; a fresh empty `mlx_array`
/// (`mlx_array_new()`) *is* the null handle per the mlx-c convention, so
/// `None` maps to that and the kernel runs the un-affine path.
/// Fused Layer Normalization over the last axis: `mlx_fast_layer_norm`.
///
/// `(x - mean) / sqrt(var + eps)`, optionally affine-scaled by `weight`
/// and shifted by `bias` (both `None` ⇒ the plain normalize path, which
/// is what the pooling dispatcher's `apply_layer_norm` uses). Mirrors
/// swift `MLXEmbedders` `Pooling.callAsFunction(applyLayerNorm:)`'s
/// `MLXFast.layerNorm(pooled, eps: 1e-5)` — hence the `1e-5` default at
/// the call site.
///
/// - `x`: any float array; normalization is over the last dim.
/// - `weight` / `bias`: optional `(hidden,)` affine params.
/// - `eps`: variance floor (swift uses `1e-5`).
/// Fused Root-Mean-Square Normalization over the last axis:
/// `mlx_fast_rms_norm`.
///
/// `x / sqrt(mean(x^2) + eps)`, optionally affine-scaled by `weight`
/// (`None` ⇒ the plain RMSNorm path used by the dispatcher's
/// `apply_rms_norm`). RMSNorm has no `bias`. Provided because several
/// embedding backbones (gemma, llama-bidirec) RMS-normalize rather than
/// LayerNorm-normalize before pooling.
///
/// - `x`: any float array; normalization is over the last dim.
/// - `weight`: optional `(hidden,)` affine scale.
/// - `eps`: variance floor.