1use std::ops::Range;
4
5#[derive(Debug, Clone, PartialEq)]
7pub struct Tensor {
8 pub shape: Vec<usize>,
9 pub data: Vec<f32>,
10}
11
12impl Tensor {
13 fn numel(shape: &[usize]) -> usize {
17 shape.iter().product()
18 }
19
20 fn strides(shape: &[usize]) -> Vec<usize> {
22 let mut s = vec![1usize; shape.len()];
23 for i in (0..shape.len().saturating_sub(1)).rev() {
24 s[i] = s[i + 1] * shape[i + 1];
25 }
26 s
27 }
28
29 fn flat_index(&self, indices: &[usize]) -> usize {
31 assert_eq!(indices.len(), self.shape.len(), "index rank mismatch");
32 let strides = Self::strides(&self.shape);
33 indices.iter().zip(strides.iter()).map(|(i, s)| i * s).sum()
34 }
35
36 pub fn zeros(shape: Vec<usize>) -> Self {
39 let n = Self::numel(&shape);
40 Self { shape, data: vec![0.0; n] }
41 }
42
43 pub fn ones(shape: Vec<usize>) -> Self {
44 let n = Self::numel(&shape);
45 Self { shape, data: vec![1.0; n] }
46 }
47
48 pub fn rand(shape: Vec<usize>, rng: u64) -> Self {
50 let n = Self::numel(&shape);
51 let mut data = Vec::with_capacity(n);
52 let mut state = rng.wrapping_add(1); for _ in 0..n {
54 state ^= state << 13;
55 state ^= state >> 7;
56 state ^= state << 17;
57 data.push((state as u32 as f32) / (u32::MAX as f32));
59 }
60 Self { shape, data }
61 }
62
63 pub fn from_vec(data: Vec<f32>, shape: Vec<usize>) -> Self {
64 assert_eq!(data.len(), Self::numel(&shape), "data length / shape mismatch");
65 Self { shape, data }
66 }
67
68 pub fn scalar(v: f32) -> Self {
70 Self { shape: vec![], data: vec![v] }
71 }
72
73 pub fn get(&self, indices: &[usize]) -> f32 {
76 self.data[self.flat_index(indices)]
77 }
78
79 pub fn set(&mut self, indices: &[usize], val: f32) {
80 let idx = self.flat_index(indices);
81 self.data[idx] = val;
82 }
83
84 pub fn slice(&self, ranges: &[Range<usize>]) -> Tensor {
87 assert_eq!(ranges.len(), self.shape.len());
88 let new_shape: Vec<usize> = ranges.iter().map(|r| r.end - r.start).collect();
89 let n = Self::numel(&new_shape);
90 let mut data = Vec::with_capacity(n);
91 let strides = Self::strides(&self.shape);
92 Self::slice_recursive(&self.data, &strides, ranges, 0, 0, &mut data);
94 Tensor { shape: new_shape, data }
95 }
96
97 fn slice_recursive(
98 src: &[f32],
99 strides: &[usize],
100 ranges: &[Range<usize>],
101 dim: usize,
102 base: usize,
103 out: &mut Vec<f32>,
104 ) {
105 if dim == ranges.len() {
106 out.push(src[base]);
107 return;
108 }
109 for i in ranges[dim].clone() {
110 Self::slice_recursive(src, strides, ranges, dim + 1, base + i * strides[dim], out);
111 }
112 }
113
114 pub fn add(&self, other: &Tensor) -> Tensor {
117 assert_eq!(self.shape, other.shape, "shape mismatch for add");
118 let data: Vec<f32> = self.data.iter().zip(&other.data).map(|(a, b)| a + b).collect();
119 Tensor { shape: self.shape.clone(), data }
120 }
121
122 pub fn sub(&self, other: &Tensor) -> Tensor {
123 assert_eq!(self.shape, other.shape, "shape mismatch for sub");
124 let data: Vec<f32> = self.data.iter().zip(&other.data).map(|(a, b)| a - b).collect();
125 Tensor { shape: self.shape.clone(), data }
126 }
127
128 pub fn mul(&self, other: &Tensor) -> Tensor {
129 assert_eq!(self.shape, other.shape, "shape mismatch for mul");
130 let data: Vec<f32> = self.data.iter().zip(&other.data).map(|(a, b)| a * b).collect();
131 Tensor { shape: self.shape.clone(), data }
132 }
133
134 pub fn scale(&self, s: f32) -> Tensor {
135 Tensor {
136 shape: self.shape.clone(),
137 data: self.data.iter().map(|v| v * s).collect(),
138 }
139 }
140
141 pub fn matmul(a: &Tensor, b: &Tensor) -> Tensor {
143 assert_eq!(a.shape.len(), 2, "matmul requires 2-D tensors");
144 assert_eq!(b.shape.len(), 2, "matmul requires 2-D tensors");
145 let m = a.shape[0];
146 let k = a.shape[1];
147 assert_eq!(b.shape[0], k, "inner dimensions must match");
148 let n = b.shape[1];
149 let mut data = vec![0.0f32; m * n];
150 for i in 0..m {
151 for j in 0..n {
152 let mut s = 0.0f32;
153 for p in 0..k {
154 s += a.data[i * k + p] * b.data[p * n + j];
155 }
156 data[i * n + j] = s;
157 }
158 }
159 Tensor { shape: vec![m, n], data }
160 }
161
162 pub fn transpose(&self) -> Tensor {
165 assert!(self.shape.len() >= 2, "transpose needs rank >= 2");
166 let ndim = self.shape.len();
167 let rows = self.shape[ndim - 2];
168 let cols = self.shape[ndim - 1];
169 let batch: usize = self.shape[..ndim - 2].iter().product();
170 let mut new_shape = self.shape.clone();
171 new_shape[ndim - 2] = cols;
172 new_shape[ndim - 1] = rows;
173 let mat_size = rows * cols;
174 let mut data = vec![0.0f32; self.data.len()];
175 for b in 0..batch {
176 let base = b * mat_size;
177 for r in 0..rows {
178 for c in 0..cols {
179 data[base + c * rows + r] = self.data[base + r * cols + c];
180 }
181 }
182 }
183 Tensor { shape: new_shape, data }
184 }
185
186 pub fn sum(&self) -> f32 {
189 self.data.iter().sum()
190 }
191
192 pub fn mean(&self) -> f32 {
193 self.sum() / self.data.len() as f32
194 }
195
196 pub fn max(&self) -> f32 {
197 self.data.iter().cloned().fold(f32::NEG_INFINITY, f32::max)
198 }
199
200 pub fn min(&self) -> f32 {
201 self.data.iter().cloned().fold(f32::INFINITY, f32::min)
202 }
203
204 pub fn argmax(&self, axis: usize) -> Tensor {
206 assert!(axis < self.shape.len());
207 let axis_len = self.shape[axis];
208 let mut new_shape: Vec<usize> = self.shape.clone();
209 new_shape.remove(axis);
210 if new_shape.is_empty() {
211 new_shape.push(1);
212 }
213 let outer: usize = self.shape[..axis].iter().product();
214 let inner: usize = self.shape[axis + 1..].iter().product();
215 let mut data = Vec::with_capacity(outer * inner);
216 for o in 0..outer {
217 for i in 0..inner {
218 let mut best_idx = 0usize;
219 let mut best_val = f32::NEG_INFINITY;
220 for a in 0..axis_len {
221 let flat = o * axis_len * inner + a * inner + i;
222 if self.data[flat] > best_val {
223 best_val = self.data[flat];
224 best_idx = a;
225 }
226 }
227 data.push(best_idx as f32);
228 }
229 }
230 Tensor { shape: new_shape, data }
231 }
232
233 pub fn reshape(&self, new_shape: Vec<usize>) -> Tensor {
236 assert_eq!(Self::numel(&new_shape), self.data.len(), "reshape size mismatch");
237 Tensor { shape: new_shape, data: self.data.clone() }
238 }
239
240 pub fn flatten(&self) -> Tensor {
241 Tensor { shape: vec![self.data.len()], data: self.data.clone() }
242 }
243
244 pub fn squeeze(&self) -> Tensor {
246 let new_shape: Vec<usize> = self.shape.iter().copied().filter(|&d| d != 1).collect();
247 let new_shape = if new_shape.is_empty() { vec![1] } else { new_shape };
248 Tensor { shape: new_shape, data: self.data.clone() }
249 }
250
251 pub fn unsqueeze(&self, dim: usize) -> Tensor {
253 let mut new_shape = self.shape.clone();
254 new_shape.insert(dim, 1);
255 Tensor { shape: new_shape, data: self.data.clone() }
256 }
257
258 pub fn broadcast_to(&self, target: &[usize]) -> Tensor {
262 assert!(target.len() >= self.shape.len());
263 let pad = target.len() - self.shape.len();
265 let mut src_shape: Vec<usize> = vec![1; pad];
266 src_shape.extend_from_slice(&self.shape);
267
268 for (s, t) in src_shape.iter().zip(target.iter()) {
269 assert!(*s == 1 || *s == *t, "cannot broadcast {src_shape:?} to {target:?}");
270 }
271
272 let n = Self::numel(target);
273 let src_strides = Self::strides(&src_shape);
274 let dst_strides = Self::strides(target);
275 let mut data = vec![0.0f32; n];
276 for flat in 0..n {
277 let mut src_flat = 0usize;
278 let mut rem = flat;
279 for d in 0..target.len() {
280 let coord = rem / dst_strides[d];
281 rem %= dst_strides[d];
282 let src_coord = if src_shape[d] == 1 { 0 } else { coord };
283 src_flat += src_coord * src_strides[d];
284 }
285 data[flat] = self.data[src_flat];
286 }
287 Tensor { shape: target.to_vec(), data }
288 }
289
290 pub fn relu(&self) -> Tensor {
293 Tensor {
294 shape: self.shape.clone(),
295 data: self.data.iter().map(|&v| v.max(0.0)).collect(),
296 }
297 }
298
299 pub fn sigmoid(&self) -> Tensor {
300 Tensor {
301 shape: self.shape.clone(),
302 data: self.data.iter().map(|&v| 1.0 / (1.0 + (-v).exp())).collect(),
303 }
304 }
305
306 pub fn tanh_act(&self) -> Tensor {
307 Tensor {
308 shape: self.shape.clone(),
309 data: self.data.iter().map(|&v| v.tanh()).collect(),
310 }
311 }
312
313 pub fn softmax(&self, axis: usize) -> Tensor {
315 assert!(axis < self.shape.len());
316 let axis_len = self.shape[axis];
317 let outer: usize = self.shape[..axis].iter().product();
318 let inner: usize = self.shape[axis + 1..].iter().product();
319 let mut data = self.data.clone();
320 for o in 0..outer {
321 for i in 0..inner {
322 let mut mx = f32::NEG_INFINITY;
324 for a in 0..axis_len {
325 let idx = o * axis_len * inner + a * inner + i;
326 mx = mx.max(data[idx]);
327 }
328 let mut sum = 0.0f32;
329 for a in 0..axis_len {
330 let idx = o * axis_len * inner + a * inner + i;
331 let e = (data[idx] - mx).exp();
332 data[idx] = e;
333 sum += e;
334 }
335 for a in 0..axis_len {
336 let idx = o * axis_len * inner + a * inner + i;
337 data[idx] /= sum;
338 }
339 }
340 }
341 Tensor { shape: self.shape.clone(), data }
342 }
343
344 pub fn gelu(&self) -> Tensor {
346 let sqrt_2_over_pi = (2.0f32 / std::f32::consts::PI).sqrt();
347 Tensor {
348 shape: self.shape.clone(),
349 data: self.data.iter().map(|&x| {
350 let inner = sqrt_2_over_pi * (x + 0.044715 * x * x * x);
351 0.5 * x * (1.0 + inner.tanh())
352 }).collect(),
353 }
354 }
355
356 pub fn conv2d(&self, kernel: &Tensor, stride: usize, padding: usize) -> Tensor {
361 assert_eq!(self.shape.len(), 3, "conv2d input must be (C, H, W)");
362 assert_eq!(kernel.shape.len(), 4, "conv2d kernel must be (C_out, C_in, kH, kW)");
363 let c_in = self.shape[0];
364 let h = self.shape[1];
365 let w = self.shape[2];
366 let c_out = kernel.shape[0];
367 assert_eq!(kernel.shape[1], c_in);
368 let kh = kernel.shape[2];
369 let kw = kernel.shape[3];
370 let h_out = (h + 2 * padding - kh) / stride + 1;
371 let w_out = (w + 2 * padding - kw) / stride + 1;
372
373 let mut out = vec![0.0f32; c_out * h_out * w_out];
374 for co in 0..c_out {
375 for oh in 0..h_out {
376 for ow in 0..w_out {
377 let mut val = 0.0f32;
378 for ci in 0..c_in {
379 for fh in 0..kh {
380 for fw in 0..kw {
381 let ih = oh * stride + fh;
382 let iw = ow * stride + fw;
383 let ih = ih as isize - padding as isize;
384 let iw = iw as isize - padding as isize;
385 if ih >= 0 && ih < h as isize && iw >= 0 && iw < w as isize {
386 let ih = ih as usize;
387 let iw = iw as usize;
388 let in_idx = ci * h * w + ih * w + iw;
389 let k_idx = co * c_in * kh * kw + ci * kh * kw + fh * kw + fw;
390 val += self.data[in_idx] * kernel.data[k_idx];
391 }
392 }
393 }
394 }
395 out[co * h_out * w_out + oh * w_out + ow] = val;
396 }
397 }
398 }
399 Tensor { shape: vec![c_out, h_out, w_out], data: out }
400 }
401
402 pub fn max_pool2d(&self, kernel_size: usize, stride: usize) -> Tensor {
406 assert_eq!(self.shape.len(), 3);
407 let c = self.shape[0];
408 let h = self.shape[1];
409 let w = self.shape[2];
410 let h_out = (h - kernel_size) / stride + 1;
411 let w_out = (w - kernel_size) / stride + 1;
412 let mut out = vec![f32::NEG_INFINITY; c * h_out * w_out];
413 for ch in 0..c {
414 for oh in 0..h_out {
415 for ow in 0..w_out {
416 let mut mx = f32::NEG_INFINITY;
417 for kh in 0..kernel_size {
418 for kw in 0..kernel_size {
419 let ih = oh * stride + kh;
420 let iw = ow * stride + kw;
421 mx = mx.max(self.data[ch * h * w + ih * w + iw]);
422 }
423 }
424 out[ch * h_out * w_out + oh * w_out + ow] = mx;
425 }
426 }
427 }
428 Tensor { shape: vec![c, h_out, w_out], data: out }
429 }
430
431 pub fn avg_pool2d(&self, kernel_size: usize, stride: usize) -> Tensor {
433 assert_eq!(self.shape.len(), 3);
434 let c = self.shape[0];
435 let h = self.shape[1];
436 let w = self.shape[2];
437 let h_out = (h - kernel_size) / stride + 1;
438 let w_out = (w - kernel_size) / stride + 1;
439 let area = (kernel_size * kernel_size) as f32;
440 let mut out = vec![0.0f32; c * h_out * w_out];
441 for ch in 0..c {
442 for oh in 0..h_out {
443 for ow in 0..w_out {
444 let mut s = 0.0f32;
445 for kh in 0..kernel_size {
446 for kw in 0..kernel_size {
447 let ih = oh * stride + kh;
448 let iw = ow * stride + kw;
449 s += self.data[ch * h * w + ih * w + iw];
450 }
451 }
452 out[ch * h_out * w_out + oh * w_out + ow] = s / area;
453 }
454 }
455 }
456 Tensor { shape: vec![c, h_out, w_out], data: out }
457 }
458
459 pub fn batch_norm(&self, mean: &Tensor, var: &Tensor, gamma: &Tensor, beta: &Tensor, eps: f32) -> Tensor {
464 assert_eq!(self.data.len(), mean.data.len());
465 let data: Vec<f32> = self.data.iter().enumerate().map(|(i, &x)| {
466 let m = mean.data[i];
467 let v = var.data[i];
468 let g = gamma.data[i];
469 let b = beta.data[i];
470 g * (x - m) / (v + eps).sqrt() + b
471 }).collect();
472 Tensor { shape: self.shape.clone(), data }
473 }
474
475 pub fn layer_norm(&self, axis: usize, eps: f32) -> Tensor {
477 assert!(axis < self.shape.len());
478 let outer: usize = self.shape[..axis].iter().product();
479 let inner: usize = self.shape[axis..].iter().product();
480 let mut data = self.data.clone();
481 for o in 0..outer {
482 let start = o * inner;
483 let end = start + inner;
484 let slice = &data[start..end];
485 let mean: f32 = slice.iter().sum::<f32>() / inner as f32;
486 let var: f32 = slice.iter().map(|v| (v - mean) * (v - mean)).sum::<f32>() / inner as f32;
487 let inv_std = 1.0 / (var + eps).sqrt();
488 for i in start..end {
489 data[i] = (data[i] - mean) * inv_std;
490 }
491 }
492 Tensor { shape: self.shape.clone(), data }
493 }
494
495 pub fn dropout(&self, p: f32, rng: u64, training: bool) -> Tensor {
499 if !training || p == 0.0 {
500 return self.clone();
501 }
502 let scale = 1.0 / (1.0 - p);
503 let mut state = rng.wrapping_add(1);
504 let data: Vec<f32> = self.data.iter().map(|&v| {
505 state ^= state << 13;
506 state ^= state >> 7;
507 state ^= state << 17;
508 let r = (state as u32 as f32) / (u32::MAX as f32);
509 if r < p { 0.0 } else { v * scale }
510 }).collect();
511 Tensor { shape: self.shape.clone(), data }
512 }
513
514 pub fn concat(tensors: &[Tensor], axis: usize) -> Tensor {
518 assert!(!tensors.is_empty());
519 let ndim = tensors[0].shape.len();
520 assert!(axis < ndim);
521 for t in &tensors[1..] {
523 assert_eq!(t.shape.len(), ndim);
524 for d in 0..ndim {
525 if d != axis {
526 assert_eq!(t.shape[d], tensors[0].shape[d]);
527 }
528 }
529 }
530 let mut new_shape = tensors[0].shape.clone();
531 new_shape[axis] = tensors.iter().map(|t| t.shape[axis]).sum();
532
533 let outer: usize = new_shape[..axis].iter().product();
534 let inner: usize = new_shape[axis + 1..].iter().product();
535 let total = Self::numel(&new_shape);
536 let mut data = Vec::with_capacity(total);
537
538 for o in 0..outer {
539 for t in tensors {
540 let t_axis = t.shape[axis];
541 let t_inner: usize = t.shape[axis + 1..].iter().product();
542 for a in 0..t_axis {
543 for i in 0..inner {
544 let idx = o * t_axis * t_inner + a * t_inner + i;
545 data.push(t.data[idx]);
546 }
547 }
548 }
549 }
550 Tensor { shape: new_shape, data }
551 }
552
553 pub fn stack(tensors: &[Tensor], axis: usize) -> Tensor {
555 assert!(!tensors.is_empty());
556 let unsqueezed: Vec<Tensor> = tensors.iter().map(|t| t.unsqueeze(axis)).collect();
558 Self::concat(&unsqueezed, axis)
559 }
560}
561
562#[cfg(test)]
563mod tests {
564 use super::*;
565
566 #[test]
567 fn test_creation() {
568 let z = Tensor::zeros(vec![2, 3]);
569 assert_eq!(z.data.len(), 6);
570 assert!(z.data.iter().all(|&v| v == 0.0));
571
572 let o = Tensor::ones(vec![3, 2]);
573 assert!(o.data.iter().all(|&v| v == 1.0));
574 }
575
576 #[test]
577 fn test_indexing() {
578 let mut t = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
579 assert_eq!(t.get(&[0, 0]), 1.0);
580 assert_eq!(t.get(&[1, 2]), 6.0);
581 t.set(&[0, 1], 99.0);
582 assert_eq!(t.get(&[0, 1]), 99.0);
583 }
584
585 #[test]
586 fn test_matmul() {
587 let a = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
589 let b = Tensor::from_vec(vec![5.0, 6.0, 7.0, 8.0], vec![2, 2]);
590 let c = Tensor::matmul(&a, &b);
591 assert_eq!(c.shape, vec![2, 2]);
592 assert_eq!(c.get(&[0, 0]), 19.0);
593 assert_eq!(c.get(&[0, 1]), 22.0);
594 assert_eq!(c.get(&[1, 0]), 43.0);
595 assert_eq!(c.get(&[1, 1]), 50.0);
596 }
597
598 #[test]
599 fn test_matmul_non_square() {
600 let a = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
601 let b = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2]);
602 let c = Tensor::matmul(&a, &b);
603 assert_eq!(c.shape, vec![2, 2]);
604 assert_eq!(c.get(&[0, 0]), 22.0);
606 assert_eq!(c.get(&[0, 1]), 28.0);
607 }
608
609 #[test]
610 fn test_transpose() {
611 let a = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
612 let at = a.transpose();
613 assert_eq!(at.shape, vec![3, 2]);
614 assert_eq!(at.get(&[0, 0]), 1.0);
615 assert_eq!(at.get(&[0, 1]), 4.0);
616 assert_eq!(at.get(&[2, 0]), 3.0);
617 assert_eq!(at.get(&[2, 1]), 6.0);
618 }
619
620 #[test]
621 fn test_softmax_sums_to_one() {
622 let t = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]);
623 let s = t.softmax(1);
624 let total: f32 = s.data.iter().sum();
625 assert!((total - 1.0).abs() < 1e-5, "softmax sum = {total}");
626 assert!(s.data.iter().all(|&v| v > 0.0));
628 }
629
630 #[test]
631 fn test_relu_zeros_negatives() {
632 let t = Tensor::from_vec(vec![-3.0, -1.0, 0.0, 1.0, 5.0], vec![5]);
633 let r = t.relu();
634 assert_eq!(r.data, vec![0.0, 0.0, 0.0, 1.0, 5.0]);
635 }
636
637 #[test]
638 fn test_conv2d() {
639 let input = Tensor::ones(vec![1, 4, 4]);
641 let kernel = Tensor::ones(vec![1, 1, 3, 3]);
642 let out = input.conv2d(&kernel, 1, 0);
643 assert_eq!(out.shape, vec![1, 2, 2]);
644 assert_eq!(out.data, vec![9.0, 9.0, 9.0, 9.0]);
646 }
647
648 #[test]
649 fn test_conv2d_with_padding() {
650 let input = Tensor::ones(vec![1, 3, 3]);
651 let kernel = Tensor::ones(vec![1, 1, 3, 3]);
652 let out = input.conv2d(&kernel, 1, 1);
653 assert_eq!(out.shape, vec![1, 3, 3]);
654 assert_eq!(out.get(&[0, 1, 1]), 9.0);
656 assert_eq!(out.get(&[0, 0, 0]), 4.0);
657 assert_eq!(out.get(&[0, 0, 1]), 6.0);
658 }
659
660 #[test]
661 fn test_pooling() {
662 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0];
663 let t = Tensor::from_vec(data, vec![1, 4, 4]);
664 let mp = t.max_pool2d(2, 2);
665 assert_eq!(mp.shape, vec![1, 2, 2]);
666 assert_eq!(mp.data, vec![6.0, 8.0, 14.0, 16.0]);
667
668 let ap = t.avg_pool2d(2, 2);
669 assert_eq!(ap.shape, vec![1, 2, 2]);
670 assert_eq!(ap.data, vec![3.5, 5.5, 11.5, 13.5]);
671 }
672
673 #[test]
674 fn test_reshape_flatten() {
675 let t = Tensor::ones(vec![2, 3, 4]);
676 let r = t.reshape(vec![6, 4]);
677 assert_eq!(r.shape, vec![6, 4]);
678 assert_eq!(r.data.len(), 24);
679 let f = t.flatten();
680 assert_eq!(f.shape, vec![24]);
681 }
682
683 #[test]
684 fn test_squeeze_unsqueeze() {
685 let t = Tensor::ones(vec![1, 3, 1, 4]);
686 let s = t.squeeze();
687 assert_eq!(s.shape, vec![3, 4]);
688 let u = s.unsqueeze(0);
689 assert_eq!(u.shape, vec![1, 3, 4]);
690 }
691
692 #[test]
693 fn test_broadcast() {
694 let t = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![1, 3]);
695 let b = t.broadcast_to(&[2, 3]);
696 assert_eq!(b.shape, vec![2, 3]);
697 assert_eq!(b.data, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
698 }
699
700 #[test]
701 fn test_sigmoid() {
702 let t = Tensor::from_vec(vec![0.0], vec![1]);
703 let s = t.sigmoid();
704 assert!((s.data[0] - 0.5).abs() < 1e-5);
705 }
706
707 #[test]
708 fn test_gelu() {
709 let t = Tensor::from_vec(vec![0.0, 1.0, -1.0], vec![3]);
710 let g = t.gelu();
711 assert!((g.data[0]).abs() < 1e-5); assert!(g.data[1] > 0.8); assert!(g.data[2] < 0.0); }
715
716 #[test]
717 fn test_layer_norm() {
718 let t = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]);
719 let ln = t.layer_norm(1, 1e-5);
720 let mean: f32 = ln.data.iter().sum::<f32>() / 4.0;
722 assert!(mean.abs() < 1e-4);
723 }
724
725 #[test]
726 fn test_concat() {
727 let a = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
728 let b = Tensor::from_vec(vec![5.0, 6.0, 7.0, 8.0], vec![2, 2]);
729 let c = Tensor::concat(&[a, b], 0);
730 assert_eq!(c.shape, vec![4, 2]);
731 assert_eq!(c.data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
732 }
733
734 #[test]
735 fn test_stack() {
736 let a = Tensor::from_vec(vec![1.0, 2.0], vec![2]);
737 let b = Tensor::from_vec(vec![3.0, 4.0], vec![2]);
738 let s = Tensor::stack(&[a, b], 0);
739 assert_eq!(s.shape, vec![2, 2]);
740 assert_eq!(s.data, vec![1.0, 2.0, 3.0, 4.0]);
741 }
742
743 #[test]
744 fn test_slice() {
745 let t = Tensor::from_vec(
746 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
747 vec![3, 3],
748 );
749 let s = t.slice(&[0..2, 1..3]);
750 assert_eq!(s.shape, vec![2, 2]);
751 assert_eq!(s.data, vec![2.0, 3.0, 5.0, 6.0]);
752 }
753
754 #[test]
755 fn test_dropout() {
756 let t = Tensor::ones(vec![100]);
757 let d = t.dropout(0.5, 42, true);
758 let zeros = d.data.iter().filter(|&&v| v == 0.0).count();
759 assert!(zeros > 10 && zeros < 90);
761 let d2 = t.dropout(0.5, 42, false);
763 assert_eq!(d2.data, t.data);
764 }
765
766 #[test]
767 fn test_argmax() {
768 let t = Tensor::from_vec(vec![1.0, 5.0, 3.0, 9.0, 2.0, 4.0], vec![2, 3]);
769 let am = t.argmax(1);
770 assert_eq!(am.shape, vec![2]);
771 assert_eq!(am.data, vec![1.0, 0.0]); }
773}