1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
//! Cosine similarity helpers.
//!
//! The matrix form mirrors `mlx-embeddings` usage
//! (`mx.matmul(embeddings, embeddings.T)` over L2-normalized rows).
use crate::;
use ;
/// Validate the scalar [`cosine_similarity`] rank/length contract *before*
/// any arithmetic, so a wrong-rank caller gets a recoverable
/// [`Error::RankMismatch`] and an unequal-length caller gets
/// [`Error::LengthMismatch`] instead of a *silently broadcast*
/// (mathematically invalid) score. Mirrors the `pooling.rs`
/// `validate_token_embeddings_*` panic→`Err` precondition style.
///
/// Without this, MLX broadcasting lets e.g. `a=(3,)` against `b=(1,)`
/// produce a "cosine" `> 1` (the dot broadcasts `b`, but `||b||_2` is the
/// 1-element norm) — silent retrieval-ranking corruption on a dim/config
/// mismatch. Requires both `a` and `b` rank-1 with equal length. No
/// behavior change for valid equal-length 1-D inputs.
/// Cosine similarity of two 1-D vectors:
/// `dot(a, b) / (||a||_2 * ||b||_2)`, computed with a **numerically-stable
/// max-abs-scaled** formulation.
///
/// Identical vectors give `≈ 1.0`, orthogonal vectors `≈ 0.0`. Both inputs
/// must be 1-D with the same length; a non-rank-1 input or an unequal
/// length (which MLX would otherwise *silently broadcast* into an invalid
/// score) returns [`Err(Error::LengthMismatch)`](Error::LengthMismatch)
/// instead. A valid length-0 input (the rank/length validator treats
/// `(0,)` vs `(0,)` as equal-length rank-1) short-circuits to a finite
/// `0.0` (the empty-vector contract; no reduction over an empty array).
///
/// Accepts any float dtype (`f32`/`f16`/`bf16`); inputs are widened to
/// `f32` (lossless: f16/bf16→f32, a no-op for f32) and the result is
/// returned as `f32`.
///
/// ## Why max-abs scaling
/// A naïve `sqrt(sum(square(x)))` norm **underflows to `0`** for genuinely
/// tiny nonzero vectors (`square(1e-23) = 1e-46 → 0` in f32; the half
/// dtypes underflow far sooner) and **overflows to `+Inf`** for huge ones
/// (`square(f32::MAX) = +Inf`), so any zero/result predicate *derived from
/// that norm* misclassifies tiny nonzero vectors as zero or leaks
/// `0*Inf = NaN`. This fn instead first scales each vector by its
/// **max-abs (Chebyshev / ∞-norm)** `s = max(|x|)`:
/// - `s` is computed with `abs` + a full `max`-reduce only — **no
/// `square`**, so it is *exact* and free of underflow/overflow:
/// `s == 0.0` **iff** every element is *exactly* `0`.
/// - The **exact zero predicate** is therefore `max(|a|) == 0 ∨
/// max(|b|) == 0` (`logical_or(equal(s_a,0), equal(s_b,0))`), evaluated
/// on the max-abs scalars directly. It is `NaN`-free and **cannot** be
/// triggered by a nonzero vector no matter how tiny — that is the whole
/// point versus the prior L2-norm-derived predicate (which underflowed
/// `square` and misclassified tiny vectors, or produced `0*Inf = NaN`).
/// - After dividing by a div-by-zero-safe scale (`1.0` substituted where
/// the zero predicate holds, so the materialized branch never divides by
/// `0`), every scaled element has `|x̂_i| ≤ 1` with the max-magnitude
/// element exactly `1.0`, so `sum(square(x̂)) ∈ [1, n]` and
/// `‖x̂‖₂ ∈ [1, sqrt(n)]`: **no underflow to `0`, no overflow to
/// `+Inf`**. The dot/norms are thus computed entirely on `O(1)`-magnitude
/// data and `‖â‖₂·‖b̂‖₂` is bounded well away from both `0` and `+Inf`
/// for any realistic dimension.
/// - Cosine is **scale-invariant**, so dividing each vector by its own
/// positive scale leaves the cosine *exactly* unchanged: the result is
/// the exact scale-invariant cosine in `[-1, 1]` for **every** finite
/// vector pair from `≈1e-23` to `f32::MAX` (and f16/bf16), with
/// underflow / overflow / `0*Inf` all *structurally impossible*.
///
/// The conventional finite `0.0` is returned **only** for a genuine
/// all-zero vector (max-abs `== 0`, exact) or a length-0 input — never for
/// a nonzero vector, however tiny or huge. This is an **mlxrs-only
/// convenience** with **no python/swift/mlx-c reference** (parity-audited),
/// so it deliberately uses this stable formulation: there is no
/// bit-identity-to-reference constraint, only a correct, robust, terminal
/// cosine. This is *intentionally distinct* from the python-faithful
/// dtype-aware weak-scalar eps used by
/// [`normalize`](crate::embeddings::normalize())/[`l2_normalize`]/
/// [`cosine_similarity_matrix`] (which mirror python `mx.maximum(norm,
/// eps)` — reference-faithful, a `f16` zero vector → `NaN` there on
/// purpose; their unconditional clamp is **not** replicated here).
/// Pairwise cosine similarity matrix for a `(n, d)` batch of row vectors.
///
/// Rows are L2-normalized, then `normalized @ normalized.T` yields the
/// `(n, n)` similarity matrix (diagonal `≈ 1.0`). Mirrors the
/// `mx.matmul(embeddings, embeddings.T)` pattern in `mlx-embeddings`.
///
/// The L2-normalize step
/// uses [`l2_normalize`]'s python-faithful dtype-aware weak-scalar eps, so
/// an all-zero **f16**/**bf16** row normalizes to `NaN` (the default
/// `1e-9` eps underflows below the half subnormal floor — see the
/// [`normalize`](crate::embeddings::normalize()) body comment). This is
/// **intentional python
/// `mlx-embeddings` / MLX per-dtype parity** (reference-faithful,
/// explicitly asserted by
/// `tests/embeddings.rs::normalize_zero_vector_f16_bf16_eps_floor_in_-
/// dtype`), NOT a defect: flooring in f32 here would diverge from the
/// python reference. This is *deliberately distinct* from the scalar
/// [`cosine_similarity`] (an mlxrs-only convenience with no python
/// reference, which DOES guarantee finite `0.0` for f16/bf16 zero vectors
/// via an f32-guarded final divide). The two are intentionally
/// different by design (matrix path = reference parity; scalar =
/// finite-0.0 convenience), not contradictory.