mlxrs 0.1.0

Safe Rust bindings for Apple's MLX array framework, with LM, VLM, audio, and embeddings support
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
//! Happy-path + edge-case tests for `mlxrs::audio::dsp` (Hann window,
//! STFT, mel filterbank, mel + log-mel spectrogram).

#![cfg(feature = "audio")]

use std::f32::consts::PI;

use mlxrs::{
  Array, Dtype,
  audio::dsp::{
    LogFloor, MelPrecision, WindowPad, hann_window, log_mel_spectrogram, log_mel_spectrogram_with,
    mel_filter_bank, mel_filter_bank_with, mel_spectrogram, stft,
  },
};

/// 16-sample sine at 1 kHz, sample_rate 16 kHz (matches the `mlx-audio`
/// happy-path inputs used in upstream tests).
fn sine_1khz_16samples() -> Array {
  let sr = 16_000.0_f32;
  let f = 1_000.0_f32;
  let buf: Vec<f32> = (0..16)
    .map(|n| (2.0 * PI * f * n as f32 / sr).sin())
    .collect();
  Array::from_slice::<f32>(&buf, &[16i32]).unwrap()
}

#[test]
fn hann_window_endpoints_are_zero() {
  let mut w = hann_window(8).unwrap();
  let v = w.to_vec::<f32>().unwrap();
  assert!(
    (v[0]).abs() < 1e-6,
    "first sample should be 0, got {}",
    v[0]
  );
  assert!((v[7]).abs() < 1e-6, "last sample should be 0, got {}", v[7]);
}

#[test]
fn hann_window_is_symmetric() {
  let mut w = hann_window(9).unwrap();
  let v = w.to_vec::<f32>().unwrap();
  for k in 0..v.len() / 2 {
    let mirror = v.len() - 1 - k;
    assert!(
      (v[k] - v[mirror]).abs() < 1e-6,
      "asymmetric: v[{k}]={} vs v[{mirror}]={}",
      v[k],
      v[mirror]
    );
  }
}

#[test]
fn hann_window_rejects_n_lt_2() {
  assert!(matches!(hann_window(0), Err(mlxrs::Error::OutOfRange(_))));
  assert!(matches!(hann_window(1), Err(mlxrs::Error::OutOfRange(_))));
}

#[test]
fn hann_window_center_value_is_one() {
  // For odd n, the middle sample is `0.5 * (1 - cos(π)) = 1.0`.
  let mut w = hann_window(9).unwrap();
  let v = w.to_vec::<f32>().unwrap();
  assert!(
    (v[4] - 1.0).abs() < 1e-5,
    "center should be ~1, got {}",
    v[4]
  );
}

#[test]
fn stft_shape_matches_formula() {
  // For n_fft=8, hop=4, samples=16 with center=True (pad=4 each side),
  // padded_len = 24, num_frames = 1 + (24 - 8) / 4 = 5.
  let x = sine_1khz_16samples();
  // `stft` now returns a typed `Spectrum`; its transform array is `.data()`.
  let s = stft(&x, 8, 4, None, WindowPad::Center).unwrap();
  assert_eq!(s.data_ref().shape(), vec![5, 5]); // (num_frames, n_fft/2 + 1)
  assert_eq!(s.data_ref().dtype().unwrap(), Dtype::Complex64);
  // The metadata travels in the type (no inference downstream in istft).
  assert_eq!(s.n_fft(), 8);
  assert_eq!(s.hop_length(), 4);
  assert_eq!(s.win_length(), 8); // defaults to n_fft
  assert!(s.center());
}

#[test]
fn stft_rejects_zero_n_fft() {
  let x = sine_1khz_16samples();
  let r = stft(&x, 0, 4, None, WindowPad::Center);
  assert!(matches!(r, Err(mlxrs::Error::InvariantViolation(_))));
}

#[test]
fn stft_rejects_zero_hop_length() {
  let x = sine_1khz_16samples();
  let r = stft(&x, 8, 0, None, WindowPad::Center);
  assert!(matches!(r, Err(mlxrs::Error::InvariantViolation(_))));
}

#[test]
fn stft_rejects_win_length_greater_than_n_fft() {
  let x = sine_1khz_16samples();
  let r = stft(&x, 8, 4, Some(16), WindowPad::Center);
  assert!(matches!(r, Err(mlxrs::Error::OutOfRange(_))));
}

#[test]
fn stft_minimum_valid_input_boundary_padding_to_index_zero() {
  // Regression: when samples.len() == n_fft / 2 + 1, the reflect padding
  // suffix must include samples[0] (padding == samples.len() - 1). A
  // formulation using stop=-1 would have mlx slice post-normalize it to
  // len-1 with negative strides, silently dropping the right edge.
  //
  // For n_fft=8, samples_len = 8/2 + 1 = 5: padding = 4, suffix indices
  // are [3, 2, 1, 0] (4 elements). Post-pad length = 5 + 4 + 4 = 13.
  // num_frames = 1 + (13 - 8) / 4 = 2 with default hop = n_fft/4 = 2... let
  // we use hop=4 to keep the math obvious: num_frames = 1 + (13 - 8) / 4 = 2.
  let buf: Vec<f32> = (0..5).map(|i| i as f32).collect();
  let x = Array::from_slice::<f32>(&buf, &[5i32]).unwrap();
  let s = stft(&x, 8, 4, None, WindowPad::Center).unwrap();
  // Shape proves the suffix was the full 4 elements (otherwise padded_len
  // would have been 12 → num_frames = 2 still, but a value-level check via
  // the reflect-pad output itself is cleaner — assert via shape + dtype
  // and the fact that the call succeeds without going through the
  // too-short-error path).
  assert_eq!(s.data_ref().shape(), vec![2, 5]); // (num_frames, n_fft/2 + 1)
  assert_eq!(s.data_ref().dtype().unwrap(), Dtype::Complex64);
}

#[test]
fn stft_rejects_input_too_short_for_reflect_pad() {
  // n_fft=16, pad=8, but input has only 4 samples — reflect needs len >= pad+1.
  let buf = vec![0.0_f32, 0.1, 0.2, 0.3];
  let x = Array::from_slice::<f32>(&buf, &[4i32]).unwrap();
  let r = stft(&x, 16, 8, None, WindowPad::Center);
  assert!(matches!(r, Err(mlxrs::Error::OutOfRange(_))));
}

#[test]
fn stft_win_length_shorter_than_n_fft_zero_pads_window() {
  // win_length=4 with n_fft=8 + WindowPad::Right zero-pads the window to
  // length 8 on the right side. Shape stays `(num_frames, n_fft/2+1)`.
  let x = sine_1khz_16samples();
  let s = stft(&x, 8, 4, Some(4), WindowPad::Right).unwrap();
  assert_eq!(s.data_ref().shape(), vec![5, 5]);
  assert_eq!(s.win_length(), 4); // the short win_length is carried on the type
}

#[test]
fn mel_filter_bank_shape_matches_n_mels_x_n_freqs() {
  // n_fft=400 (Whisper default), n_mels=80 → shape (80, 201).
  let bank = mel_filter_bank(80, 400, 16_000, 0.0, None).unwrap();
  assert_eq!(bank.shape(), vec![80, 201]);
}

#[test]
fn mel_filter_bank_rejects_zero_n_mels() {
  let r = mel_filter_bank(0, 400, 16_000, 0.0, None);
  assert!(matches!(r, Err(mlxrs::Error::InvariantViolation(_))));
}

#[test]
fn mel_filter_bank_rejects_invalid_freq_range() {
  // f_max <= f_min is invalid.
  let r = mel_filter_bank(40, 400, 16_000, 1000.0, Some(500.0));
  assert!(matches!(r, Err(mlxrs::Error::OutOfRange(_))));
}

#[test]
fn mel_filter_bank_rejects_usize_overflow_inputs() {
  // n_mels = usize::MAX → n_mels + 2 overflows.
  let r = mel_filter_bank(usize::MAX, 400, 16_000, 0.0, None);
  assert!(matches!(r, Err(mlxrs::Error::ArithmeticOverflow(_))));
  // n_mels * n_freqs overflows (with n_fft very large so n_freqs is huge).
  // Pick n_mels and n_fft such that n_mels.checked_mul(n_freqs) returns None
  // but n_mels.checked_add(2) succeeds.
  let big_n_mels = 1usize << 33;
  let big_n_fft = 1usize << 34;
  let r = mel_filter_bank(big_n_mels, big_n_fft, 16_000, 0.0, Some(8_000.0));
  assert!(matches!(r, Err(mlxrs::Error::ArithmeticOverflow(_))));
}

#[test]
fn mel_filter_bank_values_are_nonneg() {
  let mut bank = mel_filter_bank(8, 64, 16_000, 0.0, None).unwrap();
  for v in bank.to_vec::<f32>().unwrap() {
    assert!(v >= 0.0, "negative mel weight: {v}");
  }
}

// ---- precise (f64) mel filterbank (#291) ---------------------------------

/// The precise (f64) bank has the same `(n_mels, n_freqs)` shape as the f32
/// path. `mel_filter_bank_with(.., Standard)` is also the same shape as the
/// shorthand `mel_filter_bank`.
#[test]
fn mel_filter_bank_precise_shape_matches() {
  let std_bank = mel_filter_bank_with(80, 400, 16_000, 0.0, None, MelPrecision::Standard).unwrap();
  let precise = mel_filter_bank_with(80, 400, 16_000, 0.0, None, MelPrecision::Precise).unwrap();
  assert_eq!(std_bank.shape(), vec![80, 201]);
  assert_eq!(precise.shape(), vec![80, 201]);
}

/// `mel_filter_bank_with(.., Standard)` is byte-identical to the historic
/// `mel_filter_bank` shorthand (the f32 path is unchanged by #291).
#[test]
fn mel_filter_bank_with_standard_matches_shorthand() {
  let mut shorthand = mel_filter_bank(80, 400, 16_000, 0.0, None).unwrap();
  let mut with_std =
    mel_filter_bank_with(80, 400, 16_000, 0.0, None, MelPrecision::Standard).unwrap();
  assert_eq!(
    shorthand.to_vec::<f32>().unwrap(),
    with_std.to_vec::<f32>().unwrap(),
    "Standard precision must match the f32 shorthand bit-for-bit"
  );
}

/// The whole point of the precise mode: the f64-computed bank is NOT
/// bit-identical to the f32 bank (it tracks a float64 reference more
/// closely), yet stays within a small tolerance of it — the f32 path drifts
/// only ~5e-6 from the precise reference, so the two banks must be close.
#[test]
fn mel_filter_bank_precise_differs_but_close() {
  let mut std_bank =
    mel_filter_bank_with(80, 400, 16_000, 0.0, None, MelPrecision::Standard).unwrap();
  let mut precise =
    mel_filter_bank_with(80, 400, 16_000, 0.0, None, MelPrecision::Precise).unwrap();
  let s = std_bank.to_vec::<f32>().unwrap();
  let p = precise.to_vec::<f32>().unwrap();
  assert_eq!(s.len(), p.len(), "shape mismatch between precisions");

  // NOT bit-identical — at least one cell must differ (the precise path
  // recomputes the divisions / mel scale in f64, so rounding diverges).
  assert!(
    s.iter().zip(&p).any(|(a, b)| a.to_bits() != b.to_bits()),
    "precise bank must differ from the f32 bank (otherwise the f64 path is a no-op)"
  );

  // ...but close: the f32 path is documented to drift only ~5e-6 from the
  // float64 reference. 1e-4 is a comfortable ceiling on the [0, 1] weights.
  let max_abs = s
    .iter()
    .zip(&p)
    .map(|(a, b)| (a - b).abs())
    .fold(0.0_f32, f32::max);
  assert!(
    max_abs < 1e-4,
    "precise vs f32 max abs diff {max_abs} exceeds 1e-4"
  );
}

/// The precise bank stays non-negative (the triangular filter clamps via
/// `.max(0.0)` in f64 just as the f32 path does).
#[test]
fn mel_filter_bank_precise_values_are_nonneg() {
  let mut bank = mel_filter_bank_with(8, 64, 16_000, 0.0, None, MelPrecision::Precise).unwrap();
  for v in bank.to_vec::<f32>().unwrap() {
    assert!(v >= 0.0, "negative precise mel weight: {v}");
  }
}

/// The precise path enforces the same validation as the f32 path (shared
/// guards in `mel_filter_bank_with`).
#[test]
fn mel_filter_bank_precise_rejects_invalid_inputs() {
  assert!(matches!(
    mel_filter_bank_with(0, 400, 16_000, 0.0, None, MelPrecision::Precise),
    Err(mlxrs::Error::InvariantViolation(_))
  ));
  assert!(matches!(
    mel_filter_bank_with(40, 400, 16_000, 1000.0, Some(500.0), MelPrecision::Precise),
    Err(mlxrs::Error::OutOfRange(_))
  ));
}

#[test]
fn mel_spectrogram_is_nonneg_for_real_input() {
  let x = sine_1khz_16samples();
  let mut m = mel_spectrogram(&x, 8, 4, None, 4, 16_000, 0.0, None).unwrap();
  // Output: (n_mels, num_frames) = (4, 5).
  assert_eq!(m.shape(), vec![4, 5]);
  for v in m.to_vec::<f32>().unwrap() {
    assert!(v >= 0.0, "mel spec must be non-negative, got {v}");
  }
}

#[test]
fn log_mel_spectrogram_is_finite_for_silence() {
  // All-zero input — without the eps floor this would produce `log(0) = -inf`.
  let zeros = Array::zeros::<f32>(&(64usize,)).unwrap();
  let mut m = log_mel_spectrogram(&zeros, 16, 8, None, 4, 16_000, 0.0, None).unwrap();
  let v = m.to_vec::<f32>().unwrap();
  for x in &v {
    assert!(x.is_finite(), "log-mel must be finite (eps floor), got {x}");
  }
  // Every entry must equal `ln(1e-10) ≈ -23.0259`.
  let expected = (1e-10_f32).ln();
  for x in &v {
    assert!(
      (x - expected).abs() < 1e-3,
      "silence log-mel should equal ln(eps)={expected}, got {x}"
    );
  }
}

#[test]
fn log_mel_spectrogram_is_finite_for_sine_input() {
  let x = sine_1khz_16samples();
  let mut m = log_mel_spectrogram(&x, 8, 4, None, 4, 16_000, 0.0, None).unwrap();
  for v in m.to_vec::<f32>().unwrap() {
    assert!(v.is_finite(), "log-mel must be finite, got {v}");
  }
}

#[test]
fn log_floor_whisper_matches_1e_10() {
  // Bit-exact comparison: Whisper variant must produce exactly `1e-10_f32`.
  assert_eq!(LogFloor::Whisper.value(), 1e-10_f32);
  // Default should also be Whisper.
  assert_eq!(LogFloor::default().value(), 1e-10_f32);
}

#[test]
fn log_floor_kaldi_matches_mlx_audio_1e_8() {
  // Bit-exact comparison: Kaldi variant must produce exactly `1e-8_f32`,
  // matching `mlx-audio/mlx_audio/dsp.py:950` (the `get_mel_banks_kaldi`
  // path). This is deliberately mlx-audio's literal, NOT upstream
  // kaldi-asr's `f32::EPSILON` — see [`LogFloor::Kaldi`] docs.
  assert_eq!(LogFloor::Kaldi.value(), 1e-8_f32);
}

#[test]
fn log_floor_custom_clamps_nonpositive_and_nonfinite_to_min_positive() {
  // Negative, zero, NaN, and inf inputs all clamp to `f32::MIN_POSITIVE`
  // so the resulting `log(floor)` is always finite.
  assert_eq!(LogFloor::Custom(-1.0).value(), f32::MIN_POSITIVE);
  assert_eq!(LogFloor::Custom(0.0).value(), f32::MIN_POSITIVE);
  assert_eq!(LogFloor::Custom(-0.0).value(), f32::MIN_POSITIVE);
  assert_eq!(LogFloor::Custom(f32::NAN).value(), f32::MIN_POSITIVE);
  assert_eq!(LogFloor::Custom(f32::INFINITY).value(), f32::MIN_POSITIVE);
  assert_eq!(
    LogFloor::Custom(f32::NEG_INFINITY).value(),
    f32::MIN_POSITIVE
  );

  // Valid positive customs pass through unchanged.
  let v = LogFloor::Custom(1e-7).value();
  assert!((v - 1e-7).abs() < f32::EPSILON);
}

#[test]
fn log_mel_spectrogram_whisper_vs_kaldi_differ_at_silence() {
  // 64 samples of silence — every mel-spec bin is 0, so the floor is the
  // ONLY input to the final log. Whisper's `ln(1e-10) ≈ -23.0259` is more
  // negative than Kaldi's `ln(1e-8) ≈ -18.4207`, by `ln(100) ≈ 4.6052`.
  let zeros = Array::zeros::<f32>(&(64usize,)).unwrap();
  let mut whisper =
    log_mel_spectrogram_with(&zeros, 16, 8, None, 4, 16_000, 0.0, None, LogFloor::Whisper).unwrap();
  let mut kaldi =
    log_mel_spectrogram_with(&zeros, 16, 8, None, 4, 16_000, 0.0, None, LogFloor::Kaldi).unwrap();

  let w = whisper.to_vec::<f32>().unwrap();
  let k = kaldi.to_vec::<f32>().unwrap();
  assert_eq!(w.len(), k.len(), "shape mismatch between floors");

  let expected_w = (1e-10_f32).ln();
  let expected_k = (1e-8_f32).ln();
  let expected_delta = expected_k - expected_w; // ≈ +4.6052

  for (wi, ki) in w.iter().zip(k.iter()) {
    assert!(
      (wi - expected_w).abs() < 1e-3,
      "whisper silence entry should equal ln(1e-10)={expected_w}, got {wi}"
    );
    assert!(
      (ki - expected_k).abs() < 1e-3,
      "kaldi silence entry should equal ln(1e-8)={expected_k}, got {ki}"
    );
    assert!(*wi < *ki, "whisper floor must be more negative than kaldi");
    assert!(
      ((ki - wi) - expected_delta).abs() < 1e-3,
      "delta whisper-kaldi should be ~ln(100)={expected_delta}, got {}",
      ki - wi
    );
  }
}

#[test]
fn log_mel_spectrogram_default_matches_explicit_whisper() {
  // Backward-compat guarantee: the parameterless `log_mel_spectrogram`
  // must be byte-identical to `log_mel_spectrogram_with(.., Whisper)`.
  let x = sine_1khz_16samples();
  let mut a = log_mel_spectrogram(&x, 8, 4, None, 4, 16_000, 0.0, None).unwrap();
  let mut b =
    log_mel_spectrogram_with(&x, 8, 4, None, 4, 16_000, 0.0, None, LogFloor::Whisper).unwrap();
  let va = a.to_vec::<f32>().unwrap();
  let vb = b.to_vec::<f32>().unwrap();
  assert_eq!(va.len(), vb.len());
  for (i, (x, y)) in va.iter().zip(vb.iter()).enumerate() {
    assert_eq!(
      x.to_bits(),
      y.to_bits(),
      "bit-mismatch at {i}: default={x:?} explicit_whisper={y:?}"
    );
  }
}

/// Regression: `mel_spectrogram` must call
/// `mel_filter_bank_cached` (per-thread LRU bank cache) rather than the
/// uncached `mel_filter_bank` constructor on the hot path. Without this,
/// the cache would be wired only into its own unit tests while
/// `mel_spectrogram` (and every `log_mel_spectrogram` /
/// `log_mel_spectrogram_with` / STT log-mel callsite that flows through
/// it) rebuilds the bank on every call.
///
/// **Structural assertion** rather than a runtime-cache-hit observation:
/// the `mel_filter_bank_cached` thread-local store is private to the
/// module and the cached `Array` is value-equal to the uncached path
/// (so a runtime hit/miss probe would have to reach into the cache
/// internals — out of public-API scope). Reading the function body's
/// source text instead pins the structural property — `mel_spectrogram`
/// references `mel_filter_bank_cached` — without coupling to private
/// state.
#[test]
fn mel_spectrogram_uses_cached_filter_bank() {
  // Source of `mlxrs/src/audio/dsp.rs` — `include_str!` resolves
  // relative to THIS test file at `mlxrs/tests/audio_dsp.rs`.
  let src = include_str!("../src/audio/dsp.rs");

  // Locate the `mel_spectrogram` definition. The matmul-and-return at
  // the end (`ops::linalg_basic::matmul(&mel, &power_t)`) is the
  // canonical terminator: every word from `pub fn mel_spectrogram(`
  // through it is the function body.
  let body_start = src
    .find("pub fn mel_spectrogram(")
    .expect("dsp.rs must define `pub fn mel_spectrogram(`");
  let body_tail = &src[body_start..];
  let body_end_rel = body_tail
    .find("ops::linalg_basic::matmul(&mel, &power_t)")
    .expect("mel_spectrogram body must terminate with the canonical matmul-return");
  let body = &body_tail[..body_end_rel];

  assert!(
    body.contains("mel_filter_bank_cached("),
    "regression: `mel_spectrogram` must invoke \
     `mel_filter_bank_cached(...)` (per-thread LRU cache), not the \
     uncached `mel_filter_bank(...)`. Function body was:\n{body}"
  );

  // Belt-and-braces: no leftover uncached `mel_filter_bank(` call in
  // the body. The literal `mel_filter_bank(` substring excludes the
  // cached variant (`mel_filter_bank_cached(` has `_cached` between
  // `bank` and `(`), so any match is necessarily an uncached call. A
  // preceding-byte check rejects matches embedded in a longer
  // identifier (`not_mel_filter_bank(` etc.) — the preceding byte must
  // not be a Rust identifier continuation char.
  let uncached_calls = body
    .match_indices("mel_filter_bank(")
    .filter(|(idx, _)| {
      // Reject matches where the preceding byte is part of an
      // identifier (so `not_mel_filter_bank(` is excluded as a
      // different identifier).
      if *idx == 0 {
        return true;
      }
      let prev = body.as_bytes()[*idx - 1];
      !(prev.is_ascii_alphanumeric() || prev == b'_')
    })
    .count();
  assert_eq!(
    uncached_calls, 0,
    "regression: `mel_spectrogram` body must NOT \
     contain any direct `mel_filter_bank(` call; only the cached \
     variant `mel_filter_bank_cached(` is allowed. Found {uncached_calls} \
     uncached call(s).\nBody:\n{body}"
  );
}