1
2use cjc_repro::Rng;
3
4use crate::accumulator::{binned_sum_f64, BinnedAccumulatorF64};
5use cjc_repro::KahanAccumulatorF64;
6
7use crate::accumulator;
8use crate::buffer::Buffer;
9use crate::dispatch;
10use crate::error::RuntimeError;
11use crate::kernel as kernel_fns;
12use crate::tensor_simd::{self, BinOp, UnaryOp};
13use crate::tensor_tiled::TiledMatmul;
14
15#[derive(Debug, Clone)]
24pub struct Tensor {
25 pub buffer: Buffer<f64>,
26 pub(crate) shape: Vec<usize>,
27 pub(crate) strides: Vec<usize>,
28 pub(crate) offset: usize,
29}
30
31impl Tensor {
32 pub(crate) fn compute_strides(shape: &[usize]) -> Vec<usize> {
36 let mut strides = vec![1usize; shape.len()];
37 for i in (0..shape.len().saturating_sub(1)).rev() {
38 strides[i] = strides[i + 1] * shape[i + 1];
39 }
40 strides
41 }
42
43 fn shape_numel(shape: &[usize]) -> usize {
45 shape.iter().product()
46 }
47
48 pub fn zeros(shape: &[usize]) -> Self {
50 let numel = Self::shape_numel(shape);
51 Tensor {
52 buffer: Buffer::alloc(numel, 0.0),
53 shape: shape.to_vec(),
54 strides: Self::compute_strides(shape),
55 offset: 0,
56 }
57 }
58
59 pub fn ones(shape: &[usize]) -> Self {
61 let numel = Self::shape_numel(shape);
62 Tensor {
63 buffer: Buffer::alloc(numel, 1.0),
64 shape: shape.to_vec(),
65 strides: Self::compute_strides(shape),
66 offset: 0,
67 }
68 }
69
70 pub fn randn(shape: &[usize], rng: &mut Rng) -> Self {
73 let numel = Self::shape_numel(shape);
74 let data: Vec<f64> = (0..numel).map(|_| rng.next_normal_f64()).collect();
75 Tensor {
76 buffer: Buffer::from_vec(data),
77 shape: shape.to_vec(),
78 strides: Self::compute_strides(shape),
79 offset: 0,
80 }
81 }
82
83 pub fn from_vec(data: Vec<f64>, shape: &[usize]) -> Result<Self, RuntimeError> {
86 let numel = Self::shape_numel(shape);
87 if data.len() != numel {
88 return Err(RuntimeError::ShapeMismatch {
89 expected: numel,
90 got: data.len(),
91 });
92 }
93 Ok(Tensor {
94 buffer: Buffer::from_vec(data),
95 shape: shape.to_vec(),
96 strides: Self::compute_strides(shape),
97 offset: 0,
98 })
99 }
100
101 pub fn shape(&self) -> &[usize] {
105 &self.shape
106 }
107
108 pub fn ndim(&self) -> usize {
110 self.shape.len()
111 }
112
113 pub fn len(&self) -> usize {
115 Self::shape_numel(&self.shape)
116 }
117
118 pub fn is_empty(&self) -> bool {
120 self.len() == 0
121 }
122
123 fn linear_index(&self, indices: &[usize]) -> Result<usize, RuntimeError> {
125 if indices.len() != self.shape.len() {
126 return Err(RuntimeError::DimensionMismatch {
127 expected: self.shape.len(),
128 got: indices.len(),
129 });
130 }
131 let mut off = self.offset;
132 for (i, &idx) in indices.iter().enumerate() {
133 if idx >= self.shape[i] {
134 return Err(RuntimeError::IndexOutOfBounds {
135 index: idx,
136 length: self.shape[i],
137 });
138 }
139 off += idx * self.strides[i];
140 }
141 Ok(off)
142 }
143
144 pub fn is_contiguous(&self) -> bool {
146 if self.offset != 0 {
147 return false;
148 }
149 let expected = Self::compute_strides(&self.shape);
150 self.strides == expected
151 }
152
153 pub fn slice(&self, ranges: &[(usize, usize)]) -> Result<Tensor, RuntimeError> {
156 if ranges.len() != self.shape.len() {
157 return Err(RuntimeError::DimensionMismatch {
158 expected: self.shape.len(),
159 got: ranges.len(),
160 });
161 }
162 let mut new_offset = self.offset;
163 let mut new_shape = Vec::with_capacity(ranges.len());
164 for (i, &(start, end)) in ranges.iter().enumerate() {
165 if end > self.shape[i] || start > end {
166 return Err(RuntimeError::IndexOutOfBounds {
167 index: end,
168 length: self.shape[i],
169 });
170 }
171 new_offset += start * self.strides[i];
172 new_shape.push(end - start);
173 }
174 Ok(Tensor {
175 buffer: self.buffer.clone(), shape: new_shape,
177 strides: self.strides.clone(),
178 offset: new_offset,
179 })
180 }
181
182 pub fn to_contiguous(&self) -> Tensor {
184 if self.is_contiguous() {
185 return self.clone();
186 }
187 let data = self.to_vec();
188 Tensor {
189 buffer: Buffer::from_vec(data),
190 shape: self.shape.clone(),
191 strides: Self::compute_strides(&self.shape),
192 offset: 0,
193 }
194 }
195
196 pub fn broadcast_to(&self, target_shape: &[usize]) -> Result<Tensor, RuntimeError> {
199 let src_ndim = self.shape.len();
200 let tgt_ndim = target_shape.len();
201 if tgt_ndim < src_ndim {
202 return Err(RuntimeError::InvalidOperation(
203 "cannot broadcast to a smaller rank".to_string(),
204 ));
205 }
206 let pad = tgt_ndim - src_ndim;
207 let mut new_strides = vec![0usize; tgt_ndim];
208 for i in 0..tgt_ndim {
209 if i < pad {
210 new_strides[i] = 0;
212 } else {
213 let src_i = i - pad;
214 if self.shape[src_i] == target_shape[i] {
215 new_strides[i] = self.strides[src_i];
216 } else if self.shape[src_i] == 1 {
217 new_strides[i] = 0; } else {
219 return Err(RuntimeError::ShapeMismatch {
220 expected: target_shape[i],
221 got: self.shape[src_i],
222 });
223 }
224 }
225 }
226 Ok(Tensor {
227 buffer: self.buffer.clone(),
228 shape: target_shape.to_vec(),
229 strides: new_strides,
230 offset: self.offset,
231 })
232 }
233
234 pub fn get(&self, indices: &[usize]) -> Result<f64, RuntimeError> {
236 let offset = self.linear_index(indices)?;
237 self.buffer
238 .get(offset)
239 .ok_or(RuntimeError::IndexOutOfBounds {
240 index: offset,
241 length: self.buffer.len(),
242 })
243 }
244
245 pub fn set(&mut self, indices: &[usize], val: f64) -> Result<(), RuntimeError> {
247 let offset = self.linear_index(indices)?;
248 self.buffer.set(offset, val)
249 }
250
251 pub fn to_vec(&self) -> Vec<f64> {
253 if self.is_contiguous() {
254 let full = self.buffer.borrow_data();
255 let numel = self.len();
256 if full.len() == numel {
257 return full.to_vec();
258 }
259 return full[..numel].to_vec();
262 }
263 let numel = self.len();
265 let mut result = Vec::with_capacity(numel);
266 let ndim = self.shape.len();
267 let mut indices = vec![0usize; ndim];
268 for _ in 0..numel {
269 let mut off = self.offset;
270 for d in 0..ndim {
271 off += indices[d] * self.strides[d];
272 }
273 result.push(self.buffer.get(off).unwrap_or(0.0));
274 for d in (0..ndim).rev() {
276 indices[d] += 1;
277 if indices[d] < self.shape[d] {
278 break;
279 }
280 indices[d] = 0;
281 }
282 }
283 result
284 }
285
286 pub fn reshape(&self, new_shape: &[usize]) -> Result<Tensor, RuntimeError> {
291 let new_numel = Self::shape_numel(new_shape);
292 if new_numel != self.len() {
293 return Err(RuntimeError::ShapeMismatch {
294 expected: self.len(),
295 got: new_numel,
296 });
297 }
298 let tensor = if self.is_contiguous() { self.clone() } else { self.to_contiguous() };
300 Ok(Tensor {
301 buffer: tensor.buffer,
302 shape: new_shape.to_vec(),
303 strides: Self::compute_strides(new_shape),
304 offset: 0,
305 })
306 }
307
308 fn elementwise_binop(
312 &self,
313 other: &Tensor,
314 op: impl Fn(f64, f64) -> f64,
315 ) -> Result<Tensor, RuntimeError> {
316 if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
317 let a = self.buffer.borrow_data();
319 let b = other.buffer.borrow_data();
320 let data: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| op(x, y)).collect();
321 return Ok(Tensor {
322 buffer: Buffer::from_vec(data),
323 shape: self.shape.clone(),
324 strides: Self::compute_strides(&self.shape),
325 offset: 0,
326 });
327 }
328
329 let result_shape = Self::broadcast_result_shape(&self.shape, &other.shape)?;
331 let a_broadcast = self.broadcast_to(&result_shape)?;
332 let b_broadcast = other.broadcast_to(&result_shape)?;
333
334 let numel = Self::shape_numel(&result_shape);
335 let ndim = result_shape.len();
336 let mut data = Vec::with_capacity(numel);
337 let mut indices = vec![0usize; ndim];
338
339 for _ in 0..numel {
340 let mut off_a = a_broadcast.offset;
341 let mut off_b = b_broadcast.offset;
342 for d in 0..ndim {
343 off_a += indices[d] * a_broadcast.strides[d];
344 off_b += indices[d] * b_broadcast.strides[d];
345 }
346 let va = a_broadcast.buffer.get(off_a).unwrap_or(0.0);
347 let vb = b_broadcast.buffer.get(off_b).unwrap_or(0.0);
348 data.push(op(va, vb));
349
350 for d in (0..ndim).rev() {
351 indices[d] += 1;
352 if indices[d] < result_shape[d] {
353 break;
354 }
355 indices[d] = 0;
356 }
357 }
358
359 Ok(Tensor {
360 buffer: Buffer::from_vec(data),
361 shape: result_shape.clone(),
362 strides: Self::compute_strides(&result_shape),
363 offset: 0,
364 })
365 }
366
367 fn broadcast_result_shape(a: &[usize], b: &[usize]) -> Result<Vec<usize>, RuntimeError> {
369 let max_ndim = a.len().max(b.len());
370 let mut result = Vec::with_capacity(max_ndim);
371 for i in 0..max_ndim {
372 let da = if i < max_ndim - a.len() { 1 } else { a[i - (max_ndim - a.len())] };
373 let db = if i < max_ndim - b.len() { 1 } else { b[i - (max_ndim - b.len())] };
374 if da == db {
375 result.push(da);
376 } else if da == 1 {
377 result.push(db);
378 } else if db == 1 {
379 result.push(da);
380 } else {
381 return Err(RuntimeError::ShapeMismatch {
382 expected: da,
383 got: db,
384 });
385 }
386 }
387 Ok(result)
388 }
389
390 fn elementwise_binop_simd(
395 &self,
396 other: &Tensor,
397 op: BinOp,
398 fallback: impl Fn(f64, f64) -> f64,
399 ) -> Result<Tensor, RuntimeError> {
400 if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
401 let a = self.buffer.borrow_data();
403 let b = other.buffer.borrow_data();
404 let data = tensor_simd::simd_binop(&a, &b, op);
405 return Ok(Tensor {
406 buffer: Buffer::from_vec(data),
407 shape: self.shape.clone(),
408 strides: Self::compute_strides(&self.shape),
409 offset: 0,
410 });
411 }
412 self.elementwise_binop(other, fallback)
414 }
415
416 pub fn add(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
418 self.elementwise_binop_simd(other, BinOp::Add, |a, b| a + b)
419 }
420
421 pub fn sub(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
423 self.elementwise_binop_simd(other, BinOp::Sub, |a, b| a - b)
424 }
425
426 pub fn mul_elem(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
428 self.elementwise_binop_simd(other, BinOp::Mul, |a, b| a * b)
429 }
430
431 pub fn div_elem(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
433 self.elementwise_binop_simd(other, BinOp::Div, |a, b| a / b)
434 }
435
436 pub fn fused_mul_add(&self, b: &Tensor, c: &Tensor) -> Result<Tensor, RuntimeError> {
442 if self.shape != b.shape || self.shape != c.shape {
443 return Err(RuntimeError::InvalidOperation(
444 "broadcast_fma: all three tensors must have the same shape".to_string(),
445 ));
446 }
447 if self.is_contiguous() && b.is_contiguous() && c.is_contiguous() {
448 let a_data = self.buffer.borrow_data();
449 let b_data = b.buffer.borrow_data();
450 let c_data = c.buffer.borrow_data();
451 let n = a_data.len();
452 let mut out = vec![0.0f64; n];
453 for i in 0..n {
456 out[i] = a_data[i] * b_data[i] + c_data[i];
457 }
458 return Ok(Tensor {
459 buffer: Buffer::from_vec(out),
460 shape: self.shape.clone(),
461 strides: Self::compute_strides(&self.shape),
462 offset: 0,
463 });
464 }
465 let temp = self.mul_elem(b)?;
467 temp.add(c)
468 }
469
470 pub fn elem_pow(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
474 self.elementwise_binop(other, |a, b| a.powf(b))
475 }
476
477 pub fn elem_min(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
479 self.elementwise_binop(other, |a, b| a.min(b))
480 }
481
482 pub fn elem_max(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
484 self.elementwise_binop(other, |a, b| a.max(b))
485 }
486
487 pub fn elem_atan2(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
489 self.elementwise_binop(other, |a, b| a.atan2(b))
490 }
491
492 pub fn elem_hypot(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
494 self.elementwise_binop(other, |a, b| a.hypot(b))
495 }
496
497 pub fn map(&self, f: impl Fn(f64) -> f64) -> Tensor {
499 let data: Vec<f64> = self.to_vec().iter().map(|&x| f(x)).collect();
500 Tensor {
501 buffer: Buffer::from_vec(data),
502 shape: self.shape.clone(),
503 strides: Self::compute_strides(&self.shape),
504 offset: 0,
505 }
506 }
507
508 pub fn map_simd(&self, op: UnaryOp) -> Tensor {
513 let src = self.to_vec();
514 let data = tensor_simd::simd_unary(&src, op);
515 Tensor {
516 buffer: Buffer::from_vec(data),
517 shape: self.shape.clone(),
518 strides: Self::compute_strides(&self.shape),
519 offset: 0,
520 }
521 }
522
523 pub fn sum(&self) -> f64 {
527 let data = self.buffer.borrow_data();
528 binned_sum_f64(&data)
529 }
530
531 pub fn binned_sum(&self) -> f64 {
535 let data = self.buffer.borrow_data();
536 accumulator::binned_sum_f64(&data)
537 }
538
539 pub fn dispatched_sum(&self, ctx: &dispatch::ReductionContext) -> f64 {
543 let data = self.buffer.borrow_data();
544 dispatch::dispatch_sum_f64(&data, ctx)
545 }
546
547 pub fn mean(&self) -> f64 {
549 let n = self.len();
550 if n == 0 {
551 return 0.0;
552 }
553 self.sum() / n as f64
554 }
555
556 pub fn dispatched_mean(&self, ctx: &dispatch::ReductionContext) -> f64 {
558 let n = self.len();
559 if n == 0 {
560 return 0.0;
561 }
562 self.dispatched_sum(ctx) / n as f64
563 }
564
565 pub fn sum_axis(&self, axis: usize) -> Result<Tensor, RuntimeError> {
575 let ndim = self.ndim();
576 if axis >= ndim {
577 return Err(RuntimeError::IndexOutOfBounds {
578 index: axis,
579 length: ndim,
580 });
581 }
582
583 let mut out_shape = self.shape.clone();
585 out_shape[axis] = 1;
586 let out_numel = Self::shape_numel(&out_shape);
587 let out_strides = Self::compute_strides(&out_shape);
588
589 let data = self.to_vec();
590 let axis_len = self.shape[axis];
591 let mut result = vec![0.0f64; out_numel];
592
593 let mut indices = vec![0usize; ndim];
595 for out_idx in 0..out_numel {
596 {
598 let mut remaining = out_idx;
599 for d in 0..ndim {
600 indices[d] = remaining / out_strides[d];
601 remaining %= out_strides[d];
602 }
603 }
604
605 let mut acc = BinnedAccumulatorF64::new();
606 for k in 0..axis_len {
607 let mut flat = self.offset;
609 for d in 0..ndim {
610 let idx = if d == axis { k } else { indices[d] };
611 flat += idx * self.strides[d];
612 }
613 acc.add(data[flat]);
614 }
615 result[out_idx] = acc.finalize();
616 }
617
618 Tensor::from_vec(result, &out_shape)
619 }
620
621 pub fn neg(&self) -> Tensor {
625 self.map(|x| -x)
626 }
627
628 pub fn transpose(&self) -> Tensor {
631 let ndim = self.ndim();
632 if ndim <= 1 {
633 return self.clone();
634 }
635 let mut new_shape = self.shape.clone();
637 let mut new_strides = self.strides.clone();
638 new_shape.reverse();
639 new_strides.reverse();
640 Tensor {
641 buffer: self.buffer.clone(), shape: new_shape,
643 strides: new_strides,
644 offset: self.offset,
645 }
646 }
647
648 pub fn transpose_axes(&self, axes: &[usize]) -> Result<Tensor, RuntimeError> {
652 let ndim = self.ndim();
653 if axes.len() != ndim {
654 return Err(RuntimeError::InvalidOperation(
655 format!("transpose_axes: expected {} axes, got {}", ndim, axes.len()),
656 ));
657 }
658 let mut seen = vec![false; ndim];
660 for &ax in axes {
661 if ax >= ndim {
662 return Err(RuntimeError::IndexOutOfBounds { index: ax, length: ndim });
663 }
664 if seen[ax] {
665 return Err(RuntimeError::InvalidOperation(
666 format!("transpose_axes: duplicate axis {ax}"),
667 ));
668 }
669 seen[ax] = true;
670 }
671 let new_shape: Vec<usize> = axes.iter().map(|&ax| self.shape[ax]).collect();
672 let new_strides: Vec<usize> = axes.iter().map(|&ax| self.strides[ax]).collect();
673 Ok(Tensor {
674 buffer: self.buffer.clone(),
675 shape: new_shape,
676 strides: new_strides,
677 offset: self.offset,
678 })
679 }
680
681 pub fn scalar_mul(&self, s: f64) -> Tensor {
683 self.map(|x| x * s)
684 }
685
686 pub fn from_vec_unchecked(data: Vec<f64>, shape: &[usize]) -> Tensor {
691 Self::from_vec(data, shape).expect("Tensor::from_vec_unchecked: shape mismatch")
692 }
693
694 pub fn add_unchecked(&self, other: &Tensor) -> Tensor {
696 self.add(other).expect("Tensor::add shape mismatch")
697 }
698
699 pub fn sub_unchecked(&self, other: &Tensor) -> Tensor {
701 self.sub(other).expect("Tensor::sub shape mismatch")
702 }
703
704 pub fn mul_elem_unchecked(&self, other: &Tensor) -> Tensor {
706 self.mul_elem(other).expect("Tensor::mul_elem shape mismatch")
707 }
708
709 pub fn div_elem_unchecked(&self, other: &Tensor) -> Tensor {
711 self.div_elem(other).expect("Tensor::div_elem shape mismatch")
712 }
713
714 pub fn matmul_unchecked(&self, other: &Tensor) -> Tensor {
716 self.matmul(other).expect("Tensor::matmul dimension mismatch")
717 }
718
719 pub fn matmul(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
723 if self.ndim() != 2 || other.ndim() != 2 {
724 return Err(RuntimeError::InvalidOperation(
725 "matmul requires 2-D tensors".to_string(),
726 ));
727 }
728 let m = self.shape[0];
729 let k = self.shape[1];
730 let k2 = other.shape[0];
731 let n = other.shape[1];
732 if k != k2 {
733 return Err(RuntimeError::DimensionMismatch {
734 expected: k,
735 got: k2,
736 });
737 }
738
739 let a = self.to_vec();
740 let b = other.to_vec();
741
742 #[cfg(feature = "parallel")]
745 {
746 if m >= 256 || n >= 256 || k >= 256 {
747 return Self::matmul_parallel_mode_a(&a, &b, m, n, k);
748 }
749 }
750
751 if m >= 64 || n >= 64 || k >= 64 {
756 return Self::matmul_tiled(&a, &b, m, n, k);
757 }
758
759 Self::matmul_sequential(&a, &b, m, n, k)
761 }
762
763 fn matmul_sequential(
765 a: &[f64], b: &[f64], m: usize, n: usize, k: usize,
766 ) -> Result<Tensor, RuntimeError> {
767 let mut result = vec![0.0f64; m * n];
768 for i in 0..m {
769 for j in 0..n {
770 let mut acc = KahanAccumulatorF64::new();
771 for p in 0..k {
772 acc.add(a[i * k + p] * b[p * n + j]);
773 }
774 result[i * n + j] = acc.finalize();
775 }
776 }
777 Tensor::from_vec(result, &[m, n])
778 }
779
780 fn matmul_tiled(
787 a: &[f64], b: &[f64], m: usize, n: usize, k: usize,
788 ) -> Result<Tensor, RuntimeError> {
789 let engine = TiledMatmul::new();
790 let result = engine.matmul(a, m, k, b, n);
791 Tensor::from_vec(result, &[m, n])
792 }
793
794 #[cfg(feature = "parallel")]
805 fn matmul_parallel_mode_a(
806 a: &[f64], b: &[f64], m: usize, n: usize, k: usize,
807 ) -> Result<Tensor, RuntimeError> {
808 use rayon::prelude::*;
809 use cjc_repro::KahanAccumulatorF64;
810
811 if m >= 512 && n >= 512 {
816 let band_size = (m + rayon::current_num_threads() - 1) / rayon::current_num_threads();
819 let band_size = band_size.max(64); let mut result = vec![0.0f64; m * n];
821
822 result
823 .par_chunks_mut(band_size * n)
824 .enumerate()
825 .for_each(|(band_idx, band)| {
826 let i_start = band_idx * band_size;
827 let i_end = (i_start + band_size).min(m);
828 let band_m = i_end - i_start;
829 let a_band = &a[i_start * k .. i_end * k];
830 let engine = crate::tensor_tiled::TiledMatmul::new();
831 let tiled_result = engine.matmul(a_band, band_m, k, b, n);
832 band[..band_m * n].copy_from_slice(&tiled_result);
833 });
834
835 return Tensor::from_vec(result, &[m, n]);
836 }
837
838 let mut result = vec![0.0f64; m * n];
840 result
841 .par_chunks_mut(n)
842 .enumerate()
843 .for_each(|(i, row)| {
844 for j in 0..n {
845 let mut acc = KahanAccumulatorF64::new();
846 for p in 0..k {
847 acc.add(a[i * k + p] * b[p * n + j]);
848 }
849 row[j] = acc.finalize();
850 }
851 });
852
853 Tensor::from_vec(result, &[m, n])
854 }
855
856 pub fn bmm(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
864 if self.ndim() < 2 || other.ndim() < 2 {
865 return Err(RuntimeError::InvalidOperation(
866 "bmm requires at least 2-D tensors".to_string(),
867 ));
868 }
869 if self.ndim() == 2 && other.ndim() == 2 {
870 return self.matmul(other);
871 }
872 if self.ndim() != other.ndim() {
873 return Err(RuntimeError::InvalidOperation(
874 format!(
875 "bmm requires same number of dimensions, got {} and {}",
876 self.ndim(),
877 other.ndim()
878 ),
879 ));
880 }
881 let nd = self.ndim();
882 let batch_dims_a = &self.shape[..nd - 2];
883 let batch_dims_b = &other.shape[..nd - 2];
884 if batch_dims_a != batch_dims_b {
885 return Err(RuntimeError::InvalidOperation(
886 format!(
887 "bmm batch dimensions mismatch: {:?} vs {:?}",
888 batch_dims_a, batch_dims_b
889 ),
890 ));
891 }
892 let m = self.shape[nd - 2];
893 let k = self.shape[nd - 1];
894 let k2 = other.shape[nd - 2];
895 let n = other.shape[nd - 1];
896 if k != k2 {
897 return Err(RuntimeError::DimensionMismatch {
898 expected: k,
899 got: k2,
900 });
901 }
902
903 let batch_size: usize = batch_dims_a.iter().product();
904 let a = self.to_vec();
905 let b = other.to_vec();
906 let mat_a_stride = m * k;
907 let mat_b_stride = k * n;
908 let mat_c_stride = m * n;
909 let mut result = vec![0.0f64; batch_size * mat_c_stride];
910
911 let compute_batch = |batch: usize, c_slice: &mut [f64]| {
913 let a_slice = &a[batch * mat_a_stride..(batch + 1) * mat_a_stride];
914 let b_slice = &b[batch * mat_b_stride..(batch + 1) * mat_b_stride];
915
916 if m >= 64 || n >= 64 || k >= 64 {
917 let engine = crate::tensor_tiled::TiledMatmul::new();
918 let tiled = engine.matmul(a_slice, m, k, b_slice, n);
919 c_slice.copy_from_slice(&tiled);
920 } else {
921 for i in 0..m {
922 for j in 0..n {
923 let mut acc = KahanAccumulatorF64::new();
924 for p in 0..k {
925 acc.add(a_slice[i * k + p] * b_slice[p * n + j]);
926 }
927 c_slice[i * n + j] = acc.finalize();
928 }
929 }
930 }
931 };
932
933 #[cfg(feature = "parallel")]
935 {
936 if batch_size > 1 && m * k >= 4096 {
937 use rayon::prelude::*;
938 result
939 .par_chunks_mut(mat_c_stride)
940 .enumerate()
941 .for_each(|(batch, c_slice)| {
942 compute_batch(batch, c_slice);
943 });
944
945 let mut out_shape = batch_dims_a.to_vec();
946 out_shape.push(m);
947 out_shape.push(n);
948 return Tensor::from_vec(result, &out_shape);
949 }
950 }
951
952 for batch in 0..batch_size {
954 let c_off = batch * mat_c_stride;
955 compute_batch(batch, &mut result[c_off..c_off + mat_c_stride]);
956 }
957
958 let mut out_shape = batch_dims_a.to_vec();
959 out_shape.push(m);
960 out_shape.push(n);
961 Tensor::from_vec(result, &out_shape)
962 }
963
964 pub fn softmax(&self) -> Result<Tensor, RuntimeError> {
972 if self.ndim() == 0 {
973 return Err(RuntimeError::InvalidOperation(
974 "softmax requires at least 1-D tensor".to_string(),
975 ));
976 }
977 let data_ref;
979 let data_vec;
980 let data: &[f64] = if self.is_contiguous() && self.offset == 0 {
981 data_ref = self.buffer.borrow_data();
982 &data_ref
983 } else {
984 data_vec = self.to_vec();
985 &data_vec
986 };
987 let n = *self.shape.last().unwrap(); let outer: usize = data.len() / n; let mut result = vec![0.0f64; data.len()];
990
991 for row in 0..outer {
992 let start = row * n;
993 let end = start + n;
994 let slice = &data[start..end];
995
996 let mut max_val = f64::NEG_INFINITY;
998 for &v in slice {
999 if v > max_val {
1000 max_val = v;
1001 }
1002 }
1003
1004 let mut exp_vals = vec![0.0f64; n];
1006 let mut sum = 0.0f64;
1007 let mut comp = 0.0f64; for i in 0..n {
1009 let e = (slice[i] - max_val).exp();
1010 exp_vals[i] = e;
1011 let y = e - comp;
1013 let t = sum + y;
1014 comp = (t - sum) - y;
1015 sum = t;
1016 }
1017
1018 if sum == 0.0 {
1020 let uniform = 1.0 / n as f64;
1022 for i in 0..n {
1023 result[start + i] = uniform;
1024 }
1025 } else {
1026 for i in 0..n {
1027 result[start + i] = exp_vals[i] / sum;
1028 }
1029 }
1030 }
1031
1032 Tensor::from_vec(result, &self.shape)
1033 }
1034
1035 pub fn layer_norm(
1046 &self,
1047 gamma: &Tensor,
1048 beta: &Tensor,
1049 eps: f64,
1050 ) -> Result<Tensor, RuntimeError> {
1051 if self.ndim() == 0 {
1052 return Err(RuntimeError::InvalidOperation(
1053 "layer_norm requires at least 1-D tensor".to_string(),
1054 ));
1055 }
1056 let d = *self.shape.last().unwrap();
1057 if gamma.len() != d || beta.len() != d {
1058 return Err(RuntimeError::InvalidOperation(
1059 format!(
1060 "layer_norm: gamma/beta length {} must match last dim {}",
1061 gamma.len(),
1062 d
1063 ),
1064 ));
1065 }
1066
1067 let data = self.to_vec();
1068 let gamma_data = gamma.to_vec();
1069 let beta_data = beta.to_vec();
1070 let outer = data.len() / d;
1071 let mut result = vec![0.0f64; data.len()];
1072
1073 for row in 0..outer {
1074 let start = row * d;
1075 let slice = &data[start..start + d];
1076
1077 let mean = binned_sum_f64(slice) / d as f64;
1079
1080 let diffs: Vec<f64> = slice.iter().map(|&x| {
1082 let diff = x - mean;
1083 diff * diff
1084 }).collect();
1085 let variance = binned_sum_f64(&diffs) / d as f64;
1086
1087 let inv_std = 1.0 / (variance + eps).sqrt();
1089 for i in 0..d {
1090 let normalized = (slice[i] - mean) * inv_std;
1091 result[start + i] = gamma_data[i] * normalized + beta_data[i];
1092 }
1093 }
1094
1095 Tensor::from_vec(result, &self.shape)
1096 }
1097
1098 fn map_elementwise(&self, f: impl Fn(f64) -> f64) -> Tensor {
1103 if self.is_contiguous() && self.offset == 0 && self.buffer.refcount() == 1 {
1104 let mut data = self.buffer.borrow_data().clone();
1106 for x in data.iter_mut() {
1107 *x = f(*x);
1108 }
1109 Tensor::from_vec(data, &self.shape).unwrap()
1110 } else {
1111 let data = self.to_vec();
1113 let result: Vec<f64> = data.iter().map(|&x| f(x)).collect();
1114 Tensor::from_vec(result, &self.shape).unwrap()
1115 }
1116 }
1117
1118 pub fn relu(&self) -> Tensor {
1119 self.map_elementwise(|x| if x > 0.0 { x } else { 0.0 })
1120 }
1121
1122 pub fn sigmoid(&self) -> Tensor {
1124 self.map_elementwise(|x| 1.0 / (1.0 + (-x).exp()))
1125 }
1126
1127 pub fn tanh_activation(&self) -> Tensor {
1129 self.map_elementwise(|x| x.tanh())
1130 }
1131
1132 pub fn leaky_relu(&self, alpha: f64) -> Tensor {
1134 self.map_elementwise(move |x| if x > 0.0 { x } else { alpha * x })
1135 }
1136
1137 pub fn silu(&self) -> Tensor {
1139 let data = self.to_vec();
1140 let result: Vec<f64> = data.iter().map(|&x| x / (1.0 + (-x).exp())).collect();
1141 Tensor::from_vec(result, &self.shape).unwrap()
1142 }
1143
1144 pub fn mish(&self) -> Tensor {
1146 let data = self.to_vec();
1147 let result: Vec<f64> = data.iter().map(|&x| {
1148 let sp = (1.0 + x.exp()).ln();
1149 x * sp.tanh()
1150 }).collect();
1151 Tensor::from_vec(result, &self.shape).unwrap()
1152 }
1153
1154 pub fn argmax(&self) -> usize {
1156 let data = self.to_vec();
1157 let mut best_idx = 0;
1158 let mut best_val = f64::NEG_INFINITY;
1159 for (i, &v) in data.iter().enumerate() {
1160 if v > best_val || (v == best_val && i < best_idx) {
1161 best_val = v;
1162 best_idx = i;
1163 }
1164 }
1165 best_idx
1166 }
1167
1168 pub fn argmin(&self) -> usize {
1170 let data = self.to_vec();
1171 let mut best_idx = 0;
1172 let mut best_val = f64::INFINITY;
1173 for (i, &v) in data.iter().enumerate() {
1174 if v < best_val || (v == best_val && i < best_idx) {
1175 best_val = v;
1176 best_idx = i;
1177 }
1178 }
1179 best_idx
1180 }
1181
1182 pub fn clamp(&self, min: f64, max: f64) -> Tensor {
1184 let data = self.to_vec();
1185 let result: Vec<f64> = data.iter().map(|&x| x.max(min).min(max)).collect();
1186 Tensor::from_vec(result, &self.shape).unwrap()
1187 }
1188
1189 pub fn one_hot(indices: &[usize], depth: usize) -> Result<Tensor, RuntimeError> {
1192 let n = indices.len();
1193 let mut data = vec![0.0; n * depth];
1194 for (i, &idx) in indices.iter().enumerate() {
1195 if idx >= depth {
1196 return Err(RuntimeError::InvalidOperation(format!(
1197 "one_hot: index {idx} >= depth {depth}"
1198 )));
1199 }
1200 data[i * depth + idx] = 1.0;
1201 }
1202 Tensor::from_vec(data, &[n, depth])
1203 }
1204
1205 pub fn cat(tensors: &[&Tensor], axis: usize) -> Result<Tensor, RuntimeError> {
1211 if tensors.is_empty() {
1212 return Err(RuntimeError::InvalidOperation("cat: no tensors".to_string()));
1213 }
1214 let ndim = tensors[0].ndim();
1215 if axis >= ndim {
1216 return Err(RuntimeError::InvalidOperation(
1217 format!("cat: axis {axis} out of bounds for {ndim}D tensor"),
1218 ));
1219 }
1220 for (i, t) in tensors.iter().enumerate().skip(1) {
1221 if t.ndim() != ndim {
1222 return Err(RuntimeError::InvalidOperation(
1223 format!("cat: tensor {i} has different ndim"),
1224 ));
1225 }
1226 for d in 0..ndim {
1227 if d != axis && t.shape[d] != tensors[0].shape[d] {
1228 return Err(RuntimeError::InvalidOperation(
1229 format!("cat: shape mismatch at dim {d}"),
1230 ));
1231 }
1232 }
1233 }
1234 let mut out_shape = tensors[0].shape.clone();
1235 for t in tensors.iter().skip(1) {
1236 out_shape[axis] += t.shape[axis];
1237 }
1238 let total = out_shape.iter().product::<usize>();
1239 let mut result = vec![0.0; total];
1240 let mut out_strides = vec![1usize; ndim];
1241 for d in (0..ndim - 1).rev() {
1242 out_strides[d] = out_strides[d + 1] * out_shape[d + 1];
1243 }
1244 let mut offset = 0;
1245 for t in tensors {
1246 let t_data = t.to_vec();
1247 let t_total: usize = t.shape.iter().product();
1248 let mut t_strides = vec![1usize; ndim];
1249 for d in (0..ndim - 1).rev() {
1250 t_strides[d] = t_strides[d + 1] * t.shape[d + 1];
1251 }
1252 for idx in 0..t_total {
1253 let mut remaining = idx;
1254 let mut out_flat = 0;
1255 for d in 0..ndim {
1256 let coord = remaining / t_strides[d];
1257 remaining %= t_strides[d];
1258 let out_coord = if d == axis { coord + offset } else { coord };
1259 out_flat += out_coord * out_strides[d];
1260 }
1261 result[out_flat] = t_data[idx];
1262 }
1263 offset += t.shape[axis];
1264 }
1265 Tensor::from_vec(result, &out_shape)
1266 }
1267
1268 pub fn stack(tensors: &[&Tensor], axis: usize) -> Result<Tensor, RuntimeError> {
1270 if tensors.is_empty() {
1271 return Err(RuntimeError::InvalidOperation("stack: no tensors".to_string()));
1272 }
1273 let base_shape = &tensors[0].shape;
1274 let ndim = base_shape.len();
1275 if axis > ndim {
1276 return Err(RuntimeError::InvalidOperation(
1277 format!("stack: axis {axis} out of bounds"),
1278 ));
1279 }
1280 for (i, t) in tensors.iter().enumerate().skip(1) {
1281 if &t.shape != base_shape {
1282 return Err(RuntimeError::InvalidOperation(
1283 format!("stack: tensor {i} shape mismatch"),
1284 ));
1285 }
1286 }
1287 let mut out_shape = Vec::with_capacity(ndim + 1);
1288 for d in 0..axis { out_shape.push(base_shape[d]); }
1289 out_shape.push(tensors.len());
1290 for d in axis..ndim { out_shape.push(base_shape[d]); }
1291 let total: usize = out_shape.iter().product();
1292 let mut result = vec![0.0; total];
1293 let inner_size: usize = base_shape[axis..].iter().product::<usize>().max(1);
1294 let outer_size: usize = base_shape[..axis].iter().product::<usize>().max(1);
1295 for (t_idx, t) in tensors.iter().enumerate() {
1296 let t_data = t.to_vec();
1297 for outer in 0..outer_size {
1298 for inner in 0..inner_size {
1299 let src = outer * inner_size + inner;
1300 let dst = outer * (tensors.len() * inner_size) + t_idx * inner_size + inner;
1301 if src < t_data.len() && dst < result.len() {
1302 result[dst] = t_data[src];
1303 }
1304 }
1305 }
1306 }
1307 Tensor::from_vec(result, &out_shape)
1308 }
1309
1310 pub fn topk(&self, k: usize) -> Result<(Tensor, Vec<usize>), RuntimeError> {
1312 let data = self.to_vec();
1313 let n = data.len();
1314 if k > n {
1315 return Err(RuntimeError::InvalidOperation(
1316 format!("topk: k={k} exceeds data length {n}"),
1317 ));
1318 }
1319 let mut indexed: Vec<(usize, f64)> = data.into_iter().enumerate().collect();
1320 indexed.sort_by(|a, b| b.1.total_cmp(&a.1).then(a.0.cmp(&b.0)));
1321 let top_k: Vec<(usize, f64)> = indexed[..k].to_vec();
1322 let values: Vec<f64> = top_k.iter().map(|&(_, v)| v).collect();
1323 let indices: Vec<usize> = top_k.iter().map(|&(i, _)| i).collect();
1324 Ok((Tensor::from_vec(values, &[k])?, indices))
1325 }
1326
1327 pub fn gelu(&self) -> Tensor {
1329 let data = self.to_vec();
1330 let sqrt_2_over_pi = (2.0_f64 / std::f64::consts::PI).sqrt();
1331 let result: Vec<f64> = data.iter().map(|&x| {
1332 let inner = sqrt_2_over_pi * (x + 0.044715 * x * x * x);
1333 0.5 * x * (1.0 + inner.tanh())
1334 }).collect();
1335 Tensor::from_vec(result, &self.shape).unwrap()
1336 }
1337
1338 pub fn linear(
1344 &self,
1345 weight: &Tensor,
1346 bias: &Tensor,
1347 ) -> Result<Tensor, RuntimeError> {
1348 if weight.ndim() != 2 {
1349 return Err(RuntimeError::InvalidOperation(
1350 "linear: weight must be 2-D [out_features, in_features]".to_string(),
1351 ));
1352 }
1353 let out_features = weight.shape[0];
1354 let in_features = weight.shape[1];
1355 let last_dim = *self.shape.last().ok_or_else(|| {
1356 RuntimeError::InvalidOperation("linear: input must be at least 1-D".to_string())
1357 })?;
1358 if last_dim != in_features {
1359 return Err(RuntimeError::DimensionMismatch {
1360 expected: in_features,
1361 got: last_dim,
1362 });
1363 }
1364 if bias.len() != out_features {
1365 return Err(RuntimeError::InvalidOperation(
1366 format!(
1367 "linear: bias length {} must match out_features {}",
1368 bias.len(),
1369 out_features
1370 ),
1371 ));
1372 }
1373
1374 let data = self.to_vec();
1375 let w = weight.to_vec();
1376 let b = bias.to_vec();
1377 let outer = data.len() / in_features;
1378 let mut result = vec![0.0f64; outer * out_features];
1379
1380 for row in 0..outer {
1381 let x_start = row * in_features;
1382 let x_slice = &data[x_start..x_start + in_features];
1383 let y_start = row * out_features;
1384 for j in 0..out_features {
1385 let w_start = j * in_features;
1386 let mut acc = BinnedAccumulatorF64::new();
1387 for p in 0..in_features {
1388 acc.add(x_slice[p] * w[w_start + p]);
1389 }
1390 result[y_start + j] = acc.finalize() + b[j];
1391 }
1392 }
1393
1394 let mut out_shape = self.shape[..self.shape.len() - 1].to_vec();
1395 out_shape.push(out_features);
1396 Tensor::from_vec(result, &out_shape)
1397 }
1398
1399 pub fn conv1d(
1403 &self,
1404 filters: &Tensor,
1405 bias: &Tensor,
1406 ) -> Result<Tensor, RuntimeError> {
1407 if self.ndim() != 1 {
1408 return Err(RuntimeError::InvalidOperation(
1409 "conv1d: input must be 1-D [signal_len]".to_string(),
1410 ));
1411 }
1412 if filters.ndim() != 2 {
1413 return Err(RuntimeError::InvalidOperation(
1414 "conv1d: filters must be 2-D [out_channels, kernel_size]".to_string(),
1415 ));
1416 }
1417 let signal_len = self.shape[0];
1418 let out_channels = filters.shape[0];
1419 let kernel_size = filters.shape[1];
1420 if signal_len < kernel_size {
1421 return Err(RuntimeError::InvalidOperation(
1422 format!(
1423 "conv1d: signal_len {} < kernel_size {}",
1424 signal_len, kernel_size
1425 ),
1426 ));
1427 }
1428 if bias.len() != out_channels {
1429 return Err(RuntimeError::InvalidOperation(
1430 format!(
1431 "conv1d: bias length {} must match out_channels {}",
1432 bias.len(), out_channels
1433 ),
1434 ));
1435 }
1436 let out_len = signal_len - kernel_size + 1;
1437 let s = self.to_vec();
1438 let f = filters.to_vec();
1439 let b = bias.to_vec();
1440 let mut result = vec![0.0; out_channels * out_len];
1441 kernel_fns::conv1d_raw(&s, &f, &b, &mut result, signal_len, out_channels, kernel_size);
1442 Tensor::from_vec(result, &[out_channels, out_len])
1443 }
1444
1445 pub fn conv2d(
1459 &self,
1460 filters: &Tensor,
1461 bias: &Tensor,
1462 stride: usize,
1463 ) -> Result<Tensor, RuntimeError> {
1464 if self.ndim() != 4 {
1465 return Err(RuntimeError::InvalidOperation(
1466 "conv2d: input must be 4-D [N, C_in, H, W]".to_string(),
1467 ));
1468 }
1469 if filters.ndim() != 4 {
1470 return Err(RuntimeError::InvalidOperation(
1471 "conv2d: filters must be 4-D [C_out, C_in, kH, kW]".to_string(),
1472 ));
1473 }
1474 if stride == 0 {
1475 return Err(RuntimeError::InvalidOperation(
1476 "conv2d: stride must be >= 1".to_string(),
1477 ));
1478 }
1479
1480 let n = self.shape[0];
1481 let c_in = self.shape[1];
1482 let h_in = self.shape[2];
1483 let w_in = self.shape[3];
1484
1485 let c_out = filters.shape[0];
1486 let c_in_check = filters.shape[1];
1487 let kh = filters.shape[2];
1488 let kw = filters.shape[3];
1489
1490 if c_in != c_in_check {
1491 return Err(RuntimeError::InvalidOperation(format!(
1492 "conv2d: input C_in={} does not match filter C_in={}",
1493 c_in, c_in_check
1494 )));
1495 }
1496 if h_in < kh || w_in < kw {
1497 return Err(RuntimeError::InvalidOperation(format!(
1498 "conv2d: input spatial [{}, {}] is smaller than kernel [{}, {}]",
1499 h_in, w_in, kh, kw
1500 )));
1501 }
1502 if bias.len() != c_out {
1503 return Err(RuntimeError::InvalidOperation(format!(
1504 "conv2d: bias length {} must match C_out={}",
1505 bias.len(), c_out
1506 )));
1507 }
1508
1509 let h_out = (h_in - kh) / stride + 1;
1510 let w_out = (w_in - kw) / stride + 1;
1511
1512 let inp = self.to_vec();
1513 let flt = filters.to_vec();
1514 let b = bias.to_vec();
1515 let mut result = vec![0.0f64; n * c_out * h_out * w_out];
1516
1517 kernel_fns::conv2d_raw(&inp, &flt, &b, &mut result,
1518 n, c_in, h_in, w_in, c_out, kh, kw, stride);
1519
1520 Tensor::from_vec(result, &[n, c_out, h_out, w_out])
1521 }
1522
1523 pub fn maxpool2d(&self, ph: usize, pw: usize) -> Result<Tensor, RuntimeError> {
1530 if self.ndim() != 4 {
1531 return Err(RuntimeError::InvalidOperation(
1532 "maxpool2d: input must be 4-D [N, C, H, W]".to_string(),
1533 ));
1534 }
1535 if ph == 0 || pw == 0 {
1536 return Err(RuntimeError::InvalidOperation(
1537 "maxpool2d: pool size must be >= 1".to_string(),
1538 ));
1539 }
1540
1541 let n = self.shape[0];
1542 let c = self.shape[1];
1543 let h_in = self.shape[2];
1544 let w_in = self.shape[3];
1545
1546 if h_in < ph || w_in < pw {
1547 return Err(RuntimeError::InvalidOperation(format!(
1548 "maxpool2d: input [{}, {}] smaller than pool [{}, {}]",
1549 h_in, w_in, ph, pw
1550 )));
1551 }
1552
1553 let h_out = h_in / ph;
1554 let w_out = w_in / pw;
1555
1556 let inp = self.to_vec();
1557 let mut result = vec![0.0f64; n * c * h_out * w_out];
1558
1559 kernel_fns::maxpool2d_raw(&inp, &mut result, n, c, h_in, w_in, ph, pw);
1560
1561 Tensor::from_vec(result, &[n, c, h_out, w_out])
1562 }
1563
1564 pub fn scaled_dot_product_attention(
1573 queries: &Tensor,
1574 keys: &Tensor,
1575 values: &Tensor,
1576 ) -> Result<Tensor, RuntimeError> {
1577 if queries.ndim() < 2 || keys.ndim() < 2 || values.ndim() < 2 {
1578 return Err(RuntimeError::InvalidOperation(
1579 "attention: Q, K, V must be at least 2-D".to_string(),
1580 ));
1581 }
1582 let nd = queries.ndim();
1583 let d_k = queries.shape[nd - 1];
1584 let scale = 1.0 / (d_k as f64).sqrt();
1585
1586 let keys_t = keys.transpose_last_two()?;
1588
1589 let scores = queries.bmm(&keys_t)?;
1591
1592 let scores_scaled = scores.scalar_mul(scale);
1594
1595 let attn_weights = scores_scaled.softmax()?;
1597
1598 attn_weights.bmm(values)
1600 }
1601
1602 pub fn transpose_last_two(&self) -> Result<Tensor, RuntimeError> {
1606 if self.ndim() < 2 {
1607 return Err(RuntimeError::InvalidOperation(
1608 "transpose_last_two requires at least 2-D tensor".to_string(),
1609 ));
1610 }
1611 let nd = self.ndim();
1612 let rows = self.shape[nd - 2];
1613 let cols = self.shape[nd - 1];
1614 let data = self.to_vec();
1615 let batch_size: usize = self.shape[..nd - 2].iter().product::<usize>().max(1);
1616 let mat_size = rows * cols;
1617 let mut result = vec![0.0f64; data.len()];
1618
1619 for b in 0..batch_size {
1620 let off = b * mat_size;
1621 for i in 0..rows {
1622 for j in 0..cols {
1623 result[off + j * rows + i] = data[off + i * cols + j];
1624 }
1625 }
1626 }
1627
1628 let mut out_shape = self.shape.clone();
1629 out_shape[nd - 2] = cols;
1630 out_shape[nd - 1] = rows;
1631 Tensor::from_vec(result, &out_shape)
1632 }
1633
1634 pub fn from_bytes(bytes: &[u8], shape: &[usize], dtype: &str) -> Result<Tensor, RuntimeError> {
1650 let numel = Self::shape_numel(shape);
1651 match dtype {
1652 "f64" => {
1653 let expected = numel * 8;
1654 if bytes.len() != expected {
1655 return Err(RuntimeError::ShapeMismatch {
1656 expected,
1657 got: bytes.len(),
1658 });
1659 }
1660 let mut data = Vec::with_capacity(numel);
1661 for i in 0..numel {
1662 let off = i * 8;
1663 let mut buf = [0u8; 8];
1664 buf.copy_from_slice(&bytes[off..off + 8]);
1665 data.push(f64::from_le_bytes(buf));
1666 }
1667 Ok(Tensor {
1668 buffer: Buffer::from_vec(data),
1669 shape: shape.to_vec(),
1670 strides: Self::compute_strides(shape),
1671 offset: 0,
1672 })
1673 }
1674 "f32" => {
1675 let expected = numel * 4;
1676 if bytes.len() != expected {
1677 return Err(RuntimeError::ShapeMismatch {
1678 expected,
1679 got: bytes.len(),
1680 });
1681 }
1682 let mut data = Vec::with_capacity(numel);
1683 for i in 0..numel {
1684 let off = i * 4;
1685 let mut buf = [0u8; 4];
1686 buf.copy_from_slice(&bytes[off..off + 4]);
1687 data.push(f32::from_le_bytes(buf) as f64);
1688 }
1689 Ok(Tensor {
1690 buffer: Buffer::from_vec(data),
1691 shape: shape.to_vec(),
1692 strides: Self::compute_strides(shape),
1693 offset: 0,
1694 })
1695 }
1696 _ => Err(RuntimeError::InvalidOperation(
1697 format!("from_bytes: unsupported dtype '{}', expected 'f32' or 'f64'", dtype),
1698 )),
1699 }
1700 }
1701
1702 pub fn split_heads(&self, num_heads: usize) -> Result<Tensor, RuntimeError> {
1710 if self.ndim() != 3 {
1711 return Err(RuntimeError::DimensionMismatch {
1712 expected: 3,
1713 got: self.ndim(),
1714 });
1715 }
1716 let batch = self.shape[0];
1717 let seq = self.shape[1];
1718 let model_dim = self.shape[2];
1719 if model_dim % num_heads != 0 {
1720 return Err(RuntimeError::InvalidOperation(
1721 format!(
1722 "split_heads: model_dim {} not divisible by num_heads {}",
1723 model_dim, num_heads
1724 ),
1725 ));
1726 }
1727 let head_dim = model_dim / num_heads;
1728 let tensor = if self.is_contiguous() { self.clone() } else { self.to_contiguous() };
1730 let reshaped = Tensor {
1732 buffer: tensor.buffer.clone(),
1733 shape: vec![batch, seq, num_heads, head_dim],
1734 strides: Self::compute_strides(&[batch, seq, num_heads, head_dim]),
1735 offset: 0,
1736 };
1737 Ok(Tensor {
1740 buffer: reshaped.buffer,
1741 shape: vec![batch, num_heads, seq, head_dim],
1742 strides: vec![
1743 reshaped.strides[0], reshaped.strides[2], reshaped.strides[1], reshaped.strides[3], ],
1748 offset: 0,
1749 })
1750 }
1751
1752 pub fn merge_heads(&self) -> Result<Tensor, RuntimeError> {
1755 if self.ndim() != 4 {
1756 return Err(RuntimeError::DimensionMismatch {
1757 expected: 4,
1758 got: self.ndim(),
1759 });
1760 }
1761 let batch = self.shape[0];
1762 let num_heads = self.shape[1];
1763 let seq = self.shape[2];
1764 let head_dim = self.shape[3];
1765 let transposed = Tensor {
1768 buffer: self.buffer.clone(),
1769 shape: vec![batch, seq, num_heads, head_dim],
1770 strides: vec![
1771 self.strides[0],
1772 self.strides[2], self.strides[1], self.strides[3],
1775 ],
1776 offset: self.offset,
1777 };
1778 let contig = transposed.to_contiguous();
1780 let model_dim = num_heads * head_dim;
1781 Ok(Tensor {
1782 buffer: contig.buffer,
1783 shape: vec![batch, seq, model_dim],
1784 strides: Self::compute_strides(&[batch, seq, model_dim]),
1785 offset: 0,
1786 })
1787 }
1788
1789 pub fn view_reshape(&self, new_shape: &[usize]) -> Result<Tensor, RuntimeError> {
1792 self.reshape(new_shape)
1793 }
1794
1795 pub fn argsort(&self) -> Tensor {
1802 let data = self.to_vec();
1803 let mut indices: Vec<usize> = (0..data.len()).collect();
1804 indices.sort_by(|&a, &b| data[a].total_cmp(&data[b]));
1805 let result: Vec<f64> = indices.iter().map(|&i| i as f64).collect();
1806 Tensor::from_vec_unchecked(result, &[data.len()])
1807 }
1808
1809 pub fn gather(&self, dim: usize, indices: &Tensor) -> Result<Tensor, RuntimeError> {
1814 let data = self.to_vec();
1815 let idx_data = indices.to_vec();
1816 if self.ndim() == 1 {
1817 let mut result = Vec::with_capacity(idx_data.len());
1818 for &idx in &idx_data {
1819 let i = idx as usize;
1820 if i >= data.len() {
1821 return Err(RuntimeError::InvalidOperation(
1822 format!("gather: index {} out of bounds for size {}", i, data.len()),
1823 ));
1824 }
1825 result.push(data[i]);
1826 }
1827 Ok(Tensor::from_vec_unchecked(result, indices.shape()))
1828 } else if self.ndim() == 2 {
1829 let rows = self.shape[0];
1830 let cols = self.shape[1];
1831 let idx_shape = indices.shape();
1832 let out_rows = idx_shape[0];
1833 let out_cols = idx_shape[1];
1834 let mut result = vec![0.0; out_rows * out_cols];
1835 for i in 0..out_rows {
1836 for j in 0..out_cols {
1837 let idx = idx_data[i * out_cols + j] as usize;
1838 let val = if dim == 0 {
1839 if idx >= rows {
1840 return Err(RuntimeError::InvalidOperation(
1841 format!("gather dim=0: index {} out of bounds for {} rows", idx, rows),
1842 ));
1843 }
1844 data[idx * cols + j]
1845 } else {
1846 if idx >= cols {
1847 return Err(RuntimeError::InvalidOperation(
1848 format!("gather dim=1: index {} out of bounds for {} cols", idx, cols),
1849 ));
1850 }
1851 data[i * cols + idx]
1852 };
1853 result[i * out_cols + j] = val;
1854 }
1855 }
1856 Ok(Tensor::from_vec_unchecked(result, idx_shape))
1857 } else {
1858 Err(RuntimeError::InvalidOperation(
1859 "gather: only 1D and 2D tensors supported".into(),
1860 ))
1861 }
1862 }
1863
1864 pub fn scatter(&self, dim: usize, indices: &Tensor, src: &Tensor) -> Result<Tensor, RuntimeError> {
1869 let mut result = self.to_vec();
1870 let idx_data = indices.to_vec();
1871 let src_data = src.to_vec();
1872 if self.ndim() == 1 {
1873 for (k, &idx) in idx_data.iter().enumerate() {
1874 let i = idx as usize;
1875 if i >= result.len() {
1876 return Err(RuntimeError::InvalidOperation(
1877 format!("scatter: index {} out of bounds for size {}", i, result.len()),
1878 ));
1879 }
1880 result[i] = src_data[k];
1881 }
1882 Ok(Tensor::from_vec_unchecked(result, self.shape()))
1883 } else if self.ndim() == 2 {
1884 let cols = self.shape[1];
1885 let idx_shape = indices.shape();
1886 let out_cols = idx_shape[1];
1887 let out_rows = idx_shape[0];
1888 for i in 0..out_rows {
1889 for j in 0..out_cols {
1890 let idx = idx_data[i * out_cols + j] as usize;
1891 let src_val = src_data[i * out_cols + j];
1892 if dim == 0 {
1893 if idx >= self.shape[0] {
1894 return Err(RuntimeError::InvalidOperation(
1895 format!("scatter dim=0: index {} out of bounds for {} rows", idx, self.shape[0]),
1896 ));
1897 }
1898 result[idx * cols + j] = src_val;
1899 } else {
1900 if idx >= cols {
1901 return Err(RuntimeError::InvalidOperation(
1902 format!("scatter dim=1: index {} out of bounds for {} cols", idx, cols),
1903 ));
1904 }
1905 result[i * cols + idx] = src_val;
1906 }
1907 }
1908 }
1909 Ok(Tensor::from_vec_unchecked(result, self.shape()))
1910 } else {
1911 Err(RuntimeError::InvalidOperation(
1912 "scatter: only 1D and 2D tensors supported".into(),
1913 ))
1914 }
1915 }
1916
1917 pub fn index_select(&self, dim: usize, indices: &Tensor) -> Result<Tensor, RuntimeError> {
1921 let data = self.to_vec();
1922 let idx_data = indices.to_vec();
1923 if self.ndim() == 1 {
1924 let mut result = Vec::with_capacity(idx_data.len());
1925 for &idx in &idx_data {
1926 let i = idx as usize;
1927 if i >= data.len() {
1928 return Err(RuntimeError::InvalidOperation(
1929 format!("index_select: index {} out of bounds for size {}", i, data.len()),
1930 ));
1931 }
1932 result.push(data[i]);
1933 }
1934 Ok(Tensor::from_vec_unchecked(result, &[idx_data.len()]))
1935 } else if self.ndim() == 2 {
1936 let rows = self.shape[0];
1937 let cols = self.shape[1];
1938 let n = idx_data.len();
1939 if dim == 0 {
1940 let mut result = Vec::with_capacity(n * cols);
1941 for &idx in &idx_data {
1942 let i = idx as usize;
1943 if i >= rows {
1944 return Err(RuntimeError::InvalidOperation(
1945 format!("index_select dim=0: index {} out of bounds for {} rows", i, rows),
1946 ));
1947 }
1948 for j in 0..cols {
1949 result.push(data[i * cols + j]);
1950 }
1951 }
1952 Ok(Tensor::from_vec_unchecked(result, &[n, cols]))
1953 } else {
1954 let mut result = Vec::with_capacity(rows * n);
1955 for i in 0..rows {
1956 for &idx in &idx_data {
1957 let j = idx as usize;
1958 if j >= cols {
1959 return Err(RuntimeError::InvalidOperation(
1960 format!("index_select dim=1: index {} out of bounds for {} cols", j, cols),
1961 ));
1962 }
1963 result.push(data[i * cols + j]);
1964 }
1965 }
1966 Ok(Tensor::from_vec_unchecked(result, &[rows, n]))
1967 }
1968 } else {
1969 Err(RuntimeError::InvalidOperation(
1970 "index_select: only 1D and 2D tensors supported".into(),
1971 ))
1972 }
1973 }
1974
1975 pub fn tensor_where(&self, condition: &Tensor, other: &Tensor) -> Result<Tensor, RuntimeError> {
1982 if self.shape() != condition.shape() || self.shape() != other.shape() {
1983 return Err(RuntimeError::InvalidOperation(
1984 format!("where: shape mismatch self={:?} cond={:?} other={:?}",
1985 self.shape(), condition.shape(), other.shape()),
1986 ));
1987 }
1988 let s = self.to_vec();
1989 let c = condition.to_vec();
1990 let o = other.to_vec();
1991 let result: Vec<f64> = s.iter().zip(c.iter()).zip(o.iter())
1992 .map(|((&sv, &cv), &ov)| if cv != 0.0 { sv } else { ov })
1993 .collect();
1994 Tensor::from_vec(result, self.shape())
1995 }
1996
1997 pub fn any(&self) -> bool {
1999 let data = self.to_vec();
2000 data.iter().any(|&x| x != 0.0)
2001 }
2002
2003 pub fn all(&self) -> bool {
2005 let data = self.to_vec();
2006 data.iter().all(|&x| x != 0.0)
2007 }
2008
2009 pub fn nonzero(&self) -> Tensor {
2011 let data = self.to_vec();
2012 let indices: Vec<f64> = data.iter().enumerate()
2013 .filter(|(_, &v)| v != 0.0)
2014 .map(|(i, _)| i as f64)
2015 .collect();
2016 let len = indices.len();
2017 if len == 0 {
2018 Tensor::from_vec(vec![], &[0]).unwrap()
2019 } else {
2020 Tensor::from_vec(indices, &[len]).unwrap()
2021 }
2022 }
2023
2024 pub fn masked_fill(&self, mask: &Tensor, value: f64) -> Result<Tensor, RuntimeError> {
2026 if self.shape() != mask.shape() {
2027 return Err(RuntimeError::InvalidOperation(
2028 format!("masked_fill: shape mismatch self={:?} mask={:?}",
2029 self.shape(), mask.shape()),
2030 ));
2031 }
2032 let data = self.to_vec();
2033 let m = mask.to_vec();
2034 let result: Vec<f64> = data.iter().zip(m.iter())
2035 .map(|(&d, &mv)| if mv != 0.0 { value } else { d })
2036 .collect();
2037 Tensor::from_vec(result, self.shape())
2038 }
2039
2040 fn reduce_axis<F>(&self, axis: usize, keepdim: bool, reduce_fn: F)
2047 -> Result<Tensor, RuntimeError>
2048 where
2049 F: Fn(&[f64]) -> f64,
2050 {
2051 let ndim = self.ndim();
2052 if axis >= ndim {
2053 return Err(RuntimeError::IndexOutOfBounds {
2054 index: axis,
2055 length: ndim,
2056 });
2057 }
2058
2059 let axis_len = self.shape[axis];
2060 let mut out_shape: Vec<usize> = self.shape.clone();
2062 out_shape[axis] = 1;
2063 let out_numel = Self::shape_numel(&out_shape);
2064 let out_strides = Self::compute_strides(&out_shape);
2065
2066 let data = self.to_vec();
2067 let mut result = Vec::with_capacity(out_numel);
2068 let mut indices = vec![0usize; ndim];
2069
2070 for out_idx in 0..out_numel {
2071 {
2073 let mut remaining = out_idx;
2074 for d in 0..ndim {
2075 indices[d] = remaining / out_strides[d];
2076 remaining %= out_strides[d];
2077 }
2078 }
2079
2080 let mut vals = Vec::with_capacity(axis_len);
2082 for k in 0..axis_len {
2083 let mut flat = self.offset;
2084 for d in 0..ndim {
2085 let idx = if d == axis { k } else { indices[d] };
2086 flat += idx * self.strides[d];
2087 }
2088 vals.push(data[flat]);
2089 }
2090 result.push(reduce_fn(&vals));
2091 }
2092
2093 let final_shape = if keepdim {
2094 out_shape
2095 } else {
2096 let mut s: Vec<usize> = self.shape.iter().enumerate()
2098 .filter(|&(i, _)| i != axis)
2099 .map(|(_, &v)| v)
2100 .collect();
2101 if s.is_empty() {
2102 s.push(1); }
2104 s
2105 };
2106
2107 Tensor::from_vec(result, &final_shape)
2108 }
2109
2110 pub fn mean_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2112 self.reduce_axis(axis, keepdim, |vals| {
2113 let mut acc = BinnedAccumulatorF64::new();
2114 for &v in vals { acc.add(v); }
2115 acc.finalize() / vals.len() as f64
2116 })
2117 }
2118
2119 pub fn max_axis(&self, axis: usize, keepdim: bool) -> Result<(Tensor, Tensor), RuntimeError> {
2121 let ndim = self.ndim();
2122 if axis >= ndim {
2123 return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2124 }
2125 let axis_len = self.shape[axis];
2126 let mut out_shape = self.shape.clone();
2127 out_shape[axis] = 1;
2128 let out_numel = Self::shape_numel(&out_shape);
2129 let out_strides = Self::compute_strides(&out_shape);
2130 let data = self.to_vec();
2131 let mut values = Vec::with_capacity(out_numel);
2132 let mut idx_vals = Vec::with_capacity(out_numel);
2133 let mut indices = vec![0usize; ndim];
2134
2135 for out_idx in 0..out_numel {
2136 let mut remaining = out_idx;
2137 for d in 0..ndim {
2138 indices[d] = remaining / out_strides[d];
2139 remaining %= out_strides[d];
2140 }
2141 let mut best_val = f64::NEG_INFINITY;
2142 let mut best_idx = 0usize;
2143 for k in 0..axis_len {
2144 let mut flat = self.offset;
2145 for d in 0..ndim {
2146 let idx = if d == axis { k } else { indices[d] };
2147 flat += idx * self.strides[d];
2148 }
2149 let v = data[flat];
2150 if v > best_val {
2151 best_val = v;
2152 best_idx = k;
2153 }
2154 }
2155 values.push(best_val);
2156 idx_vals.push(best_idx as f64);
2157 }
2158
2159 let final_shape = if keepdim {
2160 out_shape
2161 } else {
2162 let mut s: Vec<usize> = self.shape.iter().enumerate()
2163 .filter(|&(i, _)| i != axis).map(|(_, &v)| v).collect();
2164 if s.is_empty() { s.push(1); }
2165 s
2166 };
2167 Ok((
2168 Tensor::from_vec(values, &final_shape)?,
2169 Tensor::from_vec(idx_vals, &final_shape)?,
2170 ))
2171 }
2172
2173 pub fn min_axis(&self, axis: usize, keepdim: bool) -> Result<(Tensor, Tensor), RuntimeError> {
2175 let ndim = self.ndim();
2176 if axis >= ndim {
2177 return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2178 }
2179 let axis_len = self.shape[axis];
2180 let mut out_shape = self.shape.clone();
2181 out_shape[axis] = 1;
2182 let out_numel = Self::shape_numel(&out_shape);
2183 let out_strides = Self::compute_strides(&out_shape);
2184 let data = self.to_vec();
2185 let mut values = Vec::with_capacity(out_numel);
2186 let mut idx_vals = Vec::with_capacity(out_numel);
2187 let mut indices = vec![0usize; ndim];
2188
2189 for out_idx in 0..out_numel {
2190 let mut remaining = out_idx;
2191 for d in 0..ndim {
2192 indices[d] = remaining / out_strides[d];
2193 remaining %= out_strides[d];
2194 }
2195 let mut best_val = f64::INFINITY;
2196 let mut best_idx = 0usize;
2197 for k in 0..axis_len {
2198 let mut flat = self.offset;
2199 for d in 0..ndim {
2200 let idx = if d == axis { k } else { indices[d] };
2201 flat += idx * self.strides[d];
2202 }
2203 let v = data[flat];
2204 if v < best_val {
2205 best_val = v;
2206 best_idx = k;
2207 }
2208 }
2209 values.push(best_val);
2210 idx_vals.push(best_idx as f64);
2211 }
2212
2213 let final_shape = if keepdim {
2214 out_shape
2215 } else {
2216 let mut s: Vec<usize> = self.shape.iter().enumerate()
2217 .filter(|&(i, _)| i != axis).map(|(_, &v)| v).collect();
2218 if s.is_empty() { s.push(1); }
2219 s
2220 };
2221 Ok((
2222 Tensor::from_vec(values, &final_shape)?,
2223 Tensor::from_vec(idx_vals, &final_shape)?,
2224 ))
2225 }
2226
2227 pub fn var_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2229 let mean_t = self.mean_axis(axis, true)?;
2230 let ndim = self.ndim();
2231 if axis >= ndim {
2232 return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2233 }
2234 let axis_len = self.shape[axis];
2235 let mut out_shape = self.shape.clone();
2236 out_shape[axis] = 1;
2237 let out_numel = Self::shape_numel(&out_shape);
2238 let out_strides = Self::compute_strides(&out_shape);
2239 let data = self.to_vec();
2240 let mean_data = mean_t.to_vec();
2241 let mut result = Vec::with_capacity(out_numel);
2242 let mut indices = vec![0usize; ndim];
2243
2244 for out_idx in 0..out_numel {
2245 let mut remaining = out_idx;
2246 for d in 0..ndim {
2247 indices[d] = remaining / out_strides[d];
2248 remaining %= out_strides[d];
2249 }
2250 let mu = mean_data[out_idx];
2251 let mut acc = BinnedAccumulatorF64::new();
2252 for k in 0..axis_len {
2253 let mut flat = self.offset;
2254 for d in 0..ndim {
2255 let idx = if d == axis { k } else { indices[d] };
2256 flat += idx * self.strides[d];
2257 }
2258 let diff = data[flat] - mu;
2259 acc.add(diff * diff);
2260 }
2261 result.push(acc.finalize() / axis_len as f64);
2262 }
2263
2264 let final_shape = if keepdim {
2265 out_shape
2266 } else {
2267 let mut s: Vec<usize> = self.shape.iter().enumerate()
2268 .filter(|&(i, _)| i != axis).map(|(_, &v)| v).collect();
2269 if s.is_empty() { s.push(1); }
2270 s
2271 };
2272 Tensor::from_vec(result, &final_shape)
2273 }
2274
2275 pub fn std_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2277 let var = self.var_axis(axis, keepdim)?;
2278 Ok(var.map(|x| x.sqrt()))
2279 }
2280
2281 pub fn prod_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2283 self.reduce_axis(axis, keepdim, |vals| {
2284 let mut product = 1.0f64;
2287 for &v in vals { product *= v; }
2288 product
2289 })
2290 }
2291
2292 pub fn sort_axis(&self, axis: usize, descending: bool) -> Result<Tensor, RuntimeError> {
2299 let ndim = self.ndim();
2300 if axis >= ndim {
2301 return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2302 }
2303 let data = self.to_vec();
2304 let axis_len = self.shape[axis];
2305 let out_shape = self.shape.clone();
2306 let out_numel = Self::shape_numel(&out_shape);
2307
2308 let mut iter_shape: Vec<usize> = Vec::new();
2310 for (i, &s) in self.shape.iter().enumerate() {
2311 if i != axis { iter_shape.push(s); }
2312 }
2313 let n_slices: usize = iter_shape.iter().product::<usize>().max(1);
2314
2315 let mut result = vec![0.0f64; out_numel];
2316
2317 let mut pos = vec![0usize; ndim];
2319 for slice_idx in 0..n_slices {
2320 let mut remaining = slice_idx;
2322 let mut dim_idx = 0;
2323 for d in 0..ndim {
2324 if d == axis {
2325 pos[d] = 0;
2326 } else {
2327 let stride = {
2328 let mut s = 1usize;
2329 let mut di = 0;
2330 for d2 in 0..ndim {
2331 if d2 == axis { continue; }
2332 if di > dim_idx { s *= self.shape[d2]; }
2333 di += 1;
2334 }
2335 s
2336 };
2337 pos[d] = remaining / stride;
2338 remaining %= stride;
2339 dim_idx += 1;
2340 }
2341 }
2342
2343 let mut vals: Vec<(f64, usize)> = Vec::with_capacity(axis_len);
2345 for k in 0..axis_len {
2346 let mut flat = self.offset;
2347 for d in 0..ndim {
2348 let idx = if d == axis { k } else { pos[d] };
2349 flat += idx * self.strides[d];
2350 }
2351 vals.push((data[flat], k));
2352 }
2353
2354 if descending {
2356 vals.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
2357 .then(a.1.cmp(&b.1)));
2358 } else {
2359 vals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
2360 .then(a.1.cmp(&b.1)));
2361 }
2362
2363 for (k, &(v, _)) in vals.iter().enumerate() {
2365 let mut flat = 0;
2366 let out_strides_local = Self::compute_strides(&out_shape);
2367 for d in 0..ndim {
2368 let idx = if d == axis { k } else { pos[d] };
2369 flat += idx * out_strides_local[d];
2370 }
2371 result[flat] = v;
2372 }
2373 }
2374
2375 Tensor::from_vec(result, &out_shape)
2376 }
2377
2378 pub fn argsort_axis(&self, axis: usize, descending: bool) -> Result<Tensor, RuntimeError> {
2380 let ndim = self.ndim();
2381 if axis >= ndim {
2382 return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2383 }
2384 let data = self.to_vec();
2385 let axis_len = self.shape[axis];
2386 let out_shape = self.shape.clone();
2387 let out_numel = Self::shape_numel(&out_shape);
2388
2389 let mut iter_shape: Vec<usize> = Vec::new();
2390 for (i, &s) in self.shape.iter().enumerate() {
2391 if i != axis { iter_shape.push(s); }
2392 }
2393 let n_slices: usize = iter_shape.iter().product::<usize>().max(1);
2394
2395 let mut result = vec![0.0f64; out_numel];
2396 let mut pos = vec![0usize; ndim];
2397
2398 for slice_idx in 0..n_slices {
2399 let mut remaining = slice_idx;
2400 let mut dim_idx = 0;
2401 for d in 0..ndim {
2402 if d == axis {
2403 pos[d] = 0;
2404 } else {
2405 let stride = {
2406 let mut s = 1usize;
2407 let mut di = 0;
2408 for d2 in 0..ndim {
2409 if d2 == axis { continue; }
2410 if di > dim_idx { s *= self.shape[d2]; }
2411 di += 1;
2412 }
2413 s
2414 };
2415 pos[d] = remaining / stride;
2416 remaining %= stride;
2417 dim_idx += 1;
2418 }
2419 }
2420
2421 let mut vals: Vec<(f64, usize)> = Vec::with_capacity(axis_len);
2422 for k in 0..axis_len {
2423 let mut flat = self.offset;
2424 for d in 0..ndim {
2425 let idx = if d == axis { k } else { pos[d] };
2426 flat += idx * self.strides[d];
2427 }
2428 vals.push((data[flat], k));
2429 }
2430
2431 if descending {
2432 vals.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
2433 .then(a.1.cmp(&b.1)));
2434 } else {
2435 vals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
2436 .then(a.1.cmp(&b.1)));
2437 }
2438
2439 for (k, &(_, orig_idx)) in vals.iter().enumerate() {
2440 let out_strides_local = Self::compute_strides(&out_shape);
2441 let mut flat = 0;
2442 for d in 0..ndim {
2443 let idx = if d == axis { k } else { pos[d] };
2444 flat += idx * out_strides_local[d];
2445 }
2446 result[flat] = orig_idx as f64;
2447 }
2448 }
2449
2450 Tensor::from_vec(result, &out_shape)
2451 }
2452
2453 pub fn einsum(notation: &str, inputs: &[&Tensor]) -> Result<Tensor, RuntimeError> {
2462 let parts: Vec<&str> = notation.split("->").collect();
2464 if parts.len() != 2 {
2465 return Err(RuntimeError::InvalidOperation(
2466 format!("einsum: expected 'subscripts->output' notation, got '{}'", notation),
2467 ));
2468 }
2469 let input_specs: Vec<&str> = parts[0].split(',').collect();
2470 let output_spec = parts[1];
2471
2472 if input_specs.len() != inputs.len() {
2473 return Err(RuntimeError::InvalidOperation(
2474 format!("einsum: {} input specs but {} tensors", input_specs.len(), inputs.len()),
2475 ));
2476 }
2477
2478 let mut label_size = std::collections::BTreeMap::new();
2480 for (i, &spec) in input_specs.iter().enumerate() {
2481 let chars: Vec<char> = spec.chars().collect();
2482 if chars.len() != inputs[i].ndim() {
2483 return Err(RuntimeError::InvalidOperation(
2484 format!("einsum: spec '{}' has {} dims but tensor has {}", spec, chars.len(), inputs[i].ndim()),
2485 ));
2486 }
2487 for (d, &c) in chars.iter().enumerate() {
2488 let sz = inputs[i].shape()[d];
2489 if let Some(&prev) = label_size.get(&c) {
2490 if prev != sz {
2491 return Err(RuntimeError::InvalidOperation(
2492 format!("einsum: label '{}' has conflicting sizes {} vs {}", c, prev, sz),
2493 ));
2494 }
2495 } else {
2496 label_size.insert(c, sz);
2497 }
2498 }
2499 }
2500
2501 let output_chars: Vec<char> = output_spec.chars().collect();
2503 let output_shape: Vec<usize> = output_chars.iter()
2504 .map(|c| label_size.get(c).copied().ok_or_else(||
2505 RuntimeError::InvalidOperation(format!("einsum: unknown label '{}' in output", c))))
2506 .collect::<Result<_, _>>()?;
2507 let out_numel = Self::shape_numel(&output_shape);
2508
2509 let output_set: std::collections::BTreeSet<char> = output_chars.iter().copied().collect();
2511 let contract_labels: Vec<char> = label_size.keys()
2512 .filter(|c| !output_set.contains(c))
2513 .copied()
2514 .collect();
2515 let contract_sizes: Vec<usize> = contract_labels.iter()
2516 .map(|c| label_size[c])
2517 .collect();
2518 let contract_numel: usize = contract_sizes.iter().product::<usize>().max(1);
2519
2520 let input_chars: Vec<Vec<char>> = input_specs.iter().map(|s| s.chars().collect()).collect();
2522
2523 let out_strides = Self::compute_strides(&output_shape);
2525 let mut result = vec![0.0f64; out_numel];
2526
2527 let input_data: Vec<Vec<f64>> = inputs.iter().map(|t| t.to_vec()).collect();
2529 let input_strides: Vec<Vec<usize>> = inputs.iter().map(|t| t.strides.clone()).collect();
2530 let input_offsets: Vec<usize> = inputs.iter().map(|t| t.offset).collect();
2531
2532 for out_idx in 0..out_numel {
2533 let mut label_vals = std::collections::BTreeMap::new();
2535 let mut remaining = out_idx;
2536 for (d, &c) in output_chars.iter().enumerate() {
2537 let stride = if d < out_strides.len() { out_strides[d] } else { 1 };
2538 label_vals.insert(c, remaining / stride);
2539 remaining %= stride;
2540 }
2541
2542 let mut acc = BinnedAccumulatorF64::new();
2543 for cidx in 0..contract_numel {
2545 let mut cr = cidx;
2547 for (ci, &cl) in contract_labels.iter().enumerate() {
2548 let stride: usize = contract_sizes[ci+1..].iter().product::<usize>().max(1);
2549 label_vals.insert(cl, cr / stride);
2550 cr %= stride;
2551 }
2552
2553 let mut product = 1.0f64;
2555 for (inp_idx, chars) in input_chars.iter().enumerate() {
2556 let mut flat = input_offsets[inp_idx];
2557 for (d, &c) in chars.iter().enumerate() {
2558 flat += label_vals[&c] * input_strides[inp_idx][d];
2559 }
2560 product *= input_data[inp_idx][flat];
2561 }
2562 acc.add(product);
2563 }
2564 result[out_idx] = acc.finalize();
2565 }
2566
2567 if output_shape.is_empty() {
2568 Tensor::from_vec(result, &[1])
2569 } else {
2570 Tensor::from_vec(result, &output_shape)
2571 }
2572 }
2573
2574 pub fn unsqueeze(&self, dim: usize) -> Result<Tensor, RuntimeError> {
2580 let ndim = self.ndim();
2581 if dim > ndim {
2582 return Err(RuntimeError::IndexOutOfBounds { index: dim, length: ndim + 1 });
2583 }
2584 let mut new_shape = self.shape.clone();
2585 new_shape.insert(dim, 1);
2586 self.reshape(&new_shape)
2587 }
2588
2589 pub fn squeeze(&self, dim: Option<usize>) -> Result<Tensor, RuntimeError> {
2592 match dim {
2593 Some(d) => {
2594 if d >= self.ndim() {
2595 return Err(RuntimeError::IndexOutOfBounds { index: d, length: self.ndim() });
2596 }
2597 if self.shape[d] != 1 {
2598 return Err(RuntimeError::InvalidOperation(
2599 format!("squeeze: dimension {} has size {}, not 1", d, self.shape[d]),
2600 ));
2601 }
2602 let mut new_shape = self.shape.clone();
2603 new_shape.remove(d);
2604 if new_shape.is_empty() {
2605 new_shape.push(1); }
2607 self.reshape(&new_shape)
2608 }
2609 None => {
2610 let new_shape: Vec<usize> = self.shape.iter()
2611 .filter(|&&s| s != 1)
2612 .copied()
2613 .collect();
2614 let new_shape = if new_shape.is_empty() { vec![1] } else { new_shape };
2615 self.reshape(&new_shape)
2616 }
2617 }
2618 }
2619
2620 pub fn expand(&self, target_shape: &[usize]) -> Result<Tensor, RuntimeError> {
2623 self.broadcast_to(target_shape)
2624 }
2625
2626 pub fn flatten(&self, start_dim: usize, end_dim: usize) -> Result<Tensor, RuntimeError> {
2628 if start_dim > end_dim || end_dim >= self.ndim() {
2629 return Err(RuntimeError::InvalidOperation(
2630 format!("flatten: invalid dim range [{}, {}] for {}D tensor", start_dim, end_dim, self.ndim()),
2631 ));
2632 }
2633 let mut new_shape = Vec::new();
2634 for i in 0..start_dim {
2635 new_shape.push(self.shape[i]);
2636 }
2637 let flat_size: usize = self.shape[start_dim..=end_dim].iter().product();
2638 new_shape.push(flat_size);
2639 for i in (end_dim + 1)..self.ndim() {
2640 new_shape.push(self.shape[i]);
2641 }
2642 self.reshape(&new_shape)
2643 }
2644
2645 pub fn chunk(&self, n: usize, dim: usize) -> Result<Vec<Tensor>, RuntimeError> {
2647 if dim >= self.ndim() {
2648 return Err(RuntimeError::IndexOutOfBounds { index: dim, length: self.ndim() });
2649 }
2650 if n == 0 {
2651 return Err(RuntimeError::InvalidOperation("chunk: n must be > 0".into()));
2652 }
2653 let dim_size = self.shape[dim];
2654 let chunk_size = (dim_size + n - 1) / n;
2655 let mut sizes = Vec::new();
2656 let mut remaining = dim_size;
2657 while remaining > 0 {
2658 let s = remaining.min(chunk_size);
2659 sizes.push(s);
2660 remaining -= s;
2661 }
2662 self.split(&sizes, dim)
2663 }
2664
2665 pub fn split(&self, sizes: &[usize], dim: usize) -> Result<Vec<Tensor>, RuntimeError> {
2667 if dim >= self.ndim() {
2668 return Err(RuntimeError::IndexOutOfBounds { index: dim, length: self.ndim() });
2669 }
2670 let total: usize = sizes.iter().sum();
2671 if total != self.shape[dim] {
2672 return Err(RuntimeError::InvalidOperation(
2673 format!("split: sizes sum {} != dim size {}", total, self.shape[dim]),
2674 ));
2675 }
2676
2677 let mut results = Vec::new();
2678 let mut offset = 0;
2679
2680 for &sz in sizes {
2681 let ranges: Vec<(usize, usize)> = self.shape.iter()
2682 .enumerate()
2683 .map(|(i, &s)| {
2684 if i == dim { (offset, offset + sz) } else { (0, s) }
2685 })
2686 .collect();
2687 let chunk = self.slice(&ranges)?;
2688 results.push(chunk.to_contiguous());
2690 offset += sz;
2691 }
2692
2693 Ok(results)
2694 }
2695
2696 pub fn scale_add(&self, alpha: f64, other: &Tensor, beta: f64) -> Result<Tensor, RuntimeError> {
2701 if self.shape != other.shape {
2702 return Err(RuntimeError::InvalidOperation(
2703 "scale_add: shape mismatch".to_string(),
2704 ));
2705 }
2706 let a = self.to_vec();
2707 let b = other.to_vec();
2708 let result: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| alpha * x + beta * y).collect();
2709 Tensor::from_vec(result, &self.shape)
2710 }
2711}
2712