volta/
lib.rs

1//! # Volta
2//!
3//! A minimal automatic differentiation library implementing PyTorch-like tensor operations
4//! from scratch in pure Rust. This library provides:
5//! - Dynamic computation graphs for automatic differentiation
6//! - Broadcasting support for tensor operations
7//! - Common neural network operations (matmul, activations, etc.)
8//! - Numerical gradient checking for validation
9//!
10//! ## Architecture
11//!
12//! The library uses reference-counted interior mutability (`Rc<RefCell<RawTensor>>`) to build
13//! dynamic computation graphs. Each tensor operation creates new tensors and stores gradient
14//! functions that know how to backpropagate through that operation.
15
16#[cfg(feature = "gpu")]
17pub mod gpu;
18
19pub mod dtype;
20pub mod storage;
21
22// Add to re-exports:
23pub use dtype::DType;
24pub use storage::Storage;
25
26#[cfg(feature = "gpu")]
27pub use gpu::{
28    GpuBuffer, GpuContext, get_gpu_context, gpu_cleanup, gpu_compact, gpu_pending_count,
29    gpu_pool_stats, gpu_sync, gpu_sync_threshold, is_gpu_available,
30};
31
32pub mod autograd;
33pub mod data;
34pub mod device;
35pub mod io;
36pub mod nn;
37pub mod ops;
38pub mod tensor;
39pub mod utils;
40
41// Re-export main types for easy access
42pub use autograd::GradFn;
43pub use device::Device;
44pub use nn::layers::Dropout;
45pub use nn::layers::flatten::Flatten;
46pub use nn::{
47    Adam, BatchNorm1d, BatchNorm2d, Conv2d, ConvTranspose2d, Embedding, LSTMCell, Linear,
48    MaxPool2d, Module, PixelShuffle, ReLU, SGD, Sequential, SequentialBuilder, Sigmoid, Tanh,
49};
50pub use tensor::{RawTensor, Tensor, TensorOps};
51
52// Main entry points
53
54pub use tensor::{
55    DataLoader, bce_loss, bce_with_logits_loss, check_gradients, check_gradients_simple,
56    cross_entropy_loss, kl_divergence_gaussian, manual_seed, max_dim, mse_loss, new_tensor,
57    nll_loss, ones, rand, randn, randn_like, softmax, sum_dim, zeros,
58};
59
60pub use data::{load_mnist_images, load_mnist_labels, normalize, to_one_hot};
61pub use io::{
62    TypedTensorData, load_safetensors, load_safetensors_raw, load_safetensors_with_mapping,
63    load_state_dict_with_mapping, mapping, save_safetensors, save_safetensors_typed,
64};
65pub use utils::ProgressBar;
66
67pub use ops::{
68    BinaryGradFn, BinaryOp, MatMulGradFn, MaxReduceGradFn, MeanGradFn, MovementGradFn, MovementOp,
69    MulAccGradFn, ReduceOp, SumGradFn, TernaryOp, UnaryGradFn, UnaryOp, WhereGradFn,
70};
71
72// ===== TESTS =====
73//
74// The test suite validates:
75// - Basic operations (add, mul, etc.)
76// - Gradient correctness (chain rule, broadcasting)
77// - Complex scenarios (neural networks, matmul variants)
78// - Numerical gradient checking (validates all gradients)
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83
84    #[test]
85    fn test_basic_add_backward() {
86        let a = RawTensor::new(vec![2.0], &[1], true);
87        let b = RawTensor::new(vec![3.0], &[1], true);
88        let c = a.add(&b);
89        c.backward();
90
91        assert_eq!(a.grad(), Some(vec![1.0]));
92        assert_eq!(b.grad(), Some(vec![1.0]));
93    }
94
95    #[test]
96    fn test_enhanced_device_safety() {
97        let a = RawTensor::new(vec![2.0], &[1], true);
98        let b = RawTensor::new(vec![3.0], &[1], true);
99        let c = a.add(&b);
100        c.backward();
101
102        // Test device handling safety
103        let cpu_device = Device::CPU;
104        assert!(cpu_device.is_cpu());
105        assert!(!cpu_device.is_gpu());
106        assert_eq!(cpu_device.name(), "CPU");
107
108        let gpu_device = Device::GPU("CUDA".to_string());
109        assert!(!gpu_device.is_cpu());
110        assert!(gpu_device.is_gpu());
111        assert_eq!(gpu_device.name(), "CUDA");
112
113        // Test that tensor operations still work
114        assert_eq!(a.grad(), Some(vec![1.0]));
115        assert_eq!(b.grad(), Some(vec![1.0]));
116    }
117
118    #[test]
119    fn test_multiply_backward() {
120        let a = RawTensor::new(vec![3.0], &[1], true);
121        let b = RawTensor::new(vec![4.0], &[1], true);
122        let c = a.elem_mul(&b);
123        c.backward();
124
125        assert_eq!(a.grad(), Some(vec![4.0]));
126        assert_eq!(b.grad(), Some(vec![3.0]));
127    }
128
129    #[test]
130    fn test_chain_rule() {
131        let a = RawTensor::new(vec![2.0], &[1], true);
132        let b = RawTensor::new(vec![3.0], &[1], true);
133        let c = a.add(&b);
134        let d = c.elem_mul(&a);
135        d.backward();
136
137        assert_eq!(a.grad(), Some(vec![7.0]));
138        assert_eq!(b.grad(), Some(vec![2.0]));
139    }
140
141    #[test]
142    fn test_sum_backward() {
143        let a = RawTensor::new(vec![1.0, 2.0, 3.0], &[3], true);
144        let loss = a.sum();
145        loss.backward();
146
147        assert_eq!(a.grad(), Some(vec![1.0, 1.0, 1.0]));
148    }
149
150    #[test]
151    fn test_multidim_ops() {
152        let a = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0], &[2, 2], true);
153        let b = RawTensor::new(vec![0.5, 0.5, 0.5, 0.5], &[2, 2], true);
154        let c = a.elem_mul(&b);
155        let loss = c.sum();
156        loss.backward();
157
158        assert_eq!(a.grad(), Some(vec![0.5, 0.5, 0.5, 0.5]));
159        assert_eq!(b.grad(), Some(vec![1.0, 2.0, 3.0, 4.0]));
160    }
161}
162
163#[cfg(test)]
164mod unary_tests {
165    use super::*;
166    use approx::assert_relative_eq;
167
168    #[test]
169    fn test_neg_forward_backward() {
170        let x = RawTensor::new(vec![2.0, -3.0], &[2], true);
171        let y = x.neg();
172
173        // Forward
174        assert_eq!(y.borrow().data, vec![-2.0, 3.0]);
175
176        // Backward: ∂(-x)/∂x = -1
177        y.backward();
178        assert_eq!(x.grad(), Some(vec![-1.0, -1.0]));
179    }
180
181    #[test]
182    fn test_sqrt_chain() {
183        let x = RawTensor::new(vec![4.0], &[1], true);
184        let y = x.sqrt(); // y = 2.0
185        let z = y.elem_mul(&y); // z = 4.0
186        z.backward();
187
188        // ∂z/∂x = ∂z/∂y * ∂y/∂x = 2y * 1/(2√x) = 2*2 * 1/4 = 1.0
189        assert_relative_eq!(
190            x.grad().unwrap().first().copied().unwrap_or(f32::NAN),
191            1.0,
192            epsilon = 1e-6
193        );
194    }
195
196    #[test]
197    fn test_exp2_log2_inverse() {
198        let x = RawTensor::new(vec![2.0], &[1], true);
199        let y = x.exp2().log2(); // should recover x
200        y.backward();
201
202        assert_relative_eq!(
203            y.borrow().data.first().copied().unwrap_or(f32::NAN),
204            2.0,
205            epsilon = 1e-6
206        );
207        // Chain rule: ∂(log2(2^x))/∂x = 1
208        assert_relative_eq!(
209            x.grad().unwrap().first().copied().unwrap_or(f32::NAN),
210            1.0,
211            epsilon = 1e-6
212        );
213    }
214}
215
216#[cfg(test)]
217mod binary_tests {
218    use super::*;
219
220    #[test]
221    fn test_div_backward() {
222        let x = RawTensor::new(vec![6.0], &[1], true);
223        let y = RawTensor::new(vec![2.0], &[1], true);
224        let z = x.div(&y); // z = 3.0
225        z.backward();
226
227        // ∂(x/y)/∂x = 1/y = 0.5
228        assert_eq!(x.grad(), Some(vec![0.5]));
229        // ∂(x/y)/∂y = -x/y² = -6/4 = -1.5
230        assert_eq!(y.grad(), Some(vec![-1.5]));
231    }
232
233    #[test]
234    fn test_max_backward() {
235        let x = RawTensor::new(vec![3.0, 1.0], &[2], true);
236        let y = RawTensor::new(vec![2.0, 4.0], &[2], true);
237        let z = x.max_elem(&y);
238        let loss = z.sum();
239        loss.backward();
240
241        // max picks [3.0, 4.0], so grads flow to x[0] and y[1]
242        assert_eq!(x.grad(), Some(vec![1.0, 0.0]));
243        assert_eq!(y.grad(), Some(vec![0.0, 1.0]));
244    }
245}
246
247#[cfg(test)]
248mod reduce_tests {
249    use super::*;
250
251    #[test]
252    fn test_reduce_max_backward() {
253        let x = RawTensor::new(vec![1.0, 5.0, 3.0], &[3], true);
254        let y = x.max_reduce(); // finds 5.0 at index 1
255        y.backward();
256
257        // Only max element gets gradient
258        assert_eq!(x.grad(), Some(vec![0.0, 1.0, 0.0]));
259    }
260}
261
262#[cfg(test)]
263mod ternary_tests {
264    use super::*;
265
266    #[test]
267    fn test_mulacc_backward() {
268        // z = x*y + w
269        let x = RawTensor::new(vec![2.0], &[1], true);
270        let y = RawTensor::new(vec![3.0], &[1], true);
271        let w = RawTensor::new(vec![1.0], &[1], true);
272        let z = x.mulacc(&y, &w); // z = 7.0
273        z.backward();
274
275        assert_eq!(x.grad(), Some(vec![3.0])); // ∂z/∂x = y
276        assert_eq!(y.grad(), Some(vec![2.0])); // ∂z/∂y = x
277        assert_eq!(w.grad(), Some(vec![1.0])); // ∂z/∂w = 1
278    }
279
280    #[test]
281    fn test_where_backward() {
282        let cond = RawTensor::new(vec![1.0, 0.0], &[2], false);
283        let x = RawTensor::new(vec![10.0, 20.0], &[2], true);
284        let y = RawTensor::new(vec![30.0, 40.0], &[2], true);
285        let z = cond.where_op(&x, &y); // picks [10.0, 40.0]
286        z.backward();
287
288        assert_eq!(x.grad(), Some(vec![1.0, 0.0])); // grad flows where cond=1
289        assert_eq!(y.grad(), Some(vec![0.0, 1.0])); // grad flows where cond=0
290    }
291
292    #[test]
293    fn test_where_broadcast_backward() {
294        // condition shape (2,1), true branch (2,3), false branch (1,3)
295        let cond = RawTensor::new(vec![1.0, 0.0], &[2, 1], false);
296        let true_branch = RawTensor::new(vec![10.0, 11.0, 12.0, 20.0, 21.0, 22.0], &[2, 3], true);
297        let false_branch = RawTensor::new(vec![1.0, 2.0, 3.0], &[1, 3], true);
298        let out = cond.where_op(&true_branch, &false_branch);
299        let loss = out.sum();
300        loss.backward();
301
302        // Gradients: first row picks true branch, second row picks false branch
303        assert_eq!(true_branch.grad(), Some(vec![1.0, 1.0, 1.0, 0.0, 0.0, 0.0]));
304        assert_eq!(false_branch.grad(), Some(vec![1.0, 1.0, 1.0]));
305    }
306}
307
308#[cfg(test)]
309mod movement_tests {
310    use super::*;
311
312    #[test]
313    fn test_reshape_backward() {
314        let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0], &[4], true);
315        let y = x.reshape(&[2, 2]);
316        let loss = y.sum();
317        loss.backward();
318
319        // Gradient reshapes back to [4]
320        assert_eq!(x.grad(), Some(vec![1.0, 1.0, 1.0, 1.0]));
321    }
322
323    #[test]
324    fn test_permute_backward() {
325        let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0], &[2, 2], true);
326        let y = x.permute(&[1, 0]); // transpose
327        let loss = y.sum();
328        loss.backward();
329
330        // Gradient permutes back
331        assert_eq!(x.grad(), Some(vec![1.0, 1.0, 1.0, 1.0]));
332    }
333}
334
335#[cfg(test)]
336mod misc_tests {
337    use super::*;
338    // ===== NEURAL NETWORK LAYER TEST =====
339
340    #[test]
341    fn test_linear_layer() {
342        // Simple linear layer: y = xW + b
343        let x = RawTensor::new(vec![1.0, 2.0, 3.0], &[1, 3], true);
344        let w = RawTensor::new(vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6], &[3, 2], true);
345        let b = RawTensor::new(vec![0.1, 0.2], &[1, 2], true);
346
347        let y = x.matmul(&w); // [1,3] @ [3,2] = [1,2]
348        let out = y.add(&b);
349        let loss = out.sum();
350
351        loss.backward();
352
353        // All should have gradients
354        assert!(x.grad().is_some());
355        assert!(w.grad().is_some());
356        assert!(b.grad().is_some());
357
358        // b gradient should be ones (direct path from sum)
359        assert_eq!(b.grad(), Some(vec![1.0, 1.0]));
360    }
361
362    #[test]
363    fn test_tensor_zero_grad() {
364        let x = RawTensor::new(vec![1.0, 2.0, 3.0], &[3], true);
365        let loss = x.sum();
366        loss.backward();
367
368        assert!(x.grad().is_some());
369
370        x.zero_grad();
371        assert!(x.grad().is_none());
372
373        let rewind = x.sum();
374        rewind.backward();
375        assert_eq!(x.grad(), Some(vec![1.0, 1.0, 1.0]));
376    }
377
378    // ===== BROADCASTING TESTS =====
379
380    #[test]
381    fn test_broadcast_shape() {
382        // (3, 1) and (1, 4) -> (3, 4)
383        let shape = RawTensor::broadcast_shape(&[3, 1], &[1, 4]);
384        assert_eq!(shape, vec![3, 4]);
385
386        // (5, 3, 1) and (1, 4) -> (5, 3, 4)
387        let shape = RawTensor::broadcast_shape(&[5, 3, 1], &[1, 4]);
388        assert_eq!(shape, vec![5, 3, 4]);
389
390        // (1,) and (3, 4) -> (3, 4)
391        let shape = RawTensor::broadcast_shape(&[1], &[3, 4]);
392        assert_eq!(shape, vec![3, 4]);
393
394        // (3, 4) and (4,) -> (3, 4)
395        let shape = RawTensor::broadcast_shape(&[3, 4], &[4]);
396        assert_eq!(shape, vec![3, 4]);
397    }
398
399    #[test]
400    #[should_panic(expected = "Cannot broadcast")]
401    fn test_broadcast_incompatible() {
402        let _ = RawTensor::broadcast_shape(&[3, 2], &[4, 3]);
403    }
404
405    #[test]
406    fn test_broadcast_add_scalar() {
407        // (2, 3) + scalar -> (2, 3)
408        let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
409        let scalar = RawTensor::new(vec![10.0], &[1], true);
410        let y = x.add(&scalar);
411
412        assert_eq!(y.borrow().shape, vec![2, 3]);
413        assert_eq!(y.borrow().data, vec![11.0, 12.0, 13.0, 14.0, 15.0, 16.0]);
414
415        y.backward();
416
417        // x gradient: all ones
418        assert_eq!(x.grad(), Some(vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0]));
419        // scalar gradient: sum of all gradients = 6.0
420        assert_eq!(scalar.grad(), Some(vec![6.0]));
421    }
422
423    #[test]
424    fn test_broadcast_mul_vector() {
425        // (2, 3) * (3,) -> (2, 3)
426        let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
427        let v = RawTensor::new(vec![2.0, 3.0, 4.0], &[3], true);
428        let y = x.elem_mul(&v);
429
430        assert_eq!(y.borrow().shape, vec![2, 3]);
431        assert_eq!(y.borrow().data, vec![2.0, 6.0, 12.0, 8.0, 15.0, 24.0]);
432
433        y.backward();
434
435        // x gradient: broadcast v
436        assert_eq!(x.grad(), Some(vec![2.0, 3.0, 4.0, 2.0, 3.0, 4.0]));
437        // v gradient: sum over rows
438        assert_eq!(v.grad(), Some(vec![5.0, 7.0, 9.0])); // [1+4, 2+5, 3+6]
439    }
440
441    #[test]
442    fn test_broadcast_add_matrix() {
443        // (2, 1) + (1, 3) -> (2, 3)
444        let x = RawTensor::new(vec![1.0, 2.0], &[2, 1], true);
445        let y = RawTensor::new(vec![10.0, 20.0, 30.0], &[1, 3], true);
446        let z = x.add(&y);
447
448        assert_eq!(z.borrow().shape, vec![2, 3]);
449        assert_eq!(z.borrow().data, vec![11.0, 21.0, 31.0, 12.0, 22.0, 32.0]);
450
451        z.backward();
452
453        // x gradient: sum over columns -> [3.0, 3.0]
454        assert_eq!(x.grad(), Some(vec![3.0, 3.0]));
455        // y gradient: sum over rows -> [2.0, 2.0, 2.0]
456        assert_eq!(y.grad(), Some(vec![2.0, 2.0, 2.0]));
457    }
458
459    #[test]
460    fn test_broadcast_batch_bias() {
461        // Simulate batch with bias: (batch=3, features=2) + (features=2,)
462        let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2], true);
463        let bias = RawTensor::new(vec![0.5, 1.0], &[2], true);
464        let y = x.add(&bias);
465
466        assert_eq!(y.borrow().shape, vec![3, 2]);
467        assert_eq!(y.borrow().data, vec![1.5, 3.0, 3.5, 5.0, 5.5, 7.0]);
468
469        let loss = y.sum();
470        loss.backward();
471
472        // x gradient: all ones
473        assert_eq!(x.grad(), Some(vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0]));
474        // bias gradient: sum over batch -> [3.0, 3.0]
475        assert_eq!(bias.grad(), Some(vec![3.0, 3.0]));
476    }
477
478    #[test]
479    fn test_broadcast_div() {
480        // (2, 3) / (1, 3) -> (2, 3)
481        let x = RawTensor::new(vec![2.0, 4.0, 6.0, 8.0, 10.0, 12.0], &[2, 3], true);
482        let y = RawTensor::new(vec![2.0, 2.0, 2.0], &[1, 3], true);
483        let z = x.div(&y);
484
485        assert_eq!(z.borrow().shape, vec![2, 3]);
486        assert_eq!(z.borrow().data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
487
488        z.backward();
489
490        // x gradient: 1/y broadcast
491        assert_eq!(x.grad(), Some(vec![0.5, 0.5, 0.5, 0.5, 0.5, 0.5]));
492        // y gradient: sum(-x/y²) over rows
493        // -2/4=-0.5, -4/4=-1.0, -6/4=-1.5 (row 1)
494        // -8/4=-2.0, -10/4=-2.5, -12/4=-3.0 (row 2)
495        // sum: [-2.5, -3.5, -4.5]
496        assert_eq!(y.grad(), Some(vec![-2.5, -3.5, -4.5]));
497    }
498
499    #[test]
500    fn test_broadcast_3d() {
501        // (1, 2, 3) + (2, 1) -> (1, 2, 3) but will broadcast to match
502        let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[1, 2, 3], true);
503        let y = RawTensor::new(vec![10.0, 20.0], &[2, 1], true);
504        let z = x.add(&y);
505
506        assert_eq!(z.borrow().shape, vec![1, 2, 3]);
507        // Row 0: [1,2,3] + 10 = [11,12,13]
508        // Row 1: [4,5,6] + 20 = [24,25,26]
509        assert_eq!(z.borrow().data, vec![11.0, 12.0, 13.0, 24.0, 25.0, 26.0]);
510
511        z.backward();
512
513        // x gradient: all ones
514        assert_eq!(x.grad(), Some(vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0]));
515        // y gradient: sum over last dimension -> [3.0, 3.0]
516        assert_eq!(y.grad(), Some(vec![3.0, 3.0]));
517    }
518
519    #[test]
520    fn test_broadcast_max() {
521        // (2, 3) max (3,) -> (2, 3)
522        let x = RawTensor::new(vec![1.0, 5.0, 3.0, 4.0, 2.0, 6.0], &[2, 3], true);
523        let y = RawTensor::new(vec![2.0, 3.0, 4.0], &[3], true);
524        let z = x.max_elem(&y);
525
526        assert_eq!(z.borrow().shape, vec![2, 3]);
527        // [max(1,2), max(5,3), max(3,4)] = [2,5,4]
528        // [max(4,2), max(2,3), max(6,4)] = [4,3,6]
529        assert_eq!(z.borrow().data, vec![2.0, 5.0, 4.0, 4.0, 3.0, 6.0]);
530
531        z.backward();
532
533        // Gradient flows to max elements
534        // x: [0, 1, 0, 1, 0, 1] (x wins at indices 1, 3, 5)
535        assert_eq!(x.grad(), Some(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0]));
536        // y: sum over rows where y wins
537        // y[0] wins at (0,0): 1
538        // y[1] wins at (1,1): 1
539        // y[2] wins at (0,2): 1
540        // Total: [1, 1, 1]
541        assert_eq!(y.grad(), Some(vec![1.0, 1.0, 1.0]));
542    }
543
544    #[test]
545    fn test_broadcast_bias_add() {
546        // Common pattern: batch matmul + bias
547        // (batch=2, in=3) @ (3, 4) + (4,) -> (2, 4)
548        let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
549        let w = RawTensor::new(
550            vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2],
551            &[3, 4],
552            true,
553        );
554        let b = RawTensor::new(vec![0.01, 0.02, 0.03, 0.04], &[4], true);
555
556        let y = x.matmul(&w);
557        let z = y.add(&b); // Broadcasting happens here
558        let loss = z.sum();
559
560        loss.backward();
561
562        // All should have gradients
563        assert!(x.grad().is_some());
564        assert!(w.grad().is_some());
565        assert!(b.grad().is_some());
566
567        // Bias gradient should be [batch_size, batch_size, ...]
568        // Sum over batch dimension -> [2, 2, 2, 2]
569        assert_eq!(b.grad(), Some(vec![2.0, 2.0, 2.0, 2.0]));
570    }
571
572    #[test]
573    fn test_batched_matmul() {
574        // (2, 2, 3) @ (2, 3, 2) -> (2, 2, 2)
575        let x = RawTensor::ones(&[2, 2, 3]);
576        let y = RawTensor::ones(&[2, 3, 2]);
577        let z = x.matmul(&y);
578
579        assert_eq!(z.borrow().shape, vec![2, 2, 2]);
580        // dot product of two [1,1,1] vecs is 3.0
581        assert_eq!(z.borrow().data.first().copied().unwrap_or(f32::NAN), 3.0);
582        assert_eq!(z.borrow().data.get(7).copied().unwrap_or(f32::NAN), 3.0);
583    }
584    #[test]
585    #[allow(clippy::identity_op)]
586    //Allow identity op to make the example clearer
587    fn test_batched_matmul_broadcasting() {
588        // (2, 1, 2, 3) @ (1, 2, 3, 1) -> (2, 2, 2, 1)
589        // Checks if batch dims [2, 1] and [1, 2] broadcast to [2, 2]
590
591        let a_data = vec![1.0; 2 * 1 * 2 * 3]; // 12 elements, all 1s
592        let b_data = vec![2.0; 1 * 2 * 3 * 1]; // 6 elements, all 2s
593
594        let a = RawTensor::new(a_data, &[2, 1, 2, 3], true);
595        let b = RawTensor::new(b_data, &[1, 2, 3, 1], true);
596
597        let c = a.matmul(&b);
598
599        // Output shape should be [2, 2, 2, 1]
600        assert_eq!(c.borrow().shape, vec![2, 2, 2, 1]);
601
602        // Values: Row (1,1,1) dot Col (2,2,2) = 3*2 = 6
603        assert_eq!(c.borrow().data.first().copied().unwrap_or(f32::NAN), 6.0);
604
605        let loss = c.sum();
606        loss.backward();
607
608        // Gradient check
609        // C sum is 8 elements * 6.0 = 48.0
610        // A grad should capture broadcasted dims
611        assert!(a.grad().is_some());
612        assert!(b.grad().is_some());
613    }
614
615    #[test]
616    fn test_matmul_matrix_vector_backward() {
617        // (m,n) @ (n,) -> (m,)
618        let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2], true);
619        let v = RawTensor::new(vec![0.5, -1.0], &[2], true);
620
621        // z = X @ v
622        let z = x.matmul(&v);
623        // loss = sum(z) => ∂L/∂z = 1
624        let loss = z.sum();
625        loss.backward();
626
627        // ∂L/∂X = outer(ones(m), v) = repeat v on each row
628        assert_eq!(x.grad(), Some(vec![0.5, -1.0, 0.5, -1.0, 0.5, -1.0]));
629        // ∂L/∂v = X^T @ ones(m) = column sums of X
630        // sums: col0 = 1+3+5 = 9, col1 = 2+4+6 = 12
631        assert_eq!(v.grad(), Some(vec![9.0, 12.0]));
632    }
633
634    #[test]
635    fn test_dot_backward() {
636        // (n,) @ (n,) -> scalar
637        let a = RawTensor::new(vec![1.0, 2.0, 3.0], &[3], true);
638        let b = RawTensor::new(vec![4.0, 5.0, 6.0], &[3], true);
639
640        // loss = a · b = 1*4 + 2*5 + 3*6 = 32
641        let loss = a.matmul(&b);
642        loss.backward();
643
644        // ∂L/∂a = b
645        assert_eq!(a.grad(), Some(vec![4.0, 5.0, 6.0]));
646        // ∂L/∂b = a
647        assert_eq!(b.grad(), Some(vec![1.0, 2.0, 3.0]));
648    }
649
650    #[test]
651    fn test_gradcheck_matrix_vector_matmul() {
652        // Check gradients numerically for X in (m,n) @ (n,)
653        let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0], &[2, 2], true);
654        let v = RawTensor::new(vec![0.3, -0.7], &[2], false);
655        let passed = RawTensor::check_gradients_simple(&x, |t| t.matmul(&v).sum());
656        assert!(passed, "Matrix-vector matmul gradient check failed");
657    }
658
659    #[test]
660    fn test_broadcast_sub() {
661        // Test that sub also broadcasts correctly
662        let x = RawTensor::new(vec![5.0, 6.0, 7.0, 8.0], &[2, 2], true);
663        let y = RawTensor::new(vec![1.0, 2.0], &[2], true);
664        let z = x.sub(&y);
665
666        assert_eq!(z.borrow().shape, vec![2, 2]);
667        assert_eq!(z.borrow().data, vec![4.0, 4.0, 6.0, 6.0]);
668
669        z.backward();
670
671        assert_eq!(x.grad(), Some(vec![1.0, 1.0, 1.0, 1.0]));
672        // Sub gradient for y is negative and summed
673        assert_eq!(y.grad(), Some(vec![-2.0, -2.0]));
674    }
675
676    // ===== NUMERICAL GRADIENT CHECKING TESTS =====
677
678    #[test]
679    fn test_gradcheck_unary_ops() {
680        // Test sqrt gradient
681        let x = RawTensor::new(vec![4.0, 9.0, 16.0], &[3], true);
682        let passed = RawTensor::check_gradients_simple(&x, |t| {
683            let y = t.sqrt();
684            y.sum()
685        });
686        assert!(passed, "Sqrt gradient check failed");
687
688        // Test sin gradient
689        let x = RawTensor::new(vec![0.5, 1.0, 1.5], &[3], true);
690        let passed = RawTensor::check_gradients_simple(&x, |t| {
691            let y = t.sin();
692            y.sum()
693        });
694        assert!(passed, "Sin gradient check failed");
695
696        // Test sigmoid gradient
697        let x = RawTensor::new(vec![0.0, 1.0, -1.0], &[3], true);
698        let passed = RawTensor::check_gradients_simple(&x, |t| {
699            let y = t.sigmoid();
700            y.sum()
701        });
702        assert!(passed, "Sigmoid gradient check failed");
703    }
704
705    #[test]
706    fn test_gradcheck_binary_ops() {
707        // Test add gradient
708        let x = RawTensor::new(vec![1.0, 2.0, 3.0], &[3], true);
709        let y = RawTensor::new(vec![4.0, 5.0, 6.0], &[3], false);
710        let passed = RawTensor::check_gradients_simple(&x, |t| {
711            let z = t.add(&y);
712            z.sum()
713        });
714        assert!(passed, "Add gradient check failed");
715
716        // Test mul gradient
717        let x = RawTensor::new(vec![2.0, 3.0], &[2], true);
718        let y = RawTensor::new(vec![4.0, 5.0], &[2], false);
719        let passed = RawTensor::check_gradients_simple(&x, |t| {
720            let z = t.elem_mul(&y);
721            z.sum()
722        });
723        assert!(passed, "Mul gradient check failed");
724
725        // Test div gradient
726        let x = RawTensor::new(vec![6.0, 8.0], &[2], true);
727        let y = RawTensor::new(vec![2.0, 4.0], &[2], false);
728        let passed = RawTensor::check_gradients_simple(&x, |t| {
729            let z = t.div(&y);
730            z.sum()
731        });
732        assert!(passed, "Div gradient check failed");
733    }
734
735    #[test]
736    fn test_gradcheck_matmul() {
737        // Test matmul gradient for first operand
738        let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
739        let w = RawTensor::new(
740            vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
741            &[3, 3],
742            false,
743        );
744        let passed = RawTensor::check_gradients_simple(&x, |t| {
745            let y = t.matmul(&w);
746            y.sum()
747        });
748        assert!(passed, "Matmul gradient check failed");
749    }
750
751    #[test]
752    fn test_gradcheck_broadcast() {
753        // Test broadcasting gradient
754        let x = RawTensor::new(vec![1.0, 2.0, 3.0], &[3], true);
755        let y = RawTensor::new(vec![0.5], &[1], false);
756        let passed = RawTensor::check_gradients_simple(&x, |t| {
757            let z = t.elem_mul(&y);
758            z.sum()
759        });
760        assert!(passed, "Broadcast gradient check failed");
761
762        // Test with matrix broadcast
763        let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0], &[2, 2], true);
764        let y = RawTensor::new(vec![0.5, 1.0], &[2], false);
765        let passed = RawTensor::check_gradients_simple(&x, |t| {
766            let z = t.add(&y);
767            z.sum()
768        });
769        assert!(passed, "Matrix broadcast gradient check failed");
770    }
771
772    #[test]
773    fn test_gradcheck_movement_ops() {
774        // Test reshape gradient
775        let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0], &[4], true);
776        let passed = RawTensor::check_gradients_simple(&x, |t| {
777            let y = t.reshape(&[2, 2]);
778            y.sum()
779        });
780        assert!(passed, "Reshape gradient check failed");
781
782        // Test permute gradient
783        let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0], &[2, 2], true);
784        let passed = RawTensor::check_gradients_simple(&x, |t| {
785            let y = t.permute(&[1, 0]);
786            y.sum()
787        });
788        assert!(passed, "Permute gradient check failed");
789
790        // Test pad gradient
791        let x = RawTensor::new(vec![1.0, 2.0], &[2], true);
792        let passed = RawTensor::check_gradients_simple(&x, |t| {
793            let y = t.pad(&[(1, 1)]);
794            y.sum()
795        });
796        assert!(passed, "Pad gradient check failed");
797
798        // Test shrink gradient
799        let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0], &[4], true);
800        let passed = RawTensor::check_gradients_simple(&x, |t| {
801            let y = t.shrink(&[(1, 3)]);
802            y.sum()
803        });
804        assert!(passed, "Shrink gradient check failed");
805    }
806
807    #[test]
808    fn test_gradcheck_reduce_ops() {
809        // Test mean gradient
810        let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0], &[4], true);
811        let passed = RawTensor::check_gradients_simple(&x, |t| t.mean());
812        assert!(passed, "Mean gradient check failed");
813
814        // Test max gradient (more challenging due to discontinuity)
815        let x = RawTensor::new(vec![1.0, 5.0, 3.0, 2.0], &[4], true);
816        let passed = RawTensor::check_gradients_simple(&x, |t| t.max_reduce());
817        assert!(passed, "Max gradient check failed");
818    }
819
820    #[test]
821    fn test_gradcheck_ternary_ops() {
822        // Test mulacc gradient
823        let x = RawTensor::new(vec![1.0, 2.0], &[2], true);
824        let y = RawTensor::new(vec![3.0, 4.0], &[2], false);
825        let z = RawTensor::new(vec![0.5, 1.0], &[2], false);
826        let passed = RawTensor::check_gradients_simple(&x, |t| {
827            let out = t.mulacc(&y, &z);
828            out.sum()
829        });
830        assert!(passed, "MulAcc gradient check failed");
831    }
832
833    #[test]
834    fn test_gradcheck_complex_chain() {
835        // Test complex computation graph
836        let x = RawTensor::new(vec![1.0, 2.0, 3.0], &[3], true);
837        let w = RawTensor::new(vec![0.5, 1.0, 1.5], &[3], false);
838
839        let passed = RawTensor::check_gradients_simple(&x, |t| {
840            // y = sigmoid(x * w)
841            let prod = t.elem_mul(&w);
842            let y = prod.sigmoid();
843            y.sum()
844        });
845        assert!(passed, "Complex chain gradient check failed");
846    }
847
848    #[test]
849    fn test_gradcheck_neural_network_layer() {
850        // Test full linear layer: y = xW + b
851        let x = RawTensor::new(vec![1.0, 2.0, 3.0], &[1, 3], true);
852        let w = RawTensor::new(vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6], &[3, 2], false);
853        let b = RawTensor::new(vec![0.1, 0.2], &[2], false);
854
855        let passed = RawTensor::check_gradients_simple(&x, |t| {
856            let y = t.matmul(&w);
857            let z = y.add(&b);
858            z.sum()
859        });
860        assert!(passed, "Neural network layer gradient check failed");
861    }
862
863    #[test]
864    fn test_gradcheck_with_tolerance() {
865        // Test with custom epsilon and tolerance
866        let x = RawTensor::new(vec![1.0, 2.0, 3.0], &[3], true);
867
868        let (max_err, mean_err, passed) = RawTensor::check_gradients(
869            &x,
870            |t| {
871                let y = t.relu();
872                y.sum()
873            },
874            1e-5, // smaller epsilon
875            1e-2, // larger tolerance (ReLU has discontinuity at 0)
876        );
877
878        assert!(passed, "Custom tolerance gradient check failed");
879        println!(
880            "ReLU gradcheck: max_err={:.6e}, mean_err={:.6e}",
881            max_err, mean_err
882        );
883    }
884
885    #[test]
886    fn test_gradcheck_multidim() {
887        // Test with 2D tensors
888        let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
889
890        let passed = RawTensor::check_gradients_simple(&x, |t| {
891            let y = t.sqrt();
892            let z = y.elem_mul(t);
893            z.sum()
894        });
895        assert!(passed, "Multidim gradient check failed");
896    }
897
898    #[test]
899    fn test_gradcheck_expand() {
900        let x = RawTensor::new(vec![1.0, 2.0], &[2, 1], true);
901        let passed = RawTensor::check_gradients_simple(&x, |t| {
902            let y = t.expand(&[2, 3]);
903            y.sum()
904        });
905        assert!(passed, "Expand gradient check failed");
906    }
907
908    #[test]
909    fn test_gradcheck_transpose() {
910        // Test standalone transpose gradient
911        let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
912        let passed = RawTensor::check_gradients_simple(&x, |t| {
913            let y = t.transpose();
914            // Multiply by some weights to make gradient non-uniform
915            let w = RawTensor::new(vec![0.5, 1.0, 1.5, 2.0, 2.5, 3.0], &[3, 2], false);
916            y.elem_mul(&w).sum()
917        });
918        assert!(passed, "Transpose gradient check failed");
919    }
920
921    #[test]
922    fn test_gradcheck_pad() {
923        let x = RawTensor::new(vec![1.0, 2.0, 3.0], &[3], true);
924        let passed = RawTensor::check_gradients_simple(&x, |t| {
925            let y = t.pad(&[(1, 1)]);
926            y.sum()
927        });
928        assert!(passed, "Pad gradient check failed");
929    }
930
931    #[test]
932    fn test_gradcheck_shrink() {
933        let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5], true);
934        let passed = RawTensor::check_gradients_simple(&x, |t| {
935            let y = t.shrink(&[(1, 4)]);
936            y.sum()
937        });
938        assert!(passed, "Shrink gradient check failed");
939    }
940
941    #[test]
942    fn test_gradcheck_stride() {
943        let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6], true);
944        let passed = RawTensor::check_gradients_simple(&x, |t| {
945            let y = t.stride_op(&[2]);
946            y.sum()
947        });
948        assert!(passed, "Stride gradient check failed");
949    }
950
951    #[test]
952    fn test_gradcheck_matmul_vec() {
953        // vec-mat: (n,) @ (n,p) -> (p,)
954        let x = RawTensor::new(vec![1.0, 2.0, 3.0], &[3], true);
955        let w = RawTensor::new(vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6], &[3, 2], false);
956        let passed = RawTensor::check_gradients_simple(&x, |t| {
957            let y = t.matmul(&w);
958            y.sum()
959        });
960        assert!(passed, "Vec-mat matmul gradient check failed");
961    }
962    #[test]
963    fn test_broadcast_3d_fix() {
964        // (2,1) broadcasted with (1,2,3) -> (1,2,3)
965        let x = RawTensor::new(vec![10.0, 20.0], &[2, 1], true);
966        let y = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[1, 2, 3], true);
967        let z = x.add(&y);
968
969        assert_eq!(z.borrow().shape, vec![1, 2, 3]);
970        // Row 0: [1,2,3] + 10 = [11,12,13]
971        // Row 1: [4,5,6] + 20 = [24,25,26]
972        assert_eq!(z.borrow().data, vec![11.0, 12.0, 13.0, 24.0, 25.0, 26.0]);
973
974        z.backward();
975        assert_eq!(x.grad(), Some(vec![3.0, 3.0])); // sum over last dim
976        assert_eq!(y.grad(), Some(vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0]));
977    }
978
979    #[test]
980    fn test_broadcast_batch_channels() {
981        // Typical conv bias: (B,C,H,W) + (C,1,1) -> (B,C,H,W)
982        let x = RawTensor::new((0..16).map(|i| i as f32).collect(), &[2, 2, 2, 2], true);
983        let bias = RawTensor::new(vec![0.1, 0.2], &[2, 1, 1], true);
984        let z = x.add(&bias);
985
986        assert_eq!(z.borrow().shape, vec![2, 2, 2, 2]);
987        let loss = z.sum();
988        loss.backward();
989
990        // Bias grad should sum over B,H,W -> [8.0, 8.0]
991        assert_eq!(bias.grad(), Some(vec![8.0, 8.0]));
992    }
993
994    #[test]
995    fn test_gradcheck_broadcast_3d() {
996        let x = RawTensor::new(vec![1.0, 2.0], &[2, 1], true);
997        let y = RawTensor::new(vec![0.5; 6], &[1, 2, 3], false);
998        let passed = RawTensor::check_gradients_simple(&x, |t| t.add(&y).sum());
999        assert!(passed, "3D broadcast gradcheck failed");
1000    }
1001    #[test]
1002    fn test_sequential_forward() {
1003        let model = Sequential::new(vec![
1004            Box::new(Linear::new(3, 4, true)),
1005            Box::new(ReLU),
1006            Box::new(Linear::new(4, 2, true)),
1007        ]);
1008
1009        let x = RawTensor::new(vec![1.0, 2.0, 3.0], &[1, 3], true);
1010        let y = model.forward(&x);
1011
1012        assert_eq!(y.borrow().shape, vec![1, 2]);
1013
1014        let loss = y.sum();
1015        loss.backward();
1016
1017        // All layer params should have gradients
1018        for param in model.parameters() {
1019            assert!(param.grad().is_some(), "Missing gradient");
1020        }
1021    }
1022
1023    #[test]
1024    fn test_sequential_zero_grad() {
1025        let mut model = Sequential::new(vec![Box::new(Linear::new(2, 3, true))]);
1026
1027        let x = RawTensor::new(vec![1.0, 2.0], &[1, 2], true);
1028        model.forward(&x).sum().backward();
1029
1030        // Params have grads
1031        assert!(model.parameters().first().unwrap().grad().is_some());
1032
1033        model.zero_grad();
1034
1035        // Grads cleared
1036        for p in model.parameters() {
1037            assert!(p.grad().is_none());
1038        }
1039    }
1040    #[test]
1041    fn test_adam_converges_faster() {
1042        // Robust test: Learn identity y=x with badly scaled gradients
1043        // Problem: y = 2*x.
1044        // SGD struggles with scaling differences if not tuned perfectly.
1045        // Adam adapts per-parameter learning rates.
1046
1047        let x_data: Vec<f32> = (0..10).map(|i| i as f32 * 0.1).collect(); // 10 samples
1048        let y_data: Vec<f32> = x_data.iter().map(|v| v * 2.0).collect();
1049
1050        let x = RawTensor::new(x_data.clone(), &[10, 1], false);
1051        let y = RawTensor::new(y_data.clone(), &[10, 1], false);
1052
1053        // Simple Linear model 1->1
1054        // Initialize deliberately far from solution (w=0)
1055        let layer = Linear::new(1, 1, false);
1056        // Force weight to 0.0
1057        *layer.weight.borrow_mut().data.first_mut().unwrap() = 0.0;
1058
1059        let model = Sequential::new(vec![Box::new(layer)]);
1060
1061        let params = model.parameters();
1062        // more aggressive learning rate, no weight decay
1063        let mut opt = Adam::new(params, 0.5, (0.9, 0.999), 1e-8, 0.0);
1064
1065        let mut losses = vec![];
1066        for _ in 0..50 {
1067            opt.zero_grad();
1068
1069            let pred = model.forward(&x);
1070            let loss = RawTensor::mse_loss(&pred, &y);
1071            loss.backward();
1072            opt.step();
1073
1074            losses.push(loss.borrow().data.first().copied().unwrap_or(f32::NAN));
1075        }
1076
1077        let final_loss = *losses.last().unwrap();
1078        assert!(
1079            final_loss < 0.01,
1080            "Adam failed simple regression convergence: {:.6}",
1081            final_loss
1082        );
1083    }
1084    #[test]
1085    fn test_adam_vs_sgd() {
1086        crate::manual_seed(42); //set for repro
1087        // Same setup, train two models
1088        fn train_model(use_adam: bool) -> f32 {
1089            let x_data = vec![0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0];
1090            let x = RawTensor::new(x_data, &[4, 2], false);
1091            let y_data = vec![0.0, 1.0, 1.0, 0.0];
1092            let y = RawTensor::new(y_data, &[4], false);
1093
1094            let model = Sequential::new(vec![
1095                Box::new(Linear::new(2, 8, true)),
1096                Box::new(ReLU),
1097                Box::new(Linear::new(8, 1, true)),
1098            ]);
1099
1100            let params = model.parameters();
1101
1102            if use_adam {
1103                let mut opt = Adam::new(params, 0.05, (0.9, 0.999), 1e-8, 0.0);
1104                for _ in 0..50 {
1105                    opt.zero_grad();
1106                    let pred = model.forward(&x).reshape(&[4]);
1107                    let loss = RawTensor::mse_loss(&pred, &y);
1108                    loss.backward();
1109                    opt.step();
1110                }
1111            } else {
1112                let mut opt = SGD::new(params, 0.01, 0.0, 0.0);
1113                for _ in 0..50 {
1114                    opt.zero_grad();
1115                    let pred = model.forward(&x).reshape(&[4]);
1116                    let loss = RawTensor::mse_loss(&pred, &y);
1117                    loss.backward();
1118                    opt.step();
1119                }
1120            }
1121
1122            // Return final loss
1123            let pred = model.forward(&x).reshape(&[4]);
1124            RawTensor::mse_loss(&pred, &y)
1125                .borrow()
1126                .data
1127                .first()
1128                .copied()
1129                .unwrap_or(f32::NAN)
1130        }
1131
1132        let adam_loss = train_model(true);
1133        let sgd_loss = train_model(false);
1134
1135        println!(
1136            "Adam final loss: {:.6}, SGD final loss: {:.6}",
1137            adam_loss, sgd_loss
1138        );
1139
1140        // Adam should be significantly better
1141        assert!(adam_loss < sgd_loss * 0.75, "Adam not outperforming SGD");
1142    }
1143    #[test]
1144    fn test_dataloader_iteration() {
1145        // 8 samples, 2 features each
1146        let data = (0..16).map(|i| i as f32).collect();
1147        let targets = vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
1148
1149        let mut loader = DataLoader::new(
1150            data,
1151            targets,
1152            &[2],  // 2 features per sample
1153            &[1],  // 1 target per sample
1154            3,     // batch_size
1155            false, // no shuffle for deterministic test
1156        );
1157
1158        // First batch: samples 0,1,2
1159        let (x, y) = loader.next().unwrap();
1160        assert_eq!(x.borrow().shape, vec![3, 2]);
1161        assert_eq!(x.borrow().data, vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);
1162        assert_eq!(y.borrow().shape, vec![3, 1]);
1163
1164        // Second batch: samples 3,4,5
1165        let (x, _y) = loader.next().unwrap();
1166        assert_eq!(x.borrow().shape, vec![3, 2]);
1167
1168        // Third batch: samples 6,7 (partial)
1169        let (x, _y) = loader.next().unwrap();
1170        assert_eq!(x.borrow().shape, vec![2, 2]);
1171
1172        // Done
1173        assert!(loader.next().is_none());
1174
1175        // Reset
1176        loader.reset();
1177        let (x, _) = loader.next().unwrap();
1178        assert_eq!(x.borrow().shape, vec![3, 2]);
1179    }
1180
1181    #[test]
1182    fn test_dataloader_in_training_loop() {
1183        let data = vec![0.0; 40]; // 10 samples, 4 features
1184        let targets = vec![1.0; 10];
1185
1186        let model = Sequential::new(vec![Box::new(Linear::new(4, 2, true))]);
1187
1188        let mut opt = SGD::new(model.parameters(), 0.1, 0.0, 0.0);
1189
1190        for epoch in 0..2 {
1191            let loader = DataLoader::new(data.clone(), targets.clone(), &[4], &[1], 3, false);
1192
1193            for (batch_x, _batch_y) in loader {
1194                opt.zero_grad();
1195                let pred = model.forward(&batch_x);
1196                // Dummy loss
1197                let loss = pred.sum();
1198                loss.backward();
1199                opt.step();
1200            }
1201
1202            println!("Epoch {} complete", epoch);
1203        }
1204    }
1205    #[test]
1206    fn bench_matmul_speedup() {
1207        use std::time::Instant;
1208
1209        let a = vec![1.0; 256 * 256];
1210        let b = vec![1.0; 256 * 256];
1211
1212        let start = Instant::now();
1213        let _ = RawTensor::matmul_raw(&a, &b, 256, 256, 256);
1214        let duration = start.elapsed();
1215
1216        println!("256x256 matmul: {:?}", duration);
1217        let max_duration_ms: u128 = if cfg!(all(feature = "accelerate", target_os = "macos")) {
1218            if cfg!(debug_assertions) { 50 } else { 10 }
1219        } else if cfg!(debug_assertions) {
1220            250
1221        } else {
1222            100
1223        };
1224
1225        assert!(
1226            duration.as_millis() < max_duration_ms,
1227            "Matmul took {:?} (> {}ms threshold for this build configuration)",
1228            duration,
1229            max_duration_ms
1230        );
1231    }
1232    #[test]
1233    fn test_batchnorm_working() {
1234        let mut bn = BatchNorm2d::new(3);
1235        let x = RawTensor::randn(&[2, 3, 4, 4]);
1236
1237        // Training mode
1238        bn.train(true);
1239        let y = bn.forward(&x);
1240        assert_eq!(y.borrow().shape, vec![2, 3, 4, 4]);
1241
1242        // Test mode
1243        bn.train(false);
1244        let y2 = bn.forward(&x);
1245        assert_eq!(y2.borrow().shape, vec![2, 3, 4, 4]);
1246    }
1247
1248    #[test]
1249    fn test_batchnorm1d_forward_shape() {
1250        let bn = BatchNorm1d::new(32);
1251        let x = RawTensor::randn(&[8, 32]); // batch=8, features=32
1252        let y = bn.forward(&x);
1253        assert_eq!(y.borrow().shape, vec![8, 32]);
1254    }
1255
1256    #[test]
1257    fn test_batchnorm1d_train_vs_test_mode() {
1258        let mut bn = BatchNorm1d::new(4);
1259
1260        // Training mode - run a few batches to populate running stats
1261        bn.train(true);
1262        for _ in 0..5 {
1263            let x = RawTensor::randn(&[16, 4]);
1264            let _ = bn.forward(&x);
1265        }
1266
1267        // Now test that test mode uses different stats
1268        let test_input = RawTensor::randn(&[8, 4]);
1269        bn.train(true);
1270        let y_train = bn.forward(&test_input);
1271
1272        bn.train(false);
1273        let y_test = bn.forward(&test_input);
1274
1275        // Outputs should differ because train mode uses batch stats
1276        // while test mode uses running stats
1277        let train_data = &y_train.borrow().data;
1278        let test_data = &y_test.borrow().data;
1279        let differs = train_data
1280            .iter()
1281            .zip(test_data.iter())
1282            .any(|(a, b)| (a - b).abs() > 1e-5);
1283        assert!(differs, "Train and test outputs should differ");
1284    }
1285
1286    #[test]
1287    fn test_batchnorm1d_parameters() {
1288        let bn = BatchNorm1d::new(16);
1289        let params = bn.parameters();
1290        // Should have gamma and beta
1291        assert_eq!(params.len(), 2);
1292        // gamma shape [16]
1293        assert_eq!(params.first().unwrap().borrow().shape, vec![16]);
1294        // beta shape [16]
1295        assert_eq!(params.get(1).unwrap().borrow().shape, vec![16]);
1296    }
1297
1298    #[test]
1299    fn test_pixelshuffle_forward_shape() {
1300        // Test with upscale_factor=3: [2, 36, 4, 4] -> [2, 4, 12, 12]
1301        let layer = PixelShuffle::new(3);
1302        let x = RawTensor::randn(&[2, 36, 4, 4]); // 4 channels * 9
1303        let y = layer.forward(&x);
1304        assert_eq!(y.borrow().shape, vec![2, 4, 12, 12]);
1305
1306        // Test with upscale_factor=2: [1, 12, 8, 8] -> [1, 3, 16, 16]
1307        let layer2 = PixelShuffle::new(2);
1308        let x2 = RawTensor::randn(&[1, 12, 8, 8]); // 3 channels * 4
1309        let y2 = layer2.forward(&x2);
1310        assert_eq!(y2.borrow().shape, vec![1, 3, 16, 16]);
1311    }
1312
1313    #[test]
1314    fn test_pixelshuffle_backward_flow() {
1315        let layer = PixelShuffle::new(2);
1316        let x = RawTensor::randn(&[2, 4, 3, 3]); // 1 channel * 4
1317        x.borrow_mut().requires_grad = true;
1318
1319        let y = layer.forward(&x);
1320        assert_eq!(y.borrow().shape, vec![2, 1, 6, 6]);
1321
1322        let loss = y.sum();
1323        loss.backward();
1324
1325        let grad = x.grad();
1326        assert!(
1327            grad.is_some(),
1328            "Gradient should flow back through PixelShuffle"
1329        );
1330        // Check gradient has correct number of elements: 2 * 4 * 3 * 3 = 72
1331        assert_eq!(grad.unwrap().len(), 72);
1332    }
1333
1334    #[test]
1335    fn test_pixelshuffle_values() {
1336        // Small manual test with known values
1337        // Input: [1, 4, 2, 2] with upscale_factor=2
1338        // This should rearrange 4 channels of 2x2 into 1 channel of 4x4
1339        let layer = PixelShuffle::new(2);
1340        #[rustfmt::skip]
1341        let data = vec![
1342            // Channel 0
1343            1.0, 2.0,
1344            3.0, 4.0,
1345            // Channel 1
1346            5.0, 6.0,
1347            7.0, 8.0,
1348            // Channel 2
1349            9.0, 10.0,
1350            11.0, 12.0,
1351            // Channel 3
1352            13.0, 14.0,
1353            15.0, 16.0,
1354        ];
1355        let x = RawTensor::new(data, &[1, 4, 2, 2], false);
1356        let y = layer.forward(&x);
1357
1358        assert_eq!(y.borrow().shape, vec![1, 1, 4, 4]);
1359
1360        // After PixelShuffle, the output should interleave values from the 4 input channels
1361        // The exact pattern depends on the reshape/permute order
1362        let out_data = &y.borrow().data;
1363        assert_eq!(out_data.len(), 16);
1364
1365        // Verify the transformation preserved all values (just check sum as a sanity check)
1366        let input_sum: f32 = (1..=16).map(|x| x as f32).sum();
1367        let output_sum: f32 = out_data.iter().sum();
1368        assert!((input_sum - output_sum).abs() < 1e-5);
1369    }
1370
1371    #[test]
1372    fn test_embedding_forward_shape() {
1373        let embedding = Embedding::new(100, 32);
1374        let indices = vec![5, 12, 7, 99];
1375        let output = embedding.forward(&indices);
1376        assert_eq!(output.borrow().shape, vec![4, 32]);
1377    }
1378
1379    #[test]
1380    fn test_embedding_backward_flow() {
1381        let embedding = Embedding::new(50, 16);
1382        let indices = vec![3, 10, 3]; // Note: repeated index 3
1383        let output = embedding.forward(&indices);
1384
1385        // Sum and backward
1386        let loss = output.sum();
1387        loss.backward();
1388
1389        // Check that weight has gradients
1390        let grad = embedding.weight.grad();
1391        assert!(grad.is_some(), "Weight should have gradients");
1392
1393        let grad_data = grad.unwrap();
1394        // Each embedding contributes 1.0 per dimension from sum
1395        // Index 3 appears twice, so should have accumulated grad of 2.0 per dim
1396        let grad_at_idx3_sum: f32 = (0..16)
1397            .map(|d| grad_data.get(3 * 16 + d).copied().unwrap_or(f32::NAN))
1398            .sum();
1399        let expected_sum = 2.0 * 16.0; // 2 occurrences * 16 dimensions
1400        assert!(
1401            (grad_at_idx3_sum - expected_sum).abs() < 1e-4,
1402            "Expected accumulated grad sum {}, got {}",
1403            expected_sum,
1404            grad_at_idx3_sum
1405        );
1406    }
1407
1408    #[test]
1409    fn test_embedding_gradient_accumulation() {
1410        let embedding = Embedding::new(10, 4);
1411        let indices = vec![2, 5, 2, 2]; // Index 2 appears 3 times
1412        let output = embedding.forward(&indices);
1413
1414        let loss = output.sum();
1415        loss.backward();
1416
1417        let grad = embedding.weight.grad().unwrap();
1418        // Index 2 should have grad of 3.0 per dimension (appears 3 times)
1419        for d in 0..4 {
1420            let grad_val = grad.get(2 * 4 + d).copied().unwrap_or(f32::NAN);
1421            assert!(
1422                (grad_val - 3.0).abs() < 1e-5,
1423                "Expected grad 3.0 for index 2, got {}",
1424                grad_val
1425            );
1426        }
1427
1428        // Index 5 should have grad of 1.0 per dimension (appears once)
1429        for d in 0..4 {
1430            let grad_val = grad.get(5 * 4 + d).copied().unwrap_or(f32::NAN);
1431            assert!(
1432                (grad_val - 1.0).abs() < 1e-5,
1433                "Expected grad 1.0 for index 5, got {}",
1434                grad_val
1435            );
1436        }
1437    }
1438}
1439
1440#[cfg(test)]
1441mod axis_reduce_tests {
1442    use super::*;
1443
1444    #[test]
1445    fn test_sum_dim_basic() {
1446        // [2,3] sum along dim=1 -> [2]
1447        let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
1448        let y = RawTensor::sum_dim(&x, 1, false);
1449
1450        assert_eq!(y.borrow().shape, vec![2]);
1451        assert_eq!(y.borrow().data, vec![6.0, 15.0]); // [1+2+3, 4+5+6]
1452    }
1453
1454    #[test]
1455    fn test_sum_dim_keepdim() {
1456        let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0], &[2, 2], false);
1457        let y = RawTensor::sum_dim(&x, 0, true);
1458
1459        assert_eq!(y.borrow().shape, vec![1, 2]);
1460        assert_eq!(y.borrow().data, vec![4.0, 6.0]); // [1+3, 2+4]
1461    }
1462
1463    #[test]
1464    fn test_sum_dim_backward() {
1465        let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
1466        let y = RawTensor::sum_dim(&x, 1, false); // [6, 15]
1467        y.backward();
1468
1469        // Gradient broadcasts back: each element contributed once
1470        assert_eq!(x.grad(), Some(vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0]));
1471    }
1472
1473    #[test]
1474    fn test_max_dim_basic() {
1475        let x = RawTensor::new(vec![1.0, 5.0, 3.0, 2.0, 8.0, 4.0], &[2, 3], false);
1476        let y = RawTensor::max_dim(&x, 1, false);
1477
1478        assert_eq!(y.borrow().shape, vec![2]);
1479        assert_eq!(y.borrow().data, vec![5.0, 8.0]); // max of each row
1480    }
1481
1482    #[test]
1483    fn test_max_dim_backward() {
1484        let x = RawTensor::new(vec![1.0, 5.0, 3.0, 2.0, 8.0, 4.0], &[2, 3], true);
1485        let y = RawTensor::max_dim(&x, 1, false);
1486        y.backward();
1487
1488        // Only max elements get gradient
1489        assert_eq!(x.grad(), Some(vec![0.0, 1.0, 0.0, 0.0, 1.0, 0.0]));
1490    }
1491
1492    #[test]
1493    fn test_gradcheck_sum_dim() {
1494        let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0], &[2, 2], true);
1495        let passed =
1496            RawTensor::check_gradients_simple(&x, |t| RawTensor::sum_dim(t, 0, false).sum());
1497        assert!(passed, "sum_dim gradient check failed");
1498    }
1499
1500    #[test]
1501    fn test_gradcheck_max_dim() {
1502        let x = RawTensor::new(vec![1.0, 5.0, 3.0, 2.0], &[2, 2], true);
1503        let passed =
1504            RawTensor::check_gradients_simple(&x, |t| RawTensor::max_dim(t, 1, false).sum());
1505        assert!(passed, "max_dim gradient check failed");
1506    }
1507
1508    #[test]
1509    fn test_softmax_forward() {
1510        // Test softmax computation
1511        let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
1512        let y = RawTensor::softmax(&x, 1);
1513
1514        // Each row should sum to 1.0
1515        let data = y.borrow();
1516        let row0_sum: f32 = data.data.get(0..3).unwrap().iter().sum();
1517        let row1_sum: f32 = data.data.get(3..6).unwrap().iter().sum();
1518
1519        approx::assert_relative_eq!(row0_sum, 1.0, epsilon = 1e-6);
1520        approx::assert_relative_eq!(row1_sum, 1.0, epsilon = 1e-6);
1521    }
1522
1523    #[test]
1524    fn test_gradcheck_softmax() {
1525        let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0], &[2, 2], true);
1526        let passed = RawTensor::check_gradients_simple(&x, |t| RawTensor::softmax(t, 1).sum());
1527        assert!(passed, "Softmax gradient check failed");
1528    }
1529
1530    #[test]
1531    fn test_cross_entropy_loss() {
1532        // Simple 2-class, 2-sample batch
1533        let logits = RawTensor::new(vec![2.0, 1.0, 0.5, 2.5], &[2, 2], true);
1534        let targets = RawTensor::new(vec![1.0, 0.0, 0.0, 1.0], &[2, 2], false);
1535
1536        let loss = RawTensor::cross_entropy_loss(&logits, &targets);
1537        loss.backward();
1538
1539        // Loss should be positive scalar
1540        assert_eq!(loss.borrow().shape, vec![1]);
1541        assert!(loss.borrow().data.first().copied().unwrap_or(f32::NAN) > 0.0);
1542
1543        // Gradients should exist and have correct shape
1544        assert_eq!(logits.grad().unwrap().len(), 4);
1545    }
1546    #[test]
1547    fn test_dropout_train_eval() {
1548        let mut dropout = Dropout::new(0.5);
1549        let x = RawTensor::ones(&[1000]);
1550
1551        // Train mode: roughly 50% should be zero, others scaled by 2
1552        dropout.train(true);
1553        let y = dropout.forward(&x);
1554        let y_data = &y.borrow().data;
1555        let num_zeros = y_data.iter().filter(|&&v| v == 0.0).count();
1556
1557        // Statistical check (allow some variance)
1558        assert!(
1559            num_zeros > 400 && num_zeros < 600,
1560            "Dropout ratio off: {}",
1561            num_zeros
1562        );
1563
1564        // Check scaling: non-zeros should be 2.0
1565        let non_zeros_correct = y_data.iter().all(|&v| v == 0.0 || v == 2.0);
1566        assert!(non_zeros_correct, "Dropout scaling incorrect");
1567
1568        // Eval mode: identity
1569        dropout.eval();
1570        let y_eval = dropout.forward(&x);
1571        let eval_correct = y_eval.borrow().data.iter().all(|&v| v == 1.0);
1572        assert!(eval_correct, "Dropout eval mode should be identity");
1573    }
1574
1575    #[test]
1576    fn test_weight_decay_sgd() {
1577        let w = RawTensor::new(vec![1.0], &[1], true);
1578        // SGD with 0.1 decay. Grad = 0.
1579        // Step should be: w = w - lr * (grad + decay * w) = 1.0 - 0.1 * (0 + 0.1 * 1.0) = 0.99
1580        let mut opt = SGD::new(vec![w.clone()], 0.1, 0.0, 0.1);
1581
1582        w.borrow_mut().grad = Some(Storage::cpu(vec![0.0])); // Artificial zero gradient
1583        opt.step();
1584
1585        let new_val = w.borrow().data.first().copied().unwrap_or(f32::NAN);
1586        approx::assert_relative_eq!(new_val, 0.99, epsilon = 1e-6);
1587    }
1588
1589    #[test]
1590    fn test_mean_dim() {
1591        // [2, 3]
1592        // [[1, 2, 3],
1593        //  [4, 5, 6]]
1594        let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
1595
1596        // mean(dim=1) -> [2, 5]
1597        let m = x.mean_dim(1, false);
1598        assert_eq!(m.borrow().shape, vec![2]);
1599        assert!((m.borrow().data.first().copied().unwrap_or(f32::NAN) - 2.0).abs() < 1e-6);
1600        assert!((m.borrow().data.get(1).copied().unwrap_or(f32::NAN) - 5.0).abs() < 1e-6);
1601
1602        // Check gradient
1603        m.sum().backward();
1604        // d(mean)/dx = 1/N. Here N=3. Grad should be 1/3 for all elements.
1605        let grads = x.grad().unwrap();
1606        for g in grads {
1607            assert!((g - 1.0 / 3.0).abs() < 1e-6);
1608        }
1609    }
1610
1611    /// Integration test: Simulated PyTorch model → Volta loading workflow
1612    ///
1613    /// This test demonstrates the full end-to-end workflow of:
1614    /// 1. Creating a "PyTorch-style" state dict (weights stored as [out, in])
1615    /// 2. Saving it to disk
1616    /// 3. Loading it with weight mapping (transpose + rename)
1617    /// 4. Loading into a Volta model with named layers
1618    /// 5. Verifying the model works correctly
1619    #[test]
1620    fn test_external_model_loading_integration() {
1621        use crate::io::{TensorData, load_state_dict, mapping::StateDictMapper, save_state_dict};
1622        use crate::nn::{Linear, Module, ReLU, Sequential};
1623        use std::collections::BTreeMap;
1624
1625        // Simulate a PyTorch model with 2 linear layers
1626        // PyTorch stores Linear weights as [out_features, in_features]
1627        let mut pytorch_state = BTreeMap::new();
1628
1629        // Layer 1: Linear(2, 3) -> PyTorch shape [3, 2]
1630        pytorch_state.insert(
1631            "fc1.weight".to_string(),
1632            TensorData {
1633                // PyTorch: [out=3, in=2] row-major: [[w00,w01], [w10,w11], [w20,w21]]
1634                data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
1635                shape: vec![3, 2],
1636            },
1637        );
1638        pytorch_state.insert(
1639            "fc1.bias".to_string(),
1640            TensorData {
1641                data: vec![0.1, 0.2, 0.3],
1642                shape: vec![3],
1643            },
1644        );
1645
1646        // Layer 2: Linear(3, 1) -> PyTorch shape [1, 3]
1647        pytorch_state.insert(
1648            "fc2.weight".to_string(),
1649            TensorData {
1650                data: vec![0.5, 0.6, 0.7],
1651                shape: vec![1, 3],
1652            },
1653        );
1654        pytorch_state.insert(
1655            "fc2.bias".to_string(),
1656            TensorData {
1657                data: vec![0.01],
1658                shape: vec![1],
1659            },
1660        );
1661
1662        // Save the "PyTorch" state dict
1663        let temp_path = std::env::temp_dir().join("test_pytorch_model.bin");
1664        save_state_dict(&pytorch_state, temp_path.to_str().unwrap()).unwrap();
1665
1666        // Load with weight mapping
1667        let loaded = load_state_dict(temp_path.to_str().unwrap()).unwrap();
1668
1669        // Create mapper: rename keys and transpose weights
1670        let mapper = StateDictMapper::new()
1671            .rename("fc1.weight", "encoder.weight")
1672            .rename("fc1.bias", "encoder.bias")
1673            .rename("fc2.weight", "decoder.weight")
1674            .rename("fc2.bias", "decoder.bias")
1675            .transpose("encoder.weight") // [3,2] -> [2,3]
1676            .transpose("decoder.weight"); // [1,3] -> [3,1]
1677
1678        let volta_state = mapper.map(loaded);
1679
1680        // Verify transformation
1681        assert!(volta_state.contains_key("encoder.weight"));
1682        assert!(volta_state.contains_key("decoder.weight"));
1683        assert_eq!(volta_state.get("encoder.weight").unwrap().shape, vec![2, 3]);
1684        assert_eq!(volta_state.get("decoder.weight").unwrap().shape, vec![3, 1]);
1685
1686        // Verify transpose correctness for encoder.weight
1687        // Original PyTorch [3,2]: [1,2, 3,4, 5,6]
1688        // Transposed [2,3]: [1,3,5, 2,4,6]
1689        let encoder_weight = &volta_state.get("encoder.weight").unwrap().data;
1690        assert_eq!(encoder_weight.first().copied().unwrap_or(f32::NAN), 1.0);
1691        assert_eq!(encoder_weight.get(1).copied().unwrap_or(f32::NAN), 3.0);
1692        assert_eq!(encoder_weight.get(2).copied().unwrap_or(f32::NAN), 5.0);
1693        assert_eq!(encoder_weight.get(3).copied().unwrap_or(f32::NAN), 2.0);
1694        assert_eq!(encoder_weight.get(4).copied().unwrap_or(f32::NAN), 4.0);
1695        assert_eq!(encoder_weight.get(5).copied().unwrap_or(f32::NAN), 6.0);
1696
1697        // Create Volta model with named layers
1698        let mut model = Sequential::builder()
1699            .add_named("encoder", Box::new(Linear::new(2, 3, true)))
1700            .add_unnamed(Box::new(ReLU))
1701            .add_named("decoder", Box::new(Linear::new(3, 1, true)))
1702            .build();
1703
1704        // Load the mapped state dict
1705        model.load_state_dict(&volta_state);
1706
1707        // Verify forward pass works
1708        let input = RawTensor::new(vec![1.0, 1.0], &[1, 2], false);
1709        let output = model.forward(&input);
1710
1711        // Output should be deterministic based on loaded weights
1712        assert_eq!(output.borrow().shape, vec![1, 1]);
1713
1714        // Verify we can retrieve layers by name
1715        assert!(model.get_named("encoder").is_some());
1716        assert!(model.get_named("decoder").is_some());
1717        assert!(model.get_named("nonexistent").is_none());
1718
1719        // Verify layer names
1720        let names = model.layer_names();
1721        assert_eq!(names.first().copied().unwrap_or(None), Some("encoder"));
1722        assert_eq!(names.get(1).copied().unwrap_or(None), None); // ReLU is unnamed
1723        assert_eq!(names.get(2).copied().unwrap_or(None), Some("decoder"));
1724    }
1725}
1726
1727#[cfg(test)]
1728mod gpu_tests {
1729    use super::*;
1730    use approx::assert_abs_diff_eq;
1731
1732    #[test]
1733    fn test_device_gpu_returns_none_when_disabled() {
1734        // When gpu feature is disabled, Device::gpu() should return None
1735        let gpu = Device::gpu();
1736        if cfg!(feature = "gpu") {
1737            // If GPU is available, gpu() should return Some device
1738            // We can't guarantee GPU is available on all systems
1739            if is_gpu_available() {
1740                assert!(gpu.is_some());
1741                assert!(gpu.unwrap().is_gpu());
1742            }
1743        } else {
1744            // Without gpu feature, should always be None
1745            assert!(gpu.is_none());
1746        }
1747    }
1748
1749    #[test]
1750    fn test_to_device_cpu_to_cpu() {
1751        let t = RawTensor::new(vec![1.0, 2.0, 3.0], &[3], false);
1752        let t_cpu = t.to_device(Device::CPU);
1753
1754        // Same device should return same tensor (fast path)
1755        assert_eq!(t_cpu.borrow().device, Device::CPU);
1756        assert_eq!(t_cpu.borrow().data.to_vec(), vec![1.0, 2.0, 3.0]);
1757    }
1758
1759    #[cfg(feature = "gpu")]
1760    #[test]
1761    fn test_to_device_cpu_to_gpu() {
1762        if !is_gpu_available() {
1763            return; // Skip test if GPU not available
1764        }
1765
1766        let t = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0], &[2, 2], false);
1767        let gpu_device = Device::gpu().expect("GPU should be available");
1768        let t_gpu = t.to_device(gpu_device.clone());
1769
1770        // Device should be GPU
1771        assert!(t_gpu.borrow().device.is_gpu());
1772        assert_eq!(t_gpu.borrow().device.name(), gpu_device.name());
1773
1774        // Data should be preserved
1775        assert_eq!(t_gpu.borrow().data.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
1776        assert_eq!(t_gpu.borrow().shape, vec![2, 2]);
1777    }
1778
1779    #[cfg(feature = "gpu")]
1780    #[test]
1781    fn test_to_device_gpu_to_cpu() {
1782        if !is_gpu_available() {
1783            return; // Skip test if GPU not available
1784        }
1785
1786        let gpu_device = Device::gpu().expect("GPU should be available");
1787        let t = RawTensor::new(vec![5.0, 6.0, 7.0], &[3], false);
1788        let t_gpu = t.to_device(gpu_device.clone());
1789
1790        // Move back to CPU
1791        let t_cpu = t_gpu.to_device(Device::CPU);
1792
1793        assert!(t_cpu.borrow().device.is_cpu());
1794        assert_eq!(t_cpu.borrow().data.to_vec(), vec![5.0, 6.0, 7.0]);
1795    }
1796
1797    #[cfg(feature = "gpu")]
1798    #[test]
1799    fn test_to_device_preserves_autograd_metadata() {
1800        if !is_gpu_available() {
1801            return; // Skip test if GPU not available
1802        }
1803
1804        // Create a simple computation graph
1805        let a = RawTensor::new(vec![2.0], &[1], true);
1806        let b = RawTensor::new(vec![3.0], &[1], true);
1807        let c = a.add(&b);
1808
1809        // Move result to GPU
1810        let gpu_device = Device::gpu().expect("GPU should be available");
1811        let c_gpu = c.to_device(gpu_device);
1812
1813        // Autograd metadata should be preserved
1814        assert!(c_gpu.borrow().requires_grad);
1815        assert!(!c_gpu.borrow().parents.is_empty());
1816        assert!(c_gpu.borrow().grad_fn.is_some());
1817
1818        // Note: Gradients are still computed on CPU
1819        // This is a known limitation documented in to_device()
1820    }
1821
1822    #[cfg(feature = "gpu")]
1823    #[test]
1824    fn test_matmul_backward_gpu() {
1825        if !is_gpu_available() {
1826            return; // Skip test if GPU not available
1827        }
1828
1829        let gpu_device = Device::gpu().expect("GPU should be available");
1830
1831        // Simple 2x3 @ 3x4 = 2x4 case
1832        let a = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true)
1833            .to_device(gpu_device.clone());
1834        let b = RawTensor::new(
1835            vec![
1836                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1837            ],
1838            &[3, 4],
1839            true,
1840        )
1841        .to_device(gpu_device.clone());
1842        let c = a.matmul(&b);
1843
1844        // Forward should be on GPU
1845        assert!(c.borrow().device.is_gpu());
1846
1847        // Backward should compute gradients on GPU
1848        c.backward();
1849
1850        // Gradients should be on GPU
1851        {
1852            let a_ref = a.borrow();
1853            let b_ref = b.borrow();
1854            let a_grad = a_ref.grad.as_ref().expect("a should have grad");
1855            let b_grad = b_ref.grad.as_ref().expect("b should have grad");
1856            assert!(a_grad.is_gpu());
1857            assert!(b_grad.is_gpu());
1858        }
1859
1860        // Verify gradient values are correct by comparing to CPU computation
1861        let a_cpu = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
1862        let b_cpu = RawTensor::new(
1863            vec![
1864                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1865            ],
1866            &[3, 4],
1867            true,
1868        );
1869        let c_cpu = a_cpu.matmul(&b_cpu);
1870        c_cpu.backward();
1871
1872        let a_grad_data;
1873        let b_grad_data;
1874        {
1875            let a_ref = a.borrow();
1876            let b_ref = b.borrow();
1877            a_grad_data = a_ref.grad.as_ref().unwrap().to_vec();
1878            b_grad_data = b_ref.grad.as_ref().unwrap().to_vec();
1879        }
1880
1881        let a_grad_cpu_data;
1882        let b_grad_cpu_data;
1883        {
1884            let a_ref = a_cpu.borrow();
1885            let b_ref = b_cpu.borrow();
1886            a_grad_cpu_data = a_ref.grad.as_ref().unwrap().to_vec();
1887            b_grad_cpu_data = b_ref.grad.as_ref().unwrap().to_vec();
1888        }
1889
1890        assert_eq!(a_grad_data.len(), a_grad_cpu_data.len());
1891        assert_eq!(b_grad_data.len(), b_grad_cpu_data.len());
1892
1893        // Check values are approximately equal
1894        for (gpu_val, cpu_val) in a_grad_data.iter().zip(a_grad_cpu_data.iter()) {
1895            assert!((gpu_val - cpu_val).abs() < 1e-5);
1896        }
1897        for (gpu_val, cpu_val) in b_grad_data.iter().zip(b_grad_cpu_data.iter()) {
1898            assert!((gpu_val - cpu_val).abs() < 1e-5);
1899        }
1900    }
1901
1902    #[cfg(feature = "gpu")]
1903    #[test]
1904    fn test_linear_layer_backward_gpu() {
1905        if !is_gpu_available() {
1906            return; // Skip test if GPU not available
1907        }
1908
1909        use crate::nn::{Linear, Module};
1910
1911        let gpu_device = Device::gpu().expect("GPU should be available");
1912
1913        let layer = Linear::new(4, 3, true);
1914
1915        // Move layer parameters to GPU
1916        let params = layer.parameters();
1917        for param in &params {
1918            let p = RawTensor::to_device(param, gpu_device.clone());
1919            *param.borrow_mut() = p.borrow().clone();
1920        }
1921
1922        let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4], true)
1923            .to_device(gpu_device);
1924        let out = layer.forward(&x);
1925        let loss = out.sum();
1926
1927        loss.backward();
1928
1929        // Check gradients exist and are on GPU
1930        let params = layer.parameters();
1931        assert!(!params.is_empty());
1932
1933        for param in params {
1934            let param_ref = param.borrow();
1935            if let Some(grad) = &param_ref.grad {
1936                assert!(grad.is_gpu(), "Gradient should be on GPU");
1937            }
1938        }
1939    }
1940
1941    #[cfg(feature = "gpu")]
1942    #[test]
1943    fn test_sum_backward_gpu() {
1944        if !is_gpu_available() {
1945            return;
1946        }
1947
1948        let gpu_device = Device::gpu().expect("GPU should be available");
1949
1950        let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true)
1951            .to_device(gpu_device.clone());
1952        let sum_result = x.sum();
1953
1954        // Forward should be on GPU
1955        assert!(sum_result.borrow().device.is_gpu());
1956
1957        // Backward should compute gradients on GPU
1958        sum_result.backward();
1959
1960        // Gradient should be on GPU and all ones
1961        let grad_data;
1962        {
1963            let x_ref = x.borrow();
1964            let x_grad = x_ref.grad.as_ref().expect("x should have grad");
1965            assert!(x_grad.is_gpu());
1966            grad_data = x_grad.to_vec();
1967        }
1968
1969        assert_eq!(grad_data.len(), 6);
1970        for &val in &grad_data {
1971            assert!((val - 1.0).abs() < 1e-5);
1972        }
1973    }
1974
1975    #[cfg(feature = "gpu")]
1976    #[test]
1977    fn test_mean_backward_gpu() {
1978        if !is_gpu_available() {
1979            return;
1980        }
1981
1982        let gpu_device = Device::gpu().expect("GPU should be available");
1983
1984        let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true)
1985            .to_device(gpu_device.clone());
1986        let mean_result = x.mean();
1987
1988        // Forward should be on GPU
1989        assert!(mean_result.borrow().device.is_gpu());
1990
1991        // Backward should compute gradients on GPU
1992        mean_result.backward();
1993
1994        // Gradient should be on GPU and all 1/6
1995        let grad_data;
1996        {
1997            let x_ref = x.borrow();
1998            let x_grad = x_ref.grad.as_ref().expect("x should have grad");
1999            assert!(x_grad.is_gpu());
2000            grad_data = x_grad.to_vec();
2001        }
2002
2003        assert_eq!(grad_data.len(), 6);
2004        let expected = 1.0 / 6.0;
2005        for &val in &grad_data {
2006            assert!((val - expected).abs() < 1e-5);
2007        }
2008    }
2009
2010    #[cfg(feature = "gpu")]
2011    #[test]
2012    fn test_max_backward_gpu() {
2013        if !is_gpu_available() {
2014            return;
2015        }
2016
2017        let gpu_device = Device::gpu().expect("GPU should be available");
2018
2019        let x = RawTensor::new(vec![1.0, 5.0, 3.0, 4.0, 2.0, 6.0], &[2, 3], true)
2020            .to_device(gpu_device.clone());
2021        let max_result = x.max_reduce();
2022
2023        // Forward should be on GPU
2024        assert!(max_result.borrow().device.is_gpu());
2025
2026        // Backward should compute gradients on GPU
2027        max_result.backward();
2028
2029        // Gradient should be on GPU and only max element (6.0 at index 5) gets grad
2030        let grad_data;
2031        {
2032            let x_ref = x.borrow();
2033            let x_grad = x_ref.grad.as_ref().expect("x should have grad");
2034            assert!(x_grad.is_gpu());
2035            grad_data = x_grad.to_vec();
2036        }
2037
2038        assert_eq!(grad_data.len(), 6);
2039
2040        // Only the max element (6.0 at linear index 5) should have gradient 1.0
2041        for (i, &val) in grad_data.iter().enumerate() {
2042            if i == 5 {
2043                assert!((val - 1.0).abs() < 1e-5);
2044            } else {
2045                assert!(val.abs() < 1e-5);
2046            }
2047        }
2048    }
2049
2050    #[cfg(feature = "gpu")]
2051    #[test]
2052    fn test_reduction_backward_gpu_cpu_equivalence() {
2053        if !is_gpu_available() {
2054            return;
2055        }
2056
2057        let gpu_device = Device::gpu().expect("GPU should be available");
2058
2059        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
2060        let shape = &[2, 4];
2061
2062        // Test sum
2063        let x_cpu = RawTensor::new(data.clone(), shape, true);
2064        let sum_cpu = x_cpu.sum();
2065        sum_cpu.backward();
2066        let sum_grad_cpu = x_cpu.borrow().grad.as_ref().unwrap().to_vec();
2067
2068        let x_gpu = RawTensor::new(data.clone(), shape, true).to_device(gpu_device.clone());
2069        let sum_gpu = x_gpu.sum();
2070        sum_gpu.backward();
2071        let sum_grad_gpu = x_gpu.borrow().grad.as_ref().unwrap().to_vec();
2072
2073        assert_eq!(sum_grad_cpu.len(), sum_grad_gpu.len());
2074        for (cpu_val, gpu_val) in sum_grad_cpu.iter().zip(sum_grad_gpu.iter()) {
2075            assert!((cpu_val - gpu_val).abs() < 1e-5);
2076        }
2077
2078        // Test mean
2079        let x_cpu2 = RawTensor::new(data.clone(), shape, true);
2080        let mean_cpu = x_cpu2.mean();
2081        mean_cpu.backward();
2082        let mean_grad_cpu = x_cpu2.borrow().grad.as_ref().unwrap().to_vec();
2083
2084        let x_gpu2 = RawTensor::new(data.clone(), shape, true).to_device(gpu_device.clone());
2085        let mean_gpu = x_gpu2.mean();
2086        mean_gpu.backward();
2087        let mean_grad_gpu = x_gpu2.borrow().grad.as_ref().unwrap().to_vec();
2088
2089        assert_eq!(mean_grad_cpu.len(), mean_grad_gpu.len());
2090        for (cpu_val, gpu_val) in mean_grad_cpu.iter().zip(mean_grad_gpu.iter()) {
2091            assert!((cpu_val - gpu_val).abs() < 1e-5);
2092        }
2093
2094        // Test max
2095        let x_cpu3 = RawTensor::new(data.clone(), shape, true);
2096        let max_cpu = x_cpu3.max_reduce();
2097        max_cpu.backward();
2098        let max_grad_cpu = x_cpu3.borrow().grad.as_ref().unwrap().to_vec();
2099
2100        let x_gpu3 = RawTensor::new(data.clone(), shape, true).to_device(gpu_device);
2101        let max_gpu = x_gpu3.max_reduce();
2102        max_gpu.backward();
2103        let max_grad_gpu = x_gpu3.borrow().grad.as_ref().unwrap().to_vec();
2104
2105        assert_eq!(max_grad_cpu.len(), max_grad_gpu.len());
2106        for (cpu_val, gpu_val) in max_grad_cpu.iter().zip(max_grad_gpu.iter()) {
2107            assert!((cpu_val - gpu_val).abs() < 1e-5);
2108        }
2109    }
2110
2111    #[cfg(feature = "gpu")]
2112    #[test]
2113    fn test_sgd_optimizer_gpu() {
2114        use crate::nn::optim::SGD;
2115
2116        if !is_gpu_available() {
2117            return;
2118        }
2119
2120        let gpu_device = Device::gpu().expect("GPU should be available");
2121
2122        // Create a simple parameter on GPU
2123        let param = RawTensor::new(vec![0.5; 10], &[10], true).to_device(gpu_device.clone());
2124
2125        // Create optimizer with the parameter
2126        let mut opt = SGD::new(vec![param.clone()], 0.01, 0.0, 0.0);
2127
2128        // Manually set a gradient on GPU
2129        {
2130            let mut p = param.borrow_mut();
2131            let grad_data = vec![0.1; 10];
2132            p.grad = Some(Storage::gpu(grad_data));
2133        }
2134
2135        // Step should update the parameter
2136        let param_before = param.borrow().data.to_vec();
2137        opt.step();
2138        let param_after = param.borrow().data.to_vec();
2139
2140        // Parameter should have changed (param -= lr * grad = 0.01 * 0.1 = 0.001 per element)
2141        for (before, after) in param_before.iter().zip(param_after.iter()) {
2142            assert!((after - (before - 0.001)).abs() < 1e-5);
2143        }
2144    }
2145
2146    #[cfg(feature = "gpu")]
2147    #[test]
2148    fn test_adam_optimizer_gpu() {
2149        use crate::nn::optim::Adam;
2150
2151        if !is_gpu_available() {
2152            return;
2153        }
2154
2155        let gpu_device = Device::gpu().expect("GPU should be available");
2156
2157        // Create a simple parameter on GPU
2158        let param = RawTensor::new(vec![0.5; 10], &[10], true).to_device(gpu_device.clone());
2159
2160        // Create Adam optimizer with the parameter
2161        let mut opt = Adam::new(vec![param.clone()], 0.01, (0.9, 0.999), 1e-8, 0.0);
2162
2163        // Manually set a gradient on GPU
2164        {
2165            let mut p = param.borrow_mut();
2166            let grad_data = vec![0.1; 10];
2167            p.grad = Some(Storage::gpu(grad_data));
2168        }
2169
2170        // Step should update the parameter
2171        let param_before = param.borrow().data.to_vec();
2172        opt.step();
2173        let param_after = param.borrow().data.to_vec();
2174
2175        // Parameter should have changed (Adam update formula is complex, but should change)
2176        for (before, after) in param_before.iter().zip(param_after.iter()) {
2177            assert_ne!(before, after, "Parameter should change after Adam step");
2178        }
2179    }
2180
2181    #[cfg(feature = "gpu")]
2182    #[test]
2183    fn test_sgd_momentum_optimizer_gpu() {
2184        use crate::nn::optim::SGD;
2185
2186        if !is_gpu_available() {
2187            return;
2188        }
2189
2190        let gpu_device = Device::gpu().expect("GPU should be available");
2191
2192        // Create a simple parameter on GPU
2193        let param = RawTensor::new(vec![0.5; 10], &[10], true).to_device(gpu_device.clone());
2194
2195        // Create SGD optimizer with momentum
2196        let mut opt = SGD::new(vec![param.clone()], 0.01, 0.9, 0.0);
2197
2198        // Manually set a gradient on GPU
2199        {
2200            let mut p = param.borrow_mut();
2201            let grad_data = vec![0.1; 10];
2202            p.grad = Some(Storage::gpu(grad_data));
2203        }
2204
2205        // First step
2206        let param_before = param.borrow().data.to_vec();
2207        opt.step();
2208        let param_after_step1 = param.borrow().data.to_vec();
2209
2210        // Parameter should have changed
2211        for (before, after) in param_before.iter().zip(param_after_step1.iter()) {
2212            assert_ne!(before, after);
2213        }
2214
2215        // Set same gradient again
2216        {
2217            let mut p = param.borrow_mut();
2218            p.grad = Some(Storage::gpu(vec![0.1; 10]));
2219        }
2220
2221        // Second step should apply different update due to momentum
2222        let param_after_step2;
2223        {
2224            let p = param.borrow();
2225            param_after_step2 = p.data.to_vec();
2226        }
2227
2228        // With momentum, second step should be different from first
2229        // (momentum accumulates gradient velocity)
2230        let changes_step1: Vec<f32> = param_after_step1
2231            .iter()
2232            .zip(param_before.iter())
2233            .map(|(a, b)| a - b)
2234            .collect();
2235        let changes_step2: Vec<f32> = param_after_step2
2236            .iter()
2237            .zip(param_after_step1.iter())
2238            .map(|(a, b)| a - b)
2239            .collect();
2240
2241        // Changes should be different due to momentum accumulation
2242        for (c1, c2) in changes_step1.iter().zip(changes_step2.iter()) {
2243            assert_ne!(c1, c2);
2244        }
2245    }
2246
2247    #[cfg(feature = "gpu")]
2248    #[test]
2249    fn test_optimizer_gpu_cpu_equivalence() {
2250        use crate::nn::optim::Adam;
2251
2252        if !is_gpu_available() {
2253            return;
2254        }
2255
2256        let gpu_device = Device::gpu().expect("GPU should be available");
2257
2258        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
2259        let grad_data = vec![0.1, 0.2, 0.3, 0.4, 0.5];
2260
2261        // CPU parameter and optimizer
2262        let param_cpu = RawTensor::new(data.clone(), &[5], true);
2263        let mut opt_cpu = Adam::new(vec![param_cpu.clone()], 0.01, (0.9, 0.999), 1e-8, 0.0);
2264        {
2265            let mut p = param_cpu.borrow_mut();
2266            p.grad = Some(Storage::cpu(grad_data.clone()));
2267        }
2268
2269        // GPU parameter and optimizer
2270        let param_gpu = RawTensor::new(data, &[5], true).to_device(gpu_device.clone());
2271        let mut opt_gpu = Adam::new(vec![param_gpu.clone()], 0.01, (0.9, 0.999), 1e-8, 0.0);
2272        {
2273            let mut p = param_gpu.borrow_mut();
2274            // Create GPU gradient
2275            p.grad = Some(Storage::gpu(grad_data));
2276        }
2277
2278        // Take one step on both
2279        opt_cpu.step();
2280        opt_gpu.step();
2281
2282        // Results should be approximately equal
2283        let result_cpu = param_cpu.borrow().data.to_vec();
2284        let result_gpu = param_gpu.borrow().data.to_vec();
2285
2286        assert_eq!(result_cpu.len(), result_gpu.len());
2287        for (cpu_val, gpu_val) in result_cpu.iter().zip(result_gpu.iter()) {
2288            assert!(
2289                (cpu_val - gpu_val).abs() < 1e-4,
2290                "CPU={cpu_val}, GPU={gpu_val}"
2291            );
2292        }
2293    }
2294
2295    #[cfg(feature = "gpu")]
2296    #[test]
2297    fn test_optimizer_state_stays_on_gpu() {
2298        use crate::nn::optim::Adam;
2299
2300        if !is_gpu_available() {
2301            return;
2302        }
2303
2304        let gpu_device = Device::gpu().expect("GPU should be available");
2305
2306        // Create a parameter on GPU
2307        let param = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5], true).to_device(gpu_device);
2308
2309        // Create Adam optimizer - state should be initialized on GPU
2310        let mut opt = Adam::new(vec![param.clone()], 0.01, (0.9, 0.999), 1e-8, 0.0);
2311
2312        // Verify state is on GPU
2313        // We can't directly access m and v since they're private,
2314        // but we can verify the optimizer works without CPU transfers
2315
2316        // Set gradient on GPU
2317        {
2318            let mut p = param.borrow_mut();
2319            p.grad = Some(Storage::gpu(vec![0.1, 0.2, 0.3, 0.4, 0.5]));
2320        }
2321
2322        // Take multiple steps
2323        for _ in 0..5 {
2324            // Update gradient each step
2325            {
2326                let mut p = param.borrow_mut();
2327                p.grad = Some(Storage::gpu(vec![0.1, 0.2, 0.3, 0.4, 0.5]));
2328            }
2329            opt.step();
2330        }
2331
2332        // If state was being transferred to CPU each step, this would be much slower
2333        // and the test would time out. For this test, we just verify it completes
2334        // successfully without errors.
2335
2336        // Verify parameter changed
2337        let result = param.borrow().data.to_vec();
2338        for val in result.iter() {
2339            assert_ne!(*val, 0.0, "Parameter should have been updated");
2340        }
2341    }
2342
2343    #[cfg(feature = "gpu")]
2344    #[test]
2345    fn test_sgd_momentum_state_stays_on_gpu() {
2346        use crate::nn::optim::SGD;
2347
2348        if !is_gpu_available() {
2349            return;
2350        }
2351
2352        let gpu_device = Device::gpu().expect("GPU should be available");
2353
2354        // Create a parameter on GPU
2355        let param = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5], true).to_device(gpu_device);
2356
2357        // Create SGD with momentum - velocity state should be on GPU
2358        let mut opt = SGD::new(vec![param.clone()], 0.01, 0.9, 0.0);
2359
2360        // Set gradient on GPU and take multiple steps
2361        for _ in 0..5 {
2362            {
2363                let mut p = param.borrow_mut();
2364                p.grad = Some(Storage::gpu(vec![0.1, 0.2, 0.3, 0.4, 0.5]));
2365            }
2366            opt.step();
2367        }
2368
2369        // Verify parameter changed
2370        let result = param.borrow().data.to_vec();
2371        for val in result.iter() {
2372            assert_ne!(*val, 0.0, "Parameter should have been updated");
2373        }
2374    }
2375
2376    #[cfg(feature = "gpu")]
2377    #[test]
2378    fn test_gpu_binary_backward_broadcast_add() {
2379        if !is_gpu_available() {
2380            return;
2381        }
2382
2383        let gpu_device = Device::gpu().expect("GPU should be available");
2384
2385        // Test case: (3, 1) + (1, 4) -> (3, 4)
2386        // a_grad should sum over dim 1: (3, 1)
2387        // b_grad should sum over dim 0: (1, 4)
2388        let a = RawTensor::new(vec![1.0, 2.0, 3.0], &[3, 1], true).to_device(gpu_device.clone());
2389        let b = RawTensor::new(vec![10.0, 20.0, 30.0, 40.0], &[1, 4], true)
2390            .to_device(gpu_device.clone());
2391
2392        let c = a.add(&b);
2393        c.backward();
2394
2395        // Verify gradients exist
2396        assert!(a.grad().is_some());
2397        assert!(b.grad().is_some());
2398
2399        let a_grad = a.grad().unwrap();
2400        let b_grad = b.grad().unwrap();
2401
2402        // Each a element received gradient from 4 b elements
2403        assert_abs_diff_eq!(
2404            a_grad.first().copied().unwrap_or(f32::NAN),
2405            4.0,
2406            epsilon = 1e-3
2407        );
2408        assert_abs_diff_eq!(
2409            a_grad.get(1).copied().unwrap_or(f32::NAN),
2410            4.0,
2411            epsilon = 1e-3
2412        );
2413        assert_abs_diff_eq!(
2414            a_grad.get(2).copied().unwrap_or(f32::NAN),
2415            4.0,
2416            epsilon = 1e-3
2417        );
2418
2419        // Each b element received gradient from 3 a elements
2420        assert_abs_diff_eq!(
2421            b_grad.first().copied().unwrap_or(f32::NAN),
2422            3.0,
2423            epsilon = 1e-3
2424        );
2425        assert_abs_diff_eq!(
2426            b_grad.get(1).copied().unwrap_or(f32::NAN),
2427            3.0,
2428            epsilon = 1e-3
2429        );
2430        assert_abs_diff_eq!(
2431            b_grad.get(2).copied().unwrap_or(f32::NAN),
2432            3.0,
2433            epsilon = 1e-3
2434        );
2435        assert_abs_diff_eq!(
2436            b_grad.get(3).copied().unwrap_or(f32::NAN),
2437            3.0,
2438            epsilon = 1e-3
2439        );
2440    }
2441
2442    #[cfg(feature = "gpu")]
2443    #[test]
2444    fn test_gpu_binary_backward_broadcast_mul() {
2445        if !is_gpu_available() {
2446            return;
2447        }
2448
2449        let gpu_device = Device::gpu().expect("GPU should be available");
2450
2451        // Test broadcasting multiplication
2452        let a = RawTensor::new(vec![1.0, 2.0], &[2, 1], true).to_device(gpu_device.clone());
2453        let b = RawTensor::new(vec![10.0, 20.0, 30.0], &[1, 3], true).to_device(gpu_device.clone());
2454
2455        let c = a.elem_mul(&b);
2456        c.backward();
2457
2458        // Verify gradients exist and are finite
2459        assert!(a.grad().is_some());
2460        assert!(b.grad().is_some());
2461
2462        let a_grad = a.grad().unwrap();
2463        let b_grad = b.grad().unwrap();
2464
2465        // All gradients should be finite
2466        for g in a_grad.iter() {
2467            assert!(g.is_finite(), "a_grad contains non-finite value");
2468        }
2469        for g in b_grad.iter() {
2470            assert!(g.is_finite(), "b_grad contains non-finite value");
2471        }
2472    }
2473
2474    #[cfg(feature = "gpu")]
2475    #[test]
2476    fn test_gpu_binary_backward_broadcast_stress() {
2477        if !is_gpu_available() {
2478            return;
2479        }
2480
2481        let gpu_device = Device::gpu().expect("GPU should be available");
2482
2483        // Stress test with many output positions mapping to same input
2484        let a = RawTensor::new(vec![1.0], &[1], true).to_device(gpu_device.clone());
2485        let b = RawTensor::new(
2486            (0..1000).map(|i| i as f32).collect::<Vec<_>>(),
2487            &[1000],
2488            true,
2489        )
2490        .to_device(gpu_device.clone());
2491
2492        let c = a.add(&b);
2493        c.backward();
2494
2495        // a should accumulate gradient from all 1000 positions
2496        let a_grad = a.grad().unwrap();
2497        assert_abs_diff_eq!(
2498            a_grad.first().copied().unwrap_or(f32::NAN),
2499            1000.0,
2500            epsilon = 1e-2
2501        );
2502    }
2503}