1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
//! Attention-mask construction, ported 1:1 from mlx-lm.
//!
//! - [`create_causal_mask`] is `mlx_lm.models.base.create_causal_mask`
//! (the offset + optional sliding-`window_size` subset PR-0 needs; the
//! `left/right_padding` batched args land with the batched caches in a
//! later PR).
//! - [`create_attention_mask`] is `mlx_lm.models.cache.create_attention_mask`
//! (`cache.py:114-126`): it returns the symbolic [`MaskMode::Causal`] when
//! a materialized array is unnecessary, mirroring mlx-lm's `"causal"`
//! sentinel, and cross-checked against mlx-swift-lm's
//! `ScaledDotProductAttentionMaskMode`.
//!
//! No implicit eval: every op is a pure [`crate::ops`] composition.
use crate::;
use format_smolstr;
/// The largest integer `f32` represents exactly. `f32` has a 24-bit
/// significand, so every integer in `[0, 2^24]` round-trips losslessly;
/// `2^24 + 1` is the first integer that aliases (it shares its bit pattern
/// with `2^24`, rounding *down* to `2^24`). [`iarange`] builds positions
/// through `f32` ([`Array::arange`] is `f32`-only) **and** casts its own
/// exclusive `stop` to `f32`, so `stop > 2^24` would round the bound and
/// silently corrupt (truncate) the causal/window mask.
const F32_EXACT_INT_MAX: usize = 1usize << 24;
/// A 1-element `I32` scalar (mlx-lm's weak Python int `window_size`),
/// broadcast against the index grids — built without eval.
pub
/// 1-D `I32` `[start, stop)` index vector — mlx-lm's `mx.arange(...)`
/// (integer).
///
/// [`Array::arange`] is `f32`-only (the safe ops surface has no integer
/// `arange`; adding one is out of this PR's scope), so positions are built
/// through `f32` and cast back to `I32`. Crucially, the **exclusive `stop`
/// itself is cast to `f32`** to call `Array::arange::<f32>(start, stop, 1.0)`. `f32`
/// represents every integer in `[0, 2^24]` exactly (24-bit significand) and
/// rounds `2^24 + 1` *down* to `2^24`. So the bound rejects `stop > 2^24`
/// (strictly): for the maximum allowed `stop == 2^24`, the `stop` cast is
/// exact (so the element count `stop - start` is exact) **and** every
/// produced value lies in `[start, stop - 1] ⊆ [0, 2^24 - 1]`, each exactly
/// representable, so the `f32 -> I32` round-trip is lossless and the result
/// feeds an `I32`/`Bool` grid identical to mlx-lm's. Were the bound only on
/// the largest produced value (`stop - 1 > 2^24`), `stop == 2^24 + 1` would
/// pass yet `(2^24 + 1) as f32 == 2^24`, so `arange` would stop one element
/// short and silently emit a too-short (corrupt) mask. An out-of-range
/// `stop` is therefore **rejected** (a recoverable [`Error::OutOfRange`])
/// rather than truncated — a too-long cache context surfaces an error here
/// instead of a wrong mask.
pub
/// Port of `mx.roll(a, shift=shift)` for the 1-D `[L]` arrays
/// `RotatingKVCache.make_mask` rolls (`cache.py:577`). `crate::ops` has no
/// native `roll`, so it is composed faithfully: mlx defines
/// `out[i] = a[(i - shift) mod L]`, i.e. a positive `shift` moves elements
/// toward higher indices with wrap, which is exactly
/// `concat([a[L-s:], a[:L-s]])` for `s = shift mod L` (and the identity when
/// `s == 0`). Built with the same `crate::ops` slice/concatenate idioms the
/// rest of this module uses; no implicit eval.
pub
/// Port of `mlx_lm.models.base.create_causal_mask` (the offset + sliding
/// `window_size` subset).
///
/// ```text
/// rinds = mx.arange(offset + N)
/// linds = mx.arange(offset, offset + N) if offset else rinds
/// linds = linds[:, None]; rinds = rinds[None]
/// mask = linds >= rinds
/// if window_size is not None:
/// mask = mask & (linds < rinds + window_size)
/// ```
///
/// Returns the boolean `[N, offset + N]` causal (optionally windowed) mask.
///
/// `offset + N` is computed with [`usize::checked_add`] *before* any range is
/// built: a hostile/corrupt loaded `offset` (mlx-lm's prompt-cache
/// `set_meta_state`) could otherwise overflow → a debug panic, or a release
/// wrap to a small value that then *passes* `iarange`'s `2^24` check and
/// silently produces a wrong mask. The overflow is a recoverable
/// [`Error::ArithmeticOverflow`] instead (behavior is identical for every valid
/// input — `offset + N ≤ 2^24` always reaches `iarange` unchanged).
///
/// `window_size` keeps mlx-lm's unbounded-Python-int semantics: there
/// `mask & (linds < rinds + window_size)` makes a `window_size` at least the
/// full index range a no-op (the term is always true). The largest
/// `rinds + window_size` compares against an `linds` in `[0, total)`, so a
/// `window_size >= total` cannot mask any valid position — the windowing term
/// is skipped entirely (the plain causal mask). Otherwise
/// `window_size < total ≤ 2^24 < i32::MAX`, so the `as i32` cast is exact;
/// this both mirrors mlx-lm and removes the lossy-cast hazard of a
/// `window_size > i32::MAX` wrapping.
/// Port of `mlx_lm.models.cache.create_attention_mask` (`cache.py:114-126`):
///
/// ```text
/// if window_size is not None: -> create_causal_mask(N, offset, window_size)
/// elif N == 1: -> None
/// elif return_array: -> create_causal_mask(N, offset, None)
/// else: -> "causal"
/// ```
///
/// The `"causal"` sentinel maps to [`MaskMode::Causal`]; a materialized mask
/// to [`MaskMode::Array`]; the `N == 1` no-mask case to [`MaskMode::None`]
/// (mlx-swift-lm's `.none` / `.causal` / `.array`).