1mod attention;
8mod dropout;
9mod embedding;
10mod linear;
11mod norm;
12
13pub use attention::MultiHeadAttention;
14pub use dropout::Dropout;
15pub use embedding::Embedding;
16pub use linear::Linear;
17pub use norm::{LayerNorm, RmsNorm};
18
19use mlx_core::{Result, Tensor};
20
21pub trait Module {
23 fn forward(&self, input: &Tensor) -> Result<Tensor>;
25}
26
27#[cfg(test)]
28mod tests {
29 use super::*;
30 use mlx_core::{Device, Shape, Tensor};
31
32 fn cpu() -> Device {
33 Device::Cpu
34 }
35
36 fn s(dims: &[i64]) -> Shape {
37 Shape::new(dims.to_vec())
38 }
39
40 #[test]
41 fn test_linear_no_bias() {
42 let weight =
45 Tensor::from_f32(&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0], &s(&[2, 3]), &cpu()).unwrap();
46 let linear = Linear::new(weight, None);
47
48 let x = Tensor::from_f32(&[1.0, 2.0, 3.0], &s(&[1, 3]), &cpu()).unwrap();
49 let y = linear.forward(&x).unwrap();
50 let result = y.to_vec_f32().unwrap();
52 mlx_conformance::assert_allclose(&result, &[1.0, 2.0], 1e-5, 1e-5);
53 }
54
55 #[test]
56 fn test_linear_with_bias() {
57 let weight =
58 Tensor::from_f32(&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0], &s(&[2, 3]), &cpu()).unwrap();
59 let bias = Tensor::from_f32(&[0.5, -0.5], &s(&[2]), &cpu()).unwrap();
60 let linear = Linear::new(weight, Some(bias));
61
62 let x = Tensor::from_f32(&[1.0, 2.0, 3.0], &s(&[1, 3]), &cpu()).unwrap();
63 let y = linear.forward(&x).unwrap();
64 let result = y.to_vec_f32().unwrap();
65 mlx_conformance::assert_allclose(&result, &[1.5, 1.5], 1e-5, 1e-5);
66 }
67
68 #[test]
69 fn test_linear_batch() {
70 let weight = Tensor::from_f32(&[1.0, 0.0, 0.0, 1.0], &s(&[2, 2]), &cpu()).unwrap();
72 let linear = Linear::new(weight, None);
73
74 let x = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], &s(&[2, 2]), &cpu()).unwrap();
75 let y = linear.forward(&x).unwrap();
76 let result = y.to_vec_f32().unwrap();
77 mlx_conformance::assert_allclose(&result, &[1.0, 2.0, 3.0, 4.0], 1e-5, 1e-5);
79 }
80
81 #[test]
82 fn test_layer_norm() {
83 let ln = LayerNorm::new(3, 1e-5);
84 let x = Tensor::from_f32(&[1.0, 2.0, 3.0], &s(&[1, 3]), &cpu()).unwrap();
85 let y = ln.forward(&x).unwrap();
86 let result = y.to_vec_f32().unwrap();
87 let mean: f32 = result.iter().sum::<f32>() / result.len() as f32;
89 assert!(mean.abs() < 1e-4, "mean should be ~0, got {mean}");
90 mlx_conformance::assert_allclose(&result, &[-1.2247, 0.0, 1.2247], 1e-3, 1e-3);
91 }
92
93 #[test]
94 fn test_embedding() {
95 let weight =
97 Tensor::from_f32(&[10.0, 11.0, 20.0, 21.0, 30.0, 31.0], &s(&[3, 2]), &cpu()).unwrap();
98 let emb = Embedding::new(weight);
99 let indices = Tensor::from_f32(&[2.0, 0.0, 1.0], &s(&[3]), &cpu()).unwrap();
101 let y = emb.forward(&indices).unwrap();
102 let result = y.to_vec_f32().unwrap();
103 mlx_conformance::assert_allclose(
104 &result,
105 &[30.0, 31.0, 10.0, 11.0, 20.0, 21.0],
106 1e-5,
107 1e-5,
108 );
109 assert_eq!(y.shape(), &s(&[3, 2]));
110 }
111
112 #[test]
113 fn test_dropout_eval_mode() {
114 let mut drop = Dropout::new(0.5);
115 drop.eval();
116 let x = Tensor::from_f32(&[1.0, 2.0, 3.0], &s(&[3]), &cpu()).unwrap();
117 let y = drop.forward(&x).unwrap();
118 let result = y.to_vec_f32().unwrap();
119 mlx_conformance::assert_allclose(&result, &[1.0, 2.0, 3.0], 1e-5, 1e-5);
121 }
122
123 #[test]
124 fn test_dropout_zero_prob() {
125 let drop = Dropout::new(0.0);
126 let x = Tensor::from_f32(&[1.0, 2.0, 3.0], &s(&[3]), &cpu()).unwrap();
127 let y = drop.forward(&x).unwrap();
128 let result = y.to_vec_f32().unwrap();
129 mlx_conformance::assert_allclose(&result, &[1.0, 2.0, 3.0], 1e-5, 1e-5);
130 }
131
132 #[test]
133 fn test_dropout_training_mode() {
134 let drop = Dropout::new(0.5);
135 let x = Tensor::from_f32(&[1.0; 1000], &s(&[1000]), &cpu()).unwrap();
136 let y = drop.forward(&x).unwrap();
137 let result = y.to_vec_f32().unwrap();
138 let zeros = result.iter().filter(|&&v| v == 0.0).count();
140 let non_zeros = result.iter().filter(|&&v| v != 0.0).count();
141 assert!(zeros > 300, "expected many zeros, got {zeros}");
142 assert!(non_zeros > 300, "expected many non-zeros, got {non_zeros}");
143 for &v in result.iter().filter(|&&v| v != 0.0) {
145 assert!((v - 2.0).abs() < 1e-5, "expected ~2.0, got {v}");
146 }
147 }
148
149 #[test]
150 fn test_dropout_full_drop() {
151 let drop = Dropout::new(1.0);
152 let x = Tensor::from_f32(&[1.0, 2.0, 3.0], &s(&[3]), &cpu()).unwrap();
153 let y = drop.forward(&x).unwrap();
154 let result = y.to_vec_f32().unwrap();
155 mlx_conformance::assert_allclose(&result, &[0.0, 0.0, 0.0], 1e-5, 1e-5);
156 }
157
158 #[test]
159 fn test_embedding_negative_index() {
160 let weight =
161 Tensor::from_f32(&[10.0, 11.0, 20.0, 21.0], &s(&[2, 2]), &cpu()).unwrap();
162 let emb = Embedding::new(weight);
163 let indices = Tensor::from_f32(&[-1.0], &s(&[1]), &cpu()).unwrap();
164 let y = emb.forward(&indices).unwrap();
167 let result = y.to_vec_f32();
168 assert!(result.is_err(), "negative index should fail at eval time");
169 }
170
171 #[test]
172 fn test_rms_norm() {
173 let rn = RmsNorm::new(3, 1e-5);
174 let x = Tensor::from_f32(&[1.0, 2.0, 3.0], &s(&[1, 3]), &cpu()).unwrap();
175 let y = rn.forward(&x).unwrap();
176 let result = y.to_vec_f32().unwrap();
177 let rms = (14.0f32 / 3.0).sqrt();
179 let expected = [1.0 / rms, 2.0 / rms, 3.0 / rms];
180 mlx_conformance::assert_allclose(&result, &expected, 1e-4, 1e-4);
181 }
182
183 #[test]
184 fn test_multi_head_attention_smoke() {
185 let wq_w = Tensor::from_f32(
188 &[
189 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0,
190 ],
191 &s(&[4, 4]),
192 &cpu(),
193 )
194 .unwrap();
195 let wo_w = wq_w.clone();
196 let wq = Linear::new(wq_w.clone(), None);
197 let wk = Linear::new(wq_w.clone(), None);
198 let wv = Linear::new(wq_w, None);
199 let wo = Linear::new(wo_w, None);
200
201 let mha = MultiHeadAttention::new(wq, wk, wv, wo, 2);
202
203 let x = Tensor::from_f32(
204 &[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
205 &s(&[2, 4]),
206 &cpu(),
207 )
208 .unwrap();
209 let y = mha.forward_causal(&x).unwrap();
210 assert_eq!(y.shape(), &s(&[2, 4]));
211 let result = y.to_vec_f32().unwrap();
213 assert_eq!(result.len(), 8);
214 }
215}