Skip to main content

mlx_nn/
lib.rs

1//! Neural network modules for MLX.
2//!
3//! Provides common building blocks: `Linear`, `Embedding`, `LayerNorm`, and
4//! `RMSNorm`. Each module stores its parameters as `Tensor` values and exposes
5//! a `forward()` method.
6
7mod attention;
8mod dropout;
9mod embedding;
10mod linear;
11mod norm;
12
13pub use attention::MultiHeadAttention;
14pub use dropout::Dropout;
15pub use embedding::Embedding;
16pub use linear::Linear;
17pub use norm::{LayerNorm, RmsNorm};
18
19use mlx_core::{Result, Tensor};
20
21/// Trait implemented by all NN modules.
22pub trait Module {
23    /// Run the forward pass.
24    fn forward(&self, input: &Tensor) -> Result<Tensor>;
25}
26
27#[cfg(test)]
28mod tests {
29    use super::*;
30    use mlx_core::{Device, Shape, Tensor};
31
32    fn cpu() -> Device {
33        Device::Cpu
34    }
35
36    fn s(dims: &[i64]) -> Shape {
37        Shape::new(dims.to_vec())
38    }
39
40    #[test]
41    fn test_linear_no_bias() {
42        // Linear(in=3, out=2, bias=false): y = x @ W^T
43        // W is [out, in] = [2, 3]
44        let weight =
45            Tensor::from_f32(&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0], &s(&[2, 3]), &cpu()).unwrap();
46        let linear = Linear::new(weight, None);
47
48        let x = Tensor::from_f32(&[1.0, 2.0, 3.0], &s(&[1, 3]), &cpu()).unwrap();
49        let y = linear.forward(&x).unwrap();
50        // [1,2,3] @ [[1,0],[0,1],[0,0]] = [1, 2]
51        let result = y.to_vec_f32().unwrap();
52        mlx_conformance::assert_allclose(&result, &[1.0, 2.0], 1e-5, 1e-5);
53    }
54
55    #[test]
56    fn test_linear_with_bias() {
57        let weight =
58            Tensor::from_f32(&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0], &s(&[2, 3]), &cpu()).unwrap();
59        let bias = Tensor::from_f32(&[0.5, -0.5], &s(&[2]), &cpu()).unwrap();
60        let linear = Linear::new(weight, Some(bias));
61
62        let x = Tensor::from_f32(&[1.0, 2.0, 3.0], &s(&[1, 3]), &cpu()).unwrap();
63        let y = linear.forward(&x).unwrap();
64        let result = y.to_vec_f32().unwrap();
65        mlx_conformance::assert_allclose(&result, &[1.5, 1.5], 1e-5, 1e-5);
66    }
67
68    #[test]
69    fn test_linear_batch() {
70        // Batch of 2 vectors, in=2, out=2
71        let weight = Tensor::from_f32(&[1.0, 0.0, 0.0, 1.0], &s(&[2, 2]), &cpu()).unwrap();
72        let linear = Linear::new(weight, None);
73
74        let x = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &s(&[2, 2]), &cpu()).unwrap();
75        let y = linear.forward(&x).unwrap();
76        let result = y.to_vec_f32().unwrap();
77        // Identity weight: output = input
78        mlx_conformance::assert_allclose(&result, &[1.0, 2.0, 3.0, 4.0], 1e-5, 1e-5);
79    }
80
81    #[test]
82    fn test_layer_norm() {
83        let ln = LayerNorm::new(3, 1e-5);
84        let x = Tensor::from_f32(&[1.0, 2.0, 3.0], &s(&[1, 3]), &cpu()).unwrap();
85        let y = ln.forward(&x).unwrap();
86        let result = y.to_vec_f32().unwrap();
87        // Should be normalized: mean ≈ 0, std ≈ 1
88        let mean: f32 = result.iter().sum::<f32>() / result.len() as f32;
89        assert!(mean.abs() < 1e-4, "mean should be ~0, got {mean}");
90        mlx_conformance::assert_allclose(&result, &[-1.2247, 0.0, 1.2247], 1e-3, 1e-3);
91    }
92
93    #[test]
94    fn test_embedding() {
95        // Vocab=3, dim=2. Weight: [[10,11],[20,21],[30,31]]
96        let weight =
97            Tensor::from_f32(&[10.0, 11.0, 20.0, 21.0, 30.0, 31.0], &s(&[3, 2]), &cpu()).unwrap();
98        let emb = Embedding::new(weight);
99        // Look up indices [2, 0, 1]
100        let indices = Tensor::from_f32(&[2.0, 0.0, 1.0], &s(&[3]), &cpu()).unwrap();
101        let y = emb.forward(&indices).unwrap();
102        let result = y.to_vec_f32().unwrap();
103        mlx_conformance::assert_allclose(
104            &result,
105            &[30.0, 31.0, 10.0, 11.0, 20.0, 21.0],
106            1e-5,
107            1e-5,
108        );
109        assert_eq!(y.shape(), &s(&[3, 2]));
110    }
111
112    #[test]
113    fn test_dropout_eval_mode() {
114        let mut drop = Dropout::new(0.5);
115        drop.eval();
116        let x = Tensor::from_f32(&[1.0, 2.0, 3.0], &s(&[3]), &cpu()).unwrap();
117        let y = drop.forward(&x).unwrap();
118        let result = y.to_vec_f32().unwrap();
119        // In eval mode, output == input
120        mlx_conformance::assert_allclose(&result, &[1.0, 2.0, 3.0], 1e-5, 1e-5);
121    }
122
123    #[test]
124    fn test_dropout_zero_prob() {
125        let drop = Dropout::new(0.0);
126        let x = Tensor::from_f32(&[1.0, 2.0, 3.0], &s(&[3]), &cpu()).unwrap();
127        let y = drop.forward(&x).unwrap();
128        let result = y.to_vec_f32().unwrap();
129        mlx_conformance::assert_allclose(&result, &[1.0, 2.0, 3.0], 1e-5, 1e-5);
130    }
131
132    #[test]
133    fn test_dropout_training_mode() {
134        let drop = Dropout::new(0.5);
135        let x = Tensor::from_f32(&[1.0; 1000], &s(&[1000]), &cpu()).unwrap();
136        let y = drop.forward(&x).unwrap();
137        let result = y.to_vec_f32().unwrap();
138        // In training mode, roughly half should be zero, non-zero should be scaled by 2.0
139        let zeros = result.iter().filter(|&&v| v == 0.0).count();
140        let non_zeros = result.iter().filter(|&&v| v != 0.0).count();
141        assert!(zeros > 300, "expected many zeros, got {zeros}");
142        assert!(non_zeros > 300, "expected many non-zeros, got {non_zeros}");
143        // Non-zero values should be scaled by 1/(1-0.5) = 2.0
144        for &v in result.iter().filter(|&&v| v != 0.0) {
145            assert!((v - 2.0).abs() < 1e-5, "expected ~2.0, got {v}");
146        }
147    }
148
149    #[test]
150    fn test_dropout_full_drop() {
151        let drop = Dropout::new(1.0);
152        let x = Tensor::from_f32(&[1.0, 2.0, 3.0], &s(&[3]), &cpu()).unwrap();
153        let y = drop.forward(&x).unwrap();
154        let result = y.to_vec_f32().unwrap();
155        mlx_conformance::assert_allclose(&result, &[0.0, 0.0, 0.0], 1e-5, 1e-5);
156    }
157
158    #[test]
159    fn test_embedding_negative_index() {
160        let weight =
161            Tensor::from_f32(&[10.0, 11.0, 20.0, 21.0], &s(&[2, 2]), &cpu()).unwrap();
162        let emb = Embedding::new(weight);
163        let indices = Tensor::from_f32(&[-1.0], &s(&[1]), &cpu()).unwrap();
164        // forward() builds a lazy graph node (Ok), but the negative index
165        // validation fires at eval time when the CPU kernel runs.
166        let y = emb.forward(&indices).unwrap();
167        let result = y.to_vec_f32();
168        assert!(result.is_err(), "negative index should fail at eval time");
169    }
170
171    #[test]
172    fn test_rms_norm() {
173        let rn = RmsNorm::new(3, 1e-5);
174        let x = Tensor::from_f32(&[1.0, 2.0, 3.0], &s(&[1, 3]), &cpu()).unwrap();
175        let y = rn.forward(&x).unwrap();
176        let result = y.to_vec_f32().unwrap();
177        // RMS of [1,2,3] = sqrt((1+4+9)/3) = sqrt(14/3) ≈ 2.1602
178        let rms = (14.0f32 / 3.0).sqrt();
179        let expected = [1.0 / rms, 2.0 / rms, 3.0 / rms];
180        mlx_conformance::assert_allclose(&result, &expected, 1e-4, 1e-4);
181    }
182
183    #[test]
184    fn test_multi_head_attention_smoke() {
185        // model_dim=4, n_heads=2, head_dim=2, seq_len=2
186        // Identity-ish weights for simplicity
187        let wq_w = Tensor::from_f32(
188            &[
189                1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0,
190            ],
191            &s(&[4, 4]),
192            &cpu(),
193        )
194        .unwrap();
195        let wo_w = wq_w.clone();
196        let wq = Linear::new(wq_w.clone(), None);
197        let wk = Linear::new(wq_w.clone(), None);
198        let wv = Linear::new(wq_w, None);
199        let wo = Linear::new(wo_w, None);
200
201        let mha = MultiHeadAttention::new(wq, wk, wv, wo, 2);
202
203        let x = Tensor::from_f32(
204            &[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
205            &s(&[2, 4]),
206            &cpu(),
207        )
208        .unwrap();
209        let y = mha.forward_causal(&x).unwrap();
210        assert_eq!(y.shape(), &s(&[2, 4]));
211        // Just verify it runs and produces correct shape
212        let result = y.to_vec_f32().unwrap();
213        assert_eq!(result.len(), 8);
214    }
215}