Skip to main content

proof_engine/ml/
tensor.rs

1//! N-dimensional tensor operations for ML workloads.
2
3use std::ops::Range;
4
5/// An N-dimensional tensor stored in row-major order.
6#[derive(Debug, Clone, PartialEq)]
7pub struct Tensor {
8    pub shape: Vec<usize>,
9    pub data: Vec<f32>,
10}
11
12impl Tensor {
13    // ── helpers ──────────────────────────────────────────────────────────
14
15    /// Total number of elements implied by a shape.
16    fn numel(shape: &[usize]) -> usize {
17        shape.iter().product()
18    }
19
20    /// Compute strides for row-major layout.
21    fn strides(shape: &[usize]) -> Vec<usize> {
22        let mut s = vec![1usize; shape.len()];
23        for i in (0..shape.len().saturating_sub(1)).rev() {
24            s[i] = s[i + 1] * shape[i + 1];
25        }
26        s
27    }
28
29    /// Flat index from multi-dimensional indices.
30    fn flat_index(&self, indices: &[usize]) -> usize {
31        assert_eq!(indices.len(), self.shape.len(), "index rank mismatch");
32        let strides = Self::strides(&self.shape);
33        indices.iter().zip(strides.iter()).map(|(i, s)| i * s).sum()
34    }
35
36    // ── creation ────────────────────────────────────────────────────────
37
38    pub fn zeros(shape: Vec<usize>) -> Self {
39        let n = Self::numel(&shape);
40        Self { shape, data: vec![0.0; n] }
41    }
42
43    pub fn ones(shape: Vec<usize>) -> Self {
44        let n = Self::numel(&shape);
45        Self { shape, data: vec![1.0; n] }
46    }
47
48    /// Pseudo-random tensor using a simple xorshift seeded from `rng`.
49    pub fn rand(shape: Vec<usize>, rng: u64) -> Self {
50        let n = Self::numel(&shape);
51        let mut data = Vec::with_capacity(n);
52        let mut state = rng.wrapping_add(1); // avoid zero
53        for _ in 0..n {
54            state ^= state << 13;
55            state ^= state >> 7;
56            state ^= state << 17;
57            // map to 0..1
58            data.push((state as u32 as f32) / (u32::MAX as f32));
59        }
60        Self { shape, data }
61    }
62
63    pub fn from_vec(data: Vec<f32>, shape: Vec<usize>) -> Self {
64        assert_eq!(data.len(), Self::numel(&shape), "data length / shape mismatch");
65        Self { shape, data }
66    }
67
68    /// Scalar tensor.
69    pub fn scalar(v: f32) -> Self {
70        Self { shape: vec![], data: vec![v] }
71    }
72
73    // ── indexing ─────────────────────────────────────────────────────────
74
75    pub fn get(&self, indices: &[usize]) -> f32 {
76        self.data[self.flat_index(indices)]
77    }
78
79    pub fn set(&mut self, indices: &[usize], val: f32) {
80        let idx = self.flat_index(indices);
81        self.data[idx] = val;
82    }
83
84    /// Slice along each axis with the given ranges. Produces a new tensor
85    /// whose shape matches the range extents.
86    pub fn slice(&self, ranges: &[Range<usize>]) -> Tensor {
87        assert_eq!(ranges.len(), self.shape.len());
88        let new_shape: Vec<usize> = ranges.iter().map(|r| r.end - r.start).collect();
89        let n = Self::numel(&new_shape);
90        let mut data = Vec::with_capacity(n);
91        let strides = Self::strides(&self.shape);
92        // recursive flattening via iterative approach
93        Self::slice_recursive(&self.data, &strides, ranges, 0, 0, &mut data);
94        Tensor { shape: new_shape, data }
95    }
96
97    fn slice_recursive(
98        src: &[f32],
99        strides: &[usize],
100        ranges: &[Range<usize>],
101        dim: usize,
102        base: usize,
103        out: &mut Vec<f32>,
104    ) {
105        if dim == ranges.len() {
106            out.push(src[base]);
107            return;
108        }
109        for i in ranges[dim].clone() {
110            Self::slice_recursive(src, strides, ranges, dim + 1, base + i * strides[dim], out);
111        }
112    }
113
114    // ── element-wise math ───────────────────────────────────────────────
115
116    pub fn add(&self, other: &Tensor) -> Tensor {
117        assert_eq!(self.shape, other.shape, "shape mismatch for add");
118        let data: Vec<f32> = self.data.iter().zip(&other.data).map(|(a, b)| a + b).collect();
119        Tensor { shape: self.shape.clone(), data }
120    }
121
122    pub fn sub(&self, other: &Tensor) -> Tensor {
123        assert_eq!(self.shape, other.shape, "shape mismatch for sub");
124        let data: Vec<f32> = self.data.iter().zip(&other.data).map(|(a, b)| a - b).collect();
125        Tensor { shape: self.shape.clone(), data }
126    }
127
128    pub fn mul(&self, other: &Tensor) -> Tensor {
129        assert_eq!(self.shape, other.shape, "shape mismatch for mul");
130        let data: Vec<f32> = self.data.iter().zip(&other.data).map(|(a, b)| a * b).collect();
131        Tensor { shape: self.shape.clone(), data }
132    }
133
134    pub fn scale(&self, s: f32) -> Tensor {
135        Tensor {
136            shape: self.shape.clone(),
137            data: self.data.iter().map(|v| v * s).collect(),
138        }
139    }
140
141    /// 2-D matrix multiply: (M, K) x (K, N) -> (M, N).
142    pub fn matmul(a: &Tensor, b: &Tensor) -> Tensor {
143        assert_eq!(a.shape.len(), 2, "matmul requires 2-D tensors");
144        assert_eq!(b.shape.len(), 2, "matmul requires 2-D tensors");
145        let m = a.shape[0];
146        let k = a.shape[1];
147        assert_eq!(b.shape[0], k, "inner dimensions must match");
148        let n = b.shape[1];
149        let mut data = vec![0.0f32; m * n];
150        for i in 0..m {
151            for j in 0..n {
152                let mut s = 0.0f32;
153                for p in 0..k {
154                    s += a.data[i * k + p] * b.data[p * n + j];
155                }
156                data[i * n + j] = s;
157            }
158        }
159        Tensor { shape: vec![m, n], data }
160    }
161
162    /// Transpose the last two dimensions. For 2-D tensors this is the
163    /// standard matrix transpose.
164    pub fn transpose(&self) -> Tensor {
165        assert!(self.shape.len() >= 2, "transpose needs rank >= 2");
166        let ndim = self.shape.len();
167        let rows = self.shape[ndim - 2];
168        let cols = self.shape[ndim - 1];
169        let batch: usize = self.shape[..ndim - 2].iter().product();
170        let mut new_shape = self.shape.clone();
171        new_shape[ndim - 2] = cols;
172        new_shape[ndim - 1] = rows;
173        let mat_size = rows * cols;
174        let mut data = vec![0.0f32; self.data.len()];
175        for b in 0..batch {
176            let base = b * mat_size;
177            for r in 0..rows {
178                for c in 0..cols {
179                    data[base + c * rows + r] = self.data[base + r * cols + c];
180                }
181            }
182        }
183        Tensor { shape: new_shape, data }
184    }
185
186    // ── reductions ──────────────────────────────────────────────────────
187
188    pub fn sum(&self) -> f32 {
189        self.data.iter().sum()
190    }
191
192    pub fn mean(&self) -> f32 {
193        self.sum() / self.data.len() as f32
194    }
195
196    pub fn max(&self) -> f32 {
197        self.data.iter().cloned().fold(f32::NEG_INFINITY, f32::max)
198    }
199
200    pub fn min(&self) -> f32 {
201        self.data.iter().cloned().fold(f32::INFINITY, f32::min)
202    }
203
204    /// Argmax along a given axis, returning a tensor with that axis removed.
205    pub fn argmax(&self, axis: usize) -> Tensor {
206        assert!(axis < self.shape.len());
207        let axis_len = self.shape[axis];
208        let mut new_shape: Vec<usize> = self.shape.clone();
209        new_shape.remove(axis);
210        if new_shape.is_empty() {
211            new_shape.push(1);
212        }
213        let outer: usize = self.shape[..axis].iter().product();
214        let inner: usize = self.shape[axis + 1..].iter().product();
215        let mut data = Vec::with_capacity(outer * inner);
216        for o in 0..outer {
217            for i in 0..inner {
218                let mut best_idx = 0usize;
219                let mut best_val = f32::NEG_INFINITY;
220                for a in 0..axis_len {
221                    let flat = o * axis_len * inner + a * inner + i;
222                    if self.data[flat] > best_val {
223                        best_val = self.data[flat];
224                        best_idx = a;
225                    }
226                }
227                data.push(best_idx as f32);
228            }
229        }
230        Tensor { shape: new_shape, data }
231    }
232
233    // ── reshaping ───────────────────────────────────────────────────────
234
235    pub fn reshape(&self, new_shape: Vec<usize>) -> Tensor {
236        assert_eq!(Self::numel(&new_shape), self.data.len(), "reshape size mismatch");
237        Tensor { shape: new_shape, data: self.data.clone() }
238    }
239
240    pub fn flatten(&self) -> Tensor {
241        Tensor { shape: vec![self.data.len()], data: self.data.clone() }
242    }
243
244    /// Remove all size-1 dimensions.
245    pub fn squeeze(&self) -> Tensor {
246        let new_shape: Vec<usize> = self.shape.iter().copied().filter(|&d| d != 1).collect();
247        let new_shape = if new_shape.is_empty() { vec![1] } else { new_shape };
248        Tensor { shape: new_shape, data: self.data.clone() }
249    }
250
251    /// Insert a size-1 dimension at `dim`.
252    pub fn unsqueeze(&self, dim: usize) -> Tensor {
253        let mut new_shape = self.shape.clone();
254        new_shape.insert(dim, 1);
255        Tensor { shape: new_shape, data: self.data.clone() }
256    }
257
258    // ── broadcasting ────────────────────────────────────────────────────
259
260    /// Broadcast this tensor to the target shape, repeating data as needed.
261    pub fn broadcast_to(&self, target: &[usize]) -> Tensor {
262        assert!(target.len() >= self.shape.len());
263        // left-pad shape with 1s
264        let pad = target.len() - self.shape.len();
265        let mut src_shape: Vec<usize> = vec![1; pad];
266        src_shape.extend_from_slice(&self.shape);
267
268        for (s, t) in src_shape.iter().zip(target.iter()) {
269            assert!(*s == 1 || *s == *t, "cannot broadcast {src_shape:?} to {target:?}");
270        }
271
272        let n = Self::numel(target);
273        let src_strides = Self::strides(&src_shape);
274        let dst_strides = Self::strides(target);
275        let mut data = vec![0.0f32; n];
276        for flat in 0..n {
277            let mut src_flat = 0usize;
278            let mut rem = flat;
279            for d in 0..target.len() {
280                let coord = rem / dst_strides[d];
281                rem %= dst_strides[d];
282                let src_coord = if src_shape[d] == 1 { 0 } else { coord };
283                src_flat += src_coord * src_strides[d];
284            }
285            data[flat] = self.data[src_flat];
286        }
287        Tensor { shape: target.to_vec(), data }
288    }
289
290    // ── activation functions ────────────────────────────────────────────
291
292    pub fn relu(&self) -> Tensor {
293        Tensor {
294            shape: self.shape.clone(),
295            data: self.data.iter().map(|&v| v.max(0.0)).collect(),
296        }
297    }
298
299    pub fn sigmoid(&self) -> Tensor {
300        Tensor {
301            shape: self.shape.clone(),
302            data: self.data.iter().map(|&v| 1.0 / (1.0 + (-v).exp())).collect(),
303        }
304    }
305
306    pub fn tanh_act(&self) -> Tensor {
307        Tensor {
308            shape: self.shape.clone(),
309            data: self.data.iter().map(|&v| v.tanh()).collect(),
310        }
311    }
312
313    /// Softmax along `axis`.
314    pub fn softmax(&self, axis: usize) -> Tensor {
315        assert!(axis < self.shape.len());
316        let axis_len = self.shape[axis];
317        let outer: usize = self.shape[..axis].iter().product();
318        let inner: usize = self.shape[axis + 1..].iter().product();
319        let mut data = self.data.clone();
320        for o in 0..outer {
321            for i in 0..inner {
322                // find max for numerical stability
323                let mut mx = f32::NEG_INFINITY;
324                for a in 0..axis_len {
325                    let idx = o * axis_len * inner + a * inner + i;
326                    mx = mx.max(data[idx]);
327                }
328                let mut sum = 0.0f32;
329                for a in 0..axis_len {
330                    let idx = o * axis_len * inner + a * inner + i;
331                    let e = (data[idx] - mx).exp();
332                    data[idx] = e;
333                    sum += e;
334                }
335                for a in 0..axis_len {
336                    let idx = o * axis_len * inner + a * inner + i;
337                    data[idx] /= sum;
338                }
339            }
340        }
341        Tensor { shape: self.shape.clone(), data }
342    }
343
344    /// GELU activation: x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
345    pub fn gelu(&self) -> Tensor {
346        let sqrt_2_over_pi = (2.0f32 / std::f32::consts::PI).sqrt();
347        Tensor {
348            shape: self.shape.clone(),
349            data: self.data.iter().map(|&x| {
350                let inner = sqrt_2_over_pi * (x + 0.044715 * x * x * x);
351                0.5 * x * (1.0 + inner.tanh())
352            }).collect(),
353        }
354    }
355
356    // ── convolution ─────────────────────────────────────────────────────
357
358    /// 2-D convolution. Input shape: (C_in, H, W). Kernel shape: (C_out, C_in, kH, kW).
359    /// Returns shape (C_out, H_out, W_out).
360    pub fn conv2d(&self, kernel: &Tensor, stride: usize, padding: usize) -> Tensor {
361        assert_eq!(self.shape.len(), 3, "conv2d input must be (C, H, W)");
362        assert_eq!(kernel.shape.len(), 4, "conv2d kernel must be (C_out, C_in, kH, kW)");
363        let c_in = self.shape[0];
364        let h = self.shape[1];
365        let w = self.shape[2];
366        let c_out = kernel.shape[0];
367        assert_eq!(kernel.shape[1], c_in);
368        let kh = kernel.shape[2];
369        let kw = kernel.shape[3];
370        let h_out = (h + 2 * padding - kh) / stride + 1;
371        let w_out = (w + 2 * padding - kw) / stride + 1;
372
373        let mut out = vec![0.0f32; c_out * h_out * w_out];
374        for co in 0..c_out {
375            for oh in 0..h_out {
376                for ow in 0..w_out {
377                    let mut val = 0.0f32;
378                    for ci in 0..c_in {
379                        for fh in 0..kh {
380                            for fw in 0..kw {
381                                let ih = oh * stride + fh;
382                                let iw = ow * stride + fw;
383                                let ih = ih as isize - padding as isize;
384                                let iw = iw as isize - padding as isize;
385                                if ih >= 0 && ih < h as isize && iw >= 0 && iw < w as isize {
386                                    let ih = ih as usize;
387                                    let iw = iw as usize;
388                                    let in_idx = ci * h * w + ih * w + iw;
389                                    let k_idx = co * c_in * kh * kw + ci * kh * kw + fh * kw + fw;
390                                    val += self.data[in_idx] * kernel.data[k_idx];
391                                }
392                            }
393                        }
394                    }
395                    out[co * h_out * w_out + oh * w_out + ow] = val;
396                }
397            }
398        }
399        Tensor { shape: vec![c_out, h_out, w_out], data: out }
400    }
401
402    // ── pooling ─────────────────────────────────────────────────────────
403
404    /// Max pooling 2-D. Input shape: (C, H, W).
405    pub fn max_pool2d(&self, kernel_size: usize, stride: usize) -> Tensor {
406        assert_eq!(self.shape.len(), 3);
407        let c = self.shape[0];
408        let h = self.shape[1];
409        let w = self.shape[2];
410        let h_out = (h - kernel_size) / stride + 1;
411        let w_out = (w - kernel_size) / stride + 1;
412        let mut out = vec![f32::NEG_INFINITY; c * h_out * w_out];
413        for ch in 0..c {
414            for oh in 0..h_out {
415                for ow in 0..w_out {
416                    let mut mx = f32::NEG_INFINITY;
417                    for kh in 0..kernel_size {
418                        for kw in 0..kernel_size {
419                            let ih = oh * stride + kh;
420                            let iw = ow * stride + kw;
421                            mx = mx.max(self.data[ch * h * w + ih * w + iw]);
422                        }
423                    }
424                    out[ch * h_out * w_out + oh * w_out + ow] = mx;
425                }
426            }
427        }
428        Tensor { shape: vec![c, h_out, w_out], data: out }
429    }
430
431    /// Average pooling 2-D. Input shape: (C, H, W).
432    pub fn avg_pool2d(&self, kernel_size: usize, stride: usize) -> Tensor {
433        assert_eq!(self.shape.len(), 3);
434        let c = self.shape[0];
435        let h = self.shape[1];
436        let w = self.shape[2];
437        let h_out = (h - kernel_size) / stride + 1;
438        let w_out = (w - kernel_size) / stride + 1;
439        let area = (kernel_size * kernel_size) as f32;
440        let mut out = vec![0.0f32; c * h_out * w_out];
441        for ch in 0..c {
442            for oh in 0..h_out {
443                for ow in 0..w_out {
444                    let mut s = 0.0f32;
445                    for kh in 0..kernel_size {
446                        for kw in 0..kernel_size {
447                            let ih = oh * stride + kh;
448                            let iw = ow * stride + kw;
449                            s += self.data[ch * h * w + ih * w + iw];
450                        }
451                    }
452                    out[ch * h_out * w_out + oh * w_out + ow] = s / area;
453                }
454            }
455        }
456        Tensor { shape: vec![c, h_out, w_out], data: out }
457    }
458
459    // ── normalization ───────────────────────────────────────────────────
460
461    /// Batch normalization: y = gamma * (x - mean) / sqrt(var + eps) + beta.
462    /// All parameter tensors must have the same total length as `self`.
463    pub fn batch_norm(&self, mean: &Tensor, var: &Tensor, gamma: &Tensor, beta: &Tensor, eps: f32) -> Tensor {
464        assert_eq!(self.data.len(), mean.data.len());
465        let data: Vec<f32> = self.data.iter().enumerate().map(|(i, &x)| {
466            let m = mean.data[i];
467            let v = var.data[i];
468            let g = gamma.data[i];
469            let b = beta.data[i];
470            g * (x - m) / (v + eps).sqrt() + b
471        }).collect();
472        Tensor { shape: self.shape.clone(), data }
473    }
474
475    /// Layer normalization along the last `n` dimensions starting from `axis`.
476    pub fn layer_norm(&self, axis: usize, eps: f32) -> Tensor {
477        assert!(axis < self.shape.len());
478        let outer: usize = self.shape[..axis].iter().product();
479        let inner: usize = self.shape[axis..].iter().product();
480        let mut data = self.data.clone();
481        for o in 0..outer {
482            let start = o * inner;
483            let end = start + inner;
484            let slice = &data[start..end];
485            let mean: f32 = slice.iter().sum::<f32>() / inner as f32;
486            let var: f32 = slice.iter().map(|v| (v - mean) * (v - mean)).sum::<f32>() / inner as f32;
487            let inv_std = 1.0 / (var + eps).sqrt();
488            for i in start..end {
489                data[i] = (data[i] - mean) * inv_std;
490            }
491        }
492        Tensor { shape: self.shape.clone(), data }
493    }
494
495    // ── dropout ─────────────────────────────────────────────────────────
496
497    /// Dropout: randomly zero elements with probability `p` during training.
498    pub fn dropout(&self, p: f32, rng: u64, training: bool) -> Tensor {
499        if !training || p == 0.0 {
500            return self.clone();
501        }
502        let scale = 1.0 / (1.0 - p);
503        let mut state = rng.wrapping_add(1);
504        let data: Vec<f32> = self.data.iter().map(|&v| {
505            state ^= state << 13;
506            state ^= state >> 7;
507            state ^= state << 17;
508            let r = (state as u32 as f32) / (u32::MAX as f32);
509            if r < p { 0.0 } else { v * scale }
510        }).collect();
511        Tensor { shape: self.shape.clone(), data }
512    }
513
514    // ── concatenation / stacking ────────────────────────────────────────
515
516    /// Concatenate tensors along an axis.
517    pub fn concat(tensors: &[Tensor], axis: usize) -> Tensor {
518        assert!(!tensors.is_empty());
519        let ndim = tensors[0].shape.len();
520        assert!(axis < ndim);
521        // verify all shapes match except along `axis`
522        for t in &tensors[1..] {
523            assert_eq!(t.shape.len(), ndim);
524            for d in 0..ndim {
525                if d != axis {
526                    assert_eq!(t.shape[d], tensors[0].shape[d]);
527                }
528            }
529        }
530        let mut new_shape = tensors[0].shape.clone();
531        new_shape[axis] = tensors.iter().map(|t| t.shape[axis]).sum();
532
533        let outer: usize = new_shape[..axis].iter().product();
534        let inner: usize = new_shape[axis + 1..].iter().product();
535        let total = Self::numel(&new_shape);
536        let mut data = Vec::with_capacity(total);
537
538        for o in 0..outer {
539            for t in tensors {
540                let t_axis = t.shape[axis];
541                let t_inner: usize = t.shape[axis + 1..].iter().product();
542                for a in 0..t_axis {
543                    for i in 0..inner {
544                        let idx = o * t_axis * t_inner + a * t_inner + i;
545                        data.push(t.data[idx]);
546                    }
547                }
548            }
549        }
550        Tensor { shape: new_shape, data }
551    }
552
553    /// Stack tensors along a new axis.
554    pub fn stack(tensors: &[Tensor], axis: usize) -> Tensor {
555        assert!(!tensors.is_empty());
556        // unsqueeze each tensor at `axis`, then concat
557        let unsqueezed: Vec<Tensor> = tensors.iter().map(|t| t.unsqueeze(axis)).collect();
558        Self::concat(&unsqueezed, axis)
559    }
560}
561
562#[cfg(test)]
563mod tests {
564    use super::*;
565
566    #[test]
567    fn test_creation() {
568        let z = Tensor::zeros(vec![2, 3]);
569        assert_eq!(z.data.len(), 6);
570        assert!(z.data.iter().all(|&v| v == 0.0));
571
572        let o = Tensor::ones(vec![3, 2]);
573        assert!(o.data.iter().all(|&v| v == 1.0));
574    }
575
576    #[test]
577    fn test_indexing() {
578        let mut t = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
579        assert_eq!(t.get(&[0, 0]), 1.0);
580        assert_eq!(t.get(&[1, 2]), 6.0);
581        t.set(&[0, 1], 99.0);
582        assert_eq!(t.get(&[0, 1]), 99.0);
583    }
584
585    #[test]
586    fn test_matmul() {
587        // [[1,2],[3,4]] x [[5,6],[7,8]] = [[19,22],[43,50]]
588        let a = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
589        let b = Tensor::from_vec(vec![5.0, 6.0, 7.0, 8.0], vec![2, 2]);
590        let c = Tensor::matmul(&a, &b);
591        assert_eq!(c.shape, vec![2, 2]);
592        assert_eq!(c.get(&[0, 0]), 19.0);
593        assert_eq!(c.get(&[0, 1]), 22.0);
594        assert_eq!(c.get(&[1, 0]), 43.0);
595        assert_eq!(c.get(&[1, 1]), 50.0);
596    }
597
598    #[test]
599    fn test_matmul_non_square() {
600        let a = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
601        let b = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2]);
602        let c = Tensor::matmul(&a, &b);
603        assert_eq!(c.shape, vec![2, 2]);
604        // [1*1+2*3+3*5, 1*2+2*4+3*6] = [22, 28]
605        assert_eq!(c.get(&[0, 0]), 22.0);
606        assert_eq!(c.get(&[0, 1]), 28.0);
607    }
608
609    #[test]
610    fn test_transpose() {
611        let a = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
612        let at = a.transpose();
613        assert_eq!(at.shape, vec![3, 2]);
614        assert_eq!(at.get(&[0, 0]), 1.0);
615        assert_eq!(at.get(&[0, 1]), 4.0);
616        assert_eq!(at.get(&[2, 0]), 3.0);
617        assert_eq!(at.get(&[2, 1]), 6.0);
618    }
619
620    #[test]
621    fn test_softmax_sums_to_one() {
622        let t = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]);
623        let s = t.softmax(1);
624        let total: f32 = s.data.iter().sum();
625        assert!((total - 1.0).abs() < 1e-5, "softmax sum = {total}");
626        // all positive
627        assert!(s.data.iter().all(|&v| v > 0.0));
628    }
629
630    #[test]
631    fn test_relu_zeros_negatives() {
632        let t = Tensor::from_vec(vec![-3.0, -1.0, 0.0, 1.0, 5.0], vec![5]);
633        let r = t.relu();
634        assert_eq!(r.data, vec![0.0, 0.0, 0.0, 1.0, 5.0]);
635    }
636
637    #[test]
638    fn test_conv2d() {
639        // 1 channel, 4x4 input, 1 filter 1x1x3x3, stride 1, no padding -> 2x2
640        let input = Tensor::ones(vec![1, 4, 4]);
641        let kernel = Tensor::ones(vec![1, 1, 3, 3]);
642        let out = input.conv2d(&kernel, 1, 0);
643        assert_eq!(out.shape, vec![1, 2, 2]);
644        // each output element = sum of 3x3 ones = 9
645        assert_eq!(out.data, vec![9.0, 9.0, 9.0, 9.0]);
646    }
647
648    #[test]
649    fn test_conv2d_with_padding() {
650        let input = Tensor::ones(vec![1, 3, 3]);
651        let kernel = Tensor::ones(vec![1, 1, 3, 3]);
652        let out = input.conv2d(&kernel, 1, 1);
653        assert_eq!(out.shape, vec![1, 3, 3]);
654        // center: 9, corners: 4, edges: 6
655        assert_eq!(out.get(&[0, 1, 1]), 9.0);
656        assert_eq!(out.get(&[0, 0, 0]), 4.0);
657        assert_eq!(out.get(&[0, 0, 1]), 6.0);
658    }
659
660    #[test]
661    fn test_pooling() {
662        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0];
663        let t = Tensor::from_vec(data, vec![1, 4, 4]);
664        let mp = t.max_pool2d(2, 2);
665        assert_eq!(mp.shape, vec![1, 2, 2]);
666        assert_eq!(mp.data, vec![6.0, 8.0, 14.0, 16.0]);
667
668        let ap = t.avg_pool2d(2, 2);
669        assert_eq!(ap.shape, vec![1, 2, 2]);
670        assert_eq!(ap.data, vec![3.5, 5.5, 11.5, 13.5]);
671    }
672
673    #[test]
674    fn test_reshape_flatten() {
675        let t = Tensor::ones(vec![2, 3, 4]);
676        let r = t.reshape(vec![6, 4]);
677        assert_eq!(r.shape, vec![6, 4]);
678        assert_eq!(r.data.len(), 24);
679        let f = t.flatten();
680        assert_eq!(f.shape, vec![24]);
681    }
682
683    #[test]
684    fn test_squeeze_unsqueeze() {
685        let t = Tensor::ones(vec![1, 3, 1, 4]);
686        let s = t.squeeze();
687        assert_eq!(s.shape, vec![3, 4]);
688        let u = s.unsqueeze(0);
689        assert_eq!(u.shape, vec![1, 3, 4]);
690    }
691
692    #[test]
693    fn test_broadcast() {
694        let t = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![1, 3]);
695        let b = t.broadcast_to(&[2, 3]);
696        assert_eq!(b.shape, vec![2, 3]);
697        assert_eq!(b.data, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
698    }
699
700    #[test]
701    fn test_sigmoid() {
702        let t = Tensor::from_vec(vec![0.0], vec![1]);
703        let s = t.sigmoid();
704        assert!((s.data[0] - 0.5).abs() < 1e-5);
705    }
706
707    #[test]
708    fn test_gelu() {
709        let t = Tensor::from_vec(vec![0.0, 1.0, -1.0], vec![3]);
710        let g = t.gelu();
711        assert!((g.data[0]).abs() < 1e-5); // gelu(0) = 0
712        assert!(g.data[1] > 0.8); // gelu(1) ~ 0.841
713        assert!(g.data[2] < 0.0); // gelu(-1) ~ -0.159
714    }
715
716    #[test]
717    fn test_layer_norm() {
718        let t = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]);
719        let ln = t.layer_norm(1, 1e-5);
720        // mean should be ~0
721        let mean: f32 = ln.data.iter().sum::<f32>() / 4.0;
722        assert!(mean.abs() < 1e-4);
723    }
724
725    #[test]
726    fn test_concat() {
727        let a = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
728        let b = Tensor::from_vec(vec![5.0, 6.0, 7.0, 8.0], vec![2, 2]);
729        let c = Tensor::concat(&[a, b], 0);
730        assert_eq!(c.shape, vec![4, 2]);
731        assert_eq!(c.data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
732    }
733
734    #[test]
735    fn test_stack() {
736        let a = Tensor::from_vec(vec![1.0, 2.0], vec![2]);
737        let b = Tensor::from_vec(vec![3.0, 4.0], vec![2]);
738        let s = Tensor::stack(&[a, b], 0);
739        assert_eq!(s.shape, vec![2, 2]);
740        assert_eq!(s.data, vec![1.0, 2.0, 3.0, 4.0]);
741    }
742
743    #[test]
744    fn test_slice() {
745        let t = Tensor::from_vec(
746            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
747            vec![3, 3],
748        );
749        let s = t.slice(&[0..2, 1..3]);
750        assert_eq!(s.shape, vec![2, 2]);
751        assert_eq!(s.data, vec![2.0, 3.0, 5.0, 6.0]);
752    }
753
754    #[test]
755    fn test_dropout() {
756        let t = Tensor::ones(vec![100]);
757        let d = t.dropout(0.5, 42, true);
758        let zeros = d.data.iter().filter(|&&v| v == 0.0).count();
759        // with p=0.5 we expect roughly 50 zeros (allow wide margin)
760        assert!(zeros > 10 && zeros < 90);
761        // non-training should pass through
762        let d2 = t.dropout(0.5, 42, false);
763        assert_eq!(d2.data, t.data);
764    }
765
766    #[test]
767    fn test_argmax() {
768        let t = Tensor::from_vec(vec![1.0, 5.0, 3.0, 9.0, 2.0, 4.0], vec![2, 3]);
769        let am = t.argmax(1);
770        assert_eq!(am.shape, vec![2]);
771        assert_eq!(am.data, vec![1.0, 0.0]); // argmax of [1,5,3]=1, [9,2,4]=0
772    }
773}