1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
//! Embedding normalization.
//!
//! Ported from `mlx-embeddings` `models/base.py::normalize_embeddings`
//! (`x / maximum(linalg.norm(x, ord=p, axis, keepdims), eps)`) and
//! `MLXEmbedders` `MLXArray+Helper.l2Normalized`.
//!
//! The general [`normalize`] is genuinely parameterized — the order `p`
//! is forwarded to `mlx.linalg.norm` (via [`crate::ops::linalg_full::norm`],
//! which wraps `mlx_linalg_norm`), so `p != 2` (L1, L∞, …) is a real
//! `mlx_linalg_norm` reduction, not a hand-rolled `sum(|x|^p)^(1/p)`.
use crate::;
use scalar_like;
/// python `mlx-embeddings` default normalization eps (`base.py`,
/// `normalize_embeddings(..., eps=1e-9)`). This is the project default:
/// `mlx-embeddings` is the primary embeddings reference.
pub const DEFAULT_NORMALIZE_EPS: f32 = 1e-9;
/// swift `MLXEmbedders` `l2Normalized` eps (`MLXArray+Helper.swift`,
/// `eps: Float = 1e-12`). Exposed for callers that want exact swift
/// parity; the crate default ([`DEFAULT_NORMALIZE_EPS`]) follows python.
pub const SWIFT_L2_EPS: f32 = 1e-12;
/// Parameterized vector normalization: `x / max(||x||_p, eps)`.
///
/// Mirrors `mlx-embeddings` `normalize_embeddings(embeddings, p, axis,
/// keepdims, eps)`. The norm is computed by `mlx.linalg.norm` (real
/// `ord=p` reduction, not a hand-rolled p-norm), then clamped from below
/// by `eps` (clamp-then-divide, matching both references — more stable
/// than `norm + eps`).
///
/// - `p`: norm order forwarded to `mlx_linalg_norm` (`2.0` = L2,
/// `1.0` = L1, `f64::INFINITY` = L∞, …).
/// - `axis`: reduction axis (python default `-1`).
/// - `keepdims`: keep the reduced axis (python default `true`); must be
/// `true` for the divide to broadcast back over `x` unless `axis` is
/// the last dim and `x` is 1-D.
/// - `eps`: divide-by-zero floor. Pass [`DEFAULT_NORMALIZE_EPS`] for the
/// python default or [`SWIFT_L2_EPS`] for swift parity.
/// L2-normalize along the last axis: `x / max(||x||_2, eps)`.
///
/// Convenience for `normalize(x, 2.0, -1, true, eps)`. `eps` defaults to
/// the python [`DEFAULT_NORMALIZE_EPS`] (`1e-9`) — note swift's
/// `l2Normalized` uses `1e-12` ([`SWIFT_L2_EPS`]); pass that explicitly
/// for byte-exact swift parity. Mirrors `mlx-embeddings`
/// `base.normalize_embeddings` (`p=2`, `axis=-1`, `keepdims=True`).
/// L2-normalize along the last axis with the python default eps
/// (`1e-9`). Back-compat shim for the pre-existing public API; identical
/// to `l2_normalize_eps(embeddings, DEFAULT_NORMALIZE_EPS)`.