1use cjc_repro::Rng;
43
44use crate::accumulator::{binned_sum_f64, BinnedAccumulatorF64};
45use cjc_repro::KahanAccumulatorF64;
46
47use crate::accumulator;
48use crate::buffer::Buffer;
49use crate::dispatch;
50use crate::error::RuntimeError;
51use crate::kernel as kernel_fns;
52use crate::tensor_simd::{self, BinOp, UnaryOp};
53use crate::tensor_tiled::TiledMatmul;
54
55#[derive(Debug, Clone)]
73pub struct Tensor {
74 pub buffer: Buffer<f64>,
76 pub(crate) shape: Vec<usize>,
78 pub(crate) strides: Vec<usize>,
82 pub(crate) offset: usize,
85}
86
87impl Tensor {
88 pub(crate) fn compute_strides(shape: &[usize]) -> Vec<usize> {
92 let mut strides = vec![1usize; shape.len()];
93 for i in (0..shape.len().saturating_sub(1)).rev() {
94 strides[i] = strides[i + 1] * shape[i + 1];
95 }
96 strides
97 }
98
99 fn shape_numel(shape: &[usize]) -> usize {
101 shape.iter().product()
102 }
103
104 pub fn zeros(shape: &[usize]) -> Self {
106 let numel = Self::shape_numel(shape);
107 Tensor {
108 buffer: Buffer::alloc(numel, 0.0),
109 shape: shape.to_vec(),
110 strides: Self::compute_strides(shape),
111 offset: 0,
112 }
113 }
114
115 pub fn ones(shape: &[usize]) -> Self {
117 let numel = Self::shape_numel(shape);
118 Tensor {
119 buffer: Buffer::alloc(numel, 1.0),
120 shape: shape.to_vec(),
121 strides: Self::compute_strides(shape),
122 offset: 0,
123 }
124 }
125
126 pub fn randn(shape: &[usize], rng: &mut Rng) -> Self {
129 let numel = Self::shape_numel(shape);
130 let data: Vec<f64> = (0..numel).map(|_| rng.next_normal_f64()).collect();
131 Tensor {
132 buffer: Buffer::from_vec(data),
133 shape: shape.to_vec(),
134 strides: Self::compute_strides(shape),
135 offset: 0,
136 }
137 }
138
139 pub fn from_vec(data: Vec<f64>, shape: &[usize]) -> Result<Self, RuntimeError> {
142 let numel = Self::shape_numel(shape);
143 if data.len() != numel {
144 return Err(RuntimeError::ShapeMismatch {
145 expected: numel,
146 got: data.len(),
147 });
148 }
149 Ok(Tensor {
150 buffer: Buffer::from_vec(data),
151 shape: shape.to_vec(),
152 strides: Self::compute_strides(shape),
153 offset: 0,
154 })
155 }
156
157 pub fn shape(&self) -> &[usize] {
161 &self.shape
162 }
163
164 pub fn ndim(&self) -> usize {
166 self.shape.len()
167 }
168
169 pub fn len(&self) -> usize {
171 Self::shape_numel(&self.shape)
172 }
173
174 pub fn is_empty(&self) -> bool {
176 self.len() == 0
177 }
178
179 fn linear_index(&self, indices: &[usize]) -> Result<usize, RuntimeError> {
181 if indices.len() != self.shape.len() {
182 return Err(RuntimeError::DimensionMismatch {
183 expected: self.shape.len(),
184 got: indices.len(),
185 });
186 }
187 let mut off = self.offset;
188 for (i, &idx) in indices.iter().enumerate() {
189 if idx >= self.shape[i] {
190 return Err(RuntimeError::IndexOutOfBounds {
191 index: idx,
192 length: self.shape[i],
193 });
194 }
195 off += idx * self.strides[i];
196 }
197 Ok(off)
198 }
199
200 pub fn is_contiguous(&self) -> bool {
202 if self.offset != 0 {
203 return false;
204 }
205 let expected = Self::compute_strides(&self.shape);
206 self.strides == expected
207 }
208
209 pub fn slice(&self, ranges: &[(usize, usize)]) -> Result<Tensor, RuntimeError> {
212 if ranges.len() != self.shape.len() {
213 return Err(RuntimeError::DimensionMismatch {
214 expected: self.shape.len(),
215 got: ranges.len(),
216 });
217 }
218 let mut new_offset = self.offset;
219 let mut new_shape = Vec::with_capacity(ranges.len());
220 for (i, &(start, end)) in ranges.iter().enumerate() {
221 if end > self.shape[i] || start > end {
222 return Err(RuntimeError::IndexOutOfBounds {
223 index: end,
224 length: self.shape[i],
225 });
226 }
227 new_offset += start * self.strides[i];
228 new_shape.push(end - start);
229 }
230 Ok(Tensor {
231 buffer: self.buffer.clone(), shape: new_shape,
233 strides: self.strides.clone(),
234 offset: new_offset,
235 })
236 }
237
238 pub fn to_contiguous(&self) -> Tensor {
240 if self.is_contiguous() {
241 return self.clone();
242 }
243 let data = self.to_vec();
244 Tensor {
245 buffer: Buffer::from_vec(data),
246 shape: self.shape.clone(),
247 strides: Self::compute_strides(&self.shape),
248 offset: 0,
249 }
250 }
251
252 pub fn broadcast_to(&self, target_shape: &[usize]) -> Result<Tensor, RuntimeError> {
255 let src_ndim = self.shape.len();
256 let tgt_ndim = target_shape.len();
257 if tgt_ndim < src_ndim {
258 return Err(RuntimeError::InvalidOperation(
259 "cannot broadcast to a smaller rank".to_string(),
260 ));
261 }
262 let pad = tgt_ndim - src_ndim;
263 let mut new_strides = vec![0usize; tgt_ndim];
264 for i in 0..tgt_ndim {
265 if i < pad {
266 new_strides[i] = 0;
268 } else {
269 let src_i = i - pad;
270 if self.shape[src_i] == target_shape[i] {
271 new_strides[i] = self.strides[src_i];
272 } else if self.shape[src_i] == 1 {
273 new_strides[i] = 0; } else {
275 return Err(RuntimeError::ShapeMismatch {
276 expected: target_shape[i],
277 got: self.shape[src_i],
278 });
279 }
280 }
281 }
282 Ok(Tensor {
283 buffer: self.buffer.clone(),
284 shape: target_shape.to_vec(),
285 strides: new_strides,
286 offset: self.offset,
287 })
288 }
289
290 pub fn get(&self, indices: &[usize]) -> Result<f64, RuntimeError> {
292 let offset = self.linear_index(indices)?;
293 self.buffer
294 .get(offset)
295 .ok_or(RuntimeError::IndexOutOfBounds {
296 index: offset,
297 length: self.buffer.len(),
298 })
299 }
300
301 pub fn set(&mut self, indices: &[usize], val: f64) -> Result<(), RuntimeError> {
303 let offset = self.linear_index(indices)?;
304 self.buffer.set(offset, val)
305 }
306
307 pub fn to_vec(&self) -> Vec<f64> {
309 if self.is_contiguous() {
310 let full = self.buffer.borrow_data();
311 let numel = self.len();
312 if full.len() == numel {
313 return full.to_vec();
314 }
315 return full[..numel].to_vec();
318 }
319 let numel = self.len();
321 let mut result = Vec::with_capacity(numel);
322 let ndim = self.shape.len();
323 let mut indices = vec![0usize; ndim];
324 for _ in 0..numel {
325 let mut off = self.offset;
326 for d in 0..ndim {
327 off += indices[d] * self.strides[d];
328 }
329 result.push(self.buffer.get(off).unwrap_or(0.0));
330 for d in (0..ndim).rev() {
332 indices[d] += 1;
333 if indices[d] < self.shape[d] {
334 break;
335 }
336 indices[d] = 0;
337 }
338 }
339 result
340 }
341
342 pub fn reshape(&self, new_shape: &[usize]) -> Result<Tensor, RuntimeError> {
347 let new_numel = Self::shape_numel(new_shape);
348 if new_numel != self.len() {
349 return Err(RuntimeError::ShapeMismatch {
350 expected: self.len(),
351 got: new_numel,
352 });
353 }
354 let tensor = if self.is_contiguous() { self.clone() } else { self.to_contiguous() };
356 Ok(Tensor {
357 buffer: tensor.buffer,
358 shape: new_shape.to_vec(),
359 strides: Self::compute_strides(new_shape),
360 offset: 0,
361 })
362 }
363
364 fn elementwise_binop(
368 &self,
369 other: &Tensor,
370 op: impl Fn(f64, f64) -> f64,
371 ) -> Result<Tensor, RuntimeError> {
372 if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
373 let a = self.buffer.borrow_data();
375 let b = other.buffer.borrow_data();
376 let data: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| op(x, y)).collect();
377 return Ok(Tensor {
378 buffer: Buffer::from_vec(data),
379 shape: self.shape.clone(),
380 strides: Self::compute_strides(&self.shape),
381 offset: 0,
382 });
383 }
384
385 let result_shape = Self::broadcast_result_shape(&self.shape, &other.shape)?;
387 let a_broadcast = self.broadcast_to(&result_shape)?;
388 let b_broadcast = other.broadcast_to(&result_shape)?;
389
390 let numel = Self::shape_numel(&result_shape);
391 let ndim = result_shape.len();
392 let mut data = Vec::with_capacity(numel);
393 let mut indices = vec![0usize; ndim];
394
395 for _ in 0..numel {
396 let mut off_a = a_broadcast.offset;
397 let mut off_b = b_broadcast.offset;
398 for d in 0..ndim {
399 off_a += indices[d] * a_broadcast.strides[d];
400 off_b += indices[d] * b_broadcast.strides[d];
401 }
402 let va = a_broadcast.buffer.get(off_a).ok_or_else(|| {
403 RuntimeError::InvalidOperation(format!(
404 "broadcast binop: left operand index {} out of bounds (buffer len {})",
405 off_a,
406 a_broadcast.buffer.len()
407 ))
408 })?;
409 let vb = b_broadcast.buffer.get(off_b).ok_or_else(|| {
410 RuntimeError::InvalidOperation(format!(
411 "broadcast binop: right operand index {} out of bounds (buffer len {})",
412 off_b,
413 b_broadcast.buffer.len()
414 ))
415 })?;
416 data.push(op(va, vb));
417
418 for d in (0..ndim).rev() {
419 indices[d] += 1;
420 if indices[d] < result_shape[d] {
421 break;
422 }
423 indices[d] = 0;
424 }
425 }
426
427 Ok(Tensor {
428 buffer: Buffer::from_vec(data),
429 shape: result_shape.clone(),
430 strides: Self::compute_strides(&result_shape),
431 offset: 0,
432 })
433 }
434
435 fn broadcast_result_shape(a: &[usize], b: &[usize]) -> Result<Vec<usize>, RuntimeError> {
437 let max_ndim = a.len().max(b.len());
438 let mut result = Vec::with_capacity(max_ndim);
439 for i in 0..max_ndim {
440 let da = if i < max_ndim - a.len() { 1 } else { a[i - (max_ndim - a.len())] };
441 let db = if i < max_ndim - b.len() { 1 } else { b[i - (max_ndim - b.len())] };
442 if da == db {
443 result.push(da);
444 } else if da == 1 {
445 result.push(db);
446 } else if db == 1 {
447 result.push(da);
448 } else {
449 return Err(RuntimeError::ShapeMismatch {
450 expected: da,
451 got: db,
452 });
453 }
454 }
455 Ok(result)
456 }
457
458 fn elementwise_binop_simd(
463 &self,
464 other: &Tensor,
465 op: BinOp,
466 fallback: impl Fn(f64, f64) -> f64,
467 ) -> Result<Tensor, RuntimeError> {
468 if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
469 let a = self.buffer.borrow_data();
471 let b = other.buffer.borrow_data();
472 let data = tensor_simd::simd_binop(&a, &b, op);
473 return Ok(Tensor {
474 buffer: Buffer::from_vec(data),
475 shape: self.shape.clone(),
476 strides: Self::compute_strides(&self.shape),
477 offset: 0,
478 });
479 }
480 self.elementwise_binop(other, fallback)
482 }
483
484 pub fn add(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
486 self.elementwise_binop_simd(other, BinOp::Add, |a, b| a + b)
487 }
488
489 pub fn sub(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
491 self.elementwise_binop_simd(other, BinOp::Sub, |a, b| a - b)
492 }
493
494 pub fn mul_elem(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
496 self.elementwise_binop_simd(other, BinOp::Mul, |a, b| a * b)
497 }
498
499 pub fn div_elem(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
501 self.elementwise_binop_simd(other, BinOp::Div, |a, b| a / b)
502 }
503
504 pub fn fused_mul_add(&self, b: &Tensor, c: &Tensor) -> Result<Tensor, RuntimeError> {
510 if self.shape != b.shape || self.shape != c.shape {
511 return Err(RuntimeError::InvalidOperation(
512 "broadcast_fma: all three tensors must have the same shape".to_string(),
513 ));
514 }
515 if self.is_contiguous() && b.is_contiguous() && c.is_contiguous() {
516 let a_data = self.buffer.borrow_data();
517 let b_data = b.buffer.borrow_data();
518 let c_data = c.buffer.borrow_data();
519 let n = a_data.len();
520 let mut out = vec![0.0f64; n];
521 for i in 0..n {
524 out[i] = a_data[i] * b_data[i] + c_data[i];
525 }
526 return Ok(Tensor {
527 buffer: Buffer::from_vec(out),
528 shape: self.shape.clone(),
529 strides: Self::compute_strides(&self.shape),
530 offset: 0,
531 });
532 }
533 let temp = self.mul_elem(b)?;
535 temp.add(c)
536 }
537
538 pub fn elem_pow(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
542 self.elementwise_binop(other, |a, b| a.powf(b))
543 }
544
545 pub fn elem_min(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
547 self.elementwise_binop(other, |a, b| a.min(b))
548 }
549
550 pub fn elem_max(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
552 self.elementwise_binop(other, |a, b| a.max(b))
553 }
554
555 pub fn elem_atan2(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
557 self.elementwise_binop(other, |a, b| a.atan2(b))
558 }
559
560 pub fn elem_hypot(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
562 self.elementwise_binop(other, |a, b| a.hypot(b))
563 }
564
565 pub fn map(&self, f: impl Fn(f64) -> f64) -> Tensor {
567 let data: Vec<f64> = self.to_vec().iter().map(|&x| f(x)).collect();
568 Tensor {
569 buffer: Buffer::from_vec(data),
570 shape: self.shape.clone(),
571 strides: Self::compute_strides(&self.shape),
572 offset: 0,
573 }
574 }
575
576 pub fn map_simd(&self, op: UnaryOp) -> Tensor {
581 let src = self.to_vec();
582 let data = tensor_simd::simd_unary(&src, op);
583 Tensor {
584 buffer: Buffer::from_vec(data),
585 shape: self.shape.clone(),
586 strides: Self::compute_strides(&self.shape),
587 offset: 0,
588 }
589 }
590
591 pub fn sum(&self) -> f64 {
595 let data = self.buffer.borrow_data();
596 binned_sum_f64(&data)
597 }
598
599 pub fn binned_sum(&self) -> f64 {
603 let data = self.buffer.borrow_data();
604 accumulator::binned_sum_f64(&data)
605 }
606
607 pub fn dispatched_sum(&self, ctx: &dispatch::ReductionContext) -> f64 {
611 let data = self.buffer.borrow_data();
612 dispatch::dispatch_sum_f64(&data, ctx)
613 }
614
615 pub fn mean(&self) -> f64 {
617 let n = self.len();
618 if n == 0 {
619 return 0.0;
620 }
621 self.sum() / n as f64
622 }
623
624 pub fn dispatched_mean(&self, ctx: &dispatch::ReductionContext) -> f64 {
626 let n = self.len();
627 if n == 0 {
628 return 0.0;
629 }
630 self.dispatched_sum(ctx) / n as f64
631 }
632
633 pub fn sum_axis(&self, axis: usize) -> Result<Tensor, RuntimeError> {
643 let ndim = self.ndim();
644 if axis >= ndim {
645 return Err(RuntimeError::IndexOutOfBounds {
646 index: axis,
647 length: ndim,
648 });
649 }
650
651 let mut out_shape = self.shape.clone();
653 out_shape[axis] = 1;
654 let out_numel = Self::shape_numel(&out_shape);
655 let out_strides = Self::compute_strides(&out_shape);
656
657 let data = self.to_vec();
658 let axis_len = self.shape[axis];
659 let mut result = vec![0.0f64; out_numel];
660
661 let mut indices = vec![0usize; ndim];
663 for out_idx in 0..out_numel {
664 {
666 let mut remaining = out_idx;
667 for d in 0..ndim {
668 indices[d] = remaining / out_strides[d];
669 remaining %= out_strides[d];
670 }
671 }
672
673 let mut acc = BinnedAccumulatorF64::new();
674 for k in 0..axis_len {
675 let mut flat = self.offset;
677 for d in 0..ndim {
678 let idx = if d == axis { k } else { indices[d] };
679 flat += idx * self.strides[d];
680 }
681 acc.add(data[flat]);
682 }
683 result[out_idx] = acc.finalize();
684 }
685
686 Tensor::from_vec(result, &out_shape)
687 }
688
689 pub fn neg(&self) -> Tensor {
693 self.map(|x| -x)
694 }
695
696 pub fn transpose(&self) -> Tensor {
699 let ndim = self.ndim();
700 if ndim <= 1 {
701 return self.clone();
702 }
703 let mut new_shape = self.shape.clone();
705 let mut new_strides = self.strides.clone();
706 new_shape.reverse();
707 new_strides.reverse();
708 Tensor {
709 buffer: self.buffer.clone(), shape: new_shape,
711 strides: new_strides,
712 offset: self.offset,
713 }
714 }
715
716 pub fn transpose_axes(&self, axes: &[usize]) -> Result<Tensor, RuntimeError> {
720 let ndim = self.ndim();
721 if axes.len() != ndim {
722 return Err(RuntimeError::InvalidOperation(
723 format!("transpose_axes: expected {} axes, got {}", ndim, axes.len()),
724 ));
725 }
726 let mut seen = vec![false; ndim];
728 for &ax in axes {
729 if ax >= ndim {
730 return Err(RuntimeError::IndexOutOfBounds { index: ax, length: ndim });
731 }
732 if seen[ax] {
733 return Err(RuntimeError::InvalidOperation(
734 format!("transpose_axes: duplicate axis {ax}"),
735 ));
736 }
737 seen[ax] = true;
738 }
739 let new_shape: Vec<usize> = axes.iter().map(|&ax| self.shape[ax]).collect();
740 let new_strides: Vec<usize> = axes.iter().map(|&ax| self.strides[ax]).collect();
741 Ok(Tensor {
742 buffer: self.buffer.clone(),
743 shape: new_shape,
744 strides: new_strides,
745 offset: self.offset,
746 })
747 }
748
749 pub fn scalar_mul(&self, s: f64) -> Tensor {
751 self.map(|x| x * s)
752 }
753
754 pub fn from_vec_unchecked(data: Vec<f64>, shape: &[usize]) -> Tensor {
759 Self::from_vec(data, shape).expect("Tensor::from_vec_unchecked: shape mismatch")
760 }
761
762 pub fn add_unchecked(&self, other: &Tensor) -> Tensor {
764 self.add(other).expect("Tensor::add shape mismatch")
765 }
766
767 pub fn add_assign_unchecked(&mut self, other: &Tensor) {
772 assert_eq!(self.shape, other.shape, "Tensor::add_assign shape mismatch");
773 self.buffer.make_unique();
774 let other_data = other.buffer.borrow_data();
775 let mut self_data = self.buffer.borrow_data_mut();
776 for (a, b) in self_data.iter_mut().zip(other_data.iter()) {
777 *a += b;
778 }
779 }
780
781 pub fn sub_unchecked(&self, other: &Tensor) -> Tensor {
783 self.sub(other).expect("Tensor::sub shape mismatch")
784 }
785
786 pub fn mul_elem_unchecked(&self, other: &Tensor) -> Tensor {
788 self.mul_elem(other).expect("Tensor::mul_elem shape mismatch")
789 }
790
791 pub fn div_elem_unchecked(&self, other: &Tensor) -> Tensor {
793 self.div_elem(other).expect("Tensor::div_elem shape mismatch")
794 }
795
796 pub fn matmul_unchecked(&self, other: &Tensor) -> Tensor {
798 self.matmul(other).expect("Tensor::matmul dimension mismatch")
799 }
800
801 pub fn matmul(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
805 if self.ndim() != 2 || other.ndim() != 2 {
806 return Err(RuntimeError::InvalidOperation(
807 "matmul requires 2-D tensors".to_string(),
808 ));
809 }
810 let m = self.shape[0];
811 let k = self.shape[1];
812 let k2 = other.shape[0];
813 let n = other.shape[1];
814 if k != k2 {
815 return Err(RuntimeError::DimensionMismatch {
816 expected: k,
817 got: k2,
818 });
819 }
820
821 let a = self.to_vec();
822 let b = other.to_vec();
823
824 #[cfg(feature = "parallel")]
827 {
828 if m >= 256 || n >= 256 || k >= 256 {
829 return Self::matmul_parallel_mode_a(&a, &b, m, n, k);
830 }
831 }
832
833 if m >= 64 || n >= 64 || k >= 64 {
838 return Self::matmul_tiled(&a, &b, m, n, k);
839 }
840
841 Self::matmul_sequential(&a, &b, m, n, k)
843 }
844
845 fn matmul_sequential(
847 a: &[f64], b: &[f64], m: usize, n: usize, k: usize,
848 ) -> Result<Tensor, RuntimeError> {
849 let mut result = vec![0.0f64; m * n];
850 for i in 0..m {
851 for j in 0..n {
852 let mut acc = KahanAccumulatorF64::new();
853 for p in 0..k {
854 acc.add(a[i * k + p] * b[p * n + j]);
855 }
856 result[i * n + j] = acc.finalize();
857 }
858 }
859 Tensor::from_vec(result, &[m, n])
860 }
861
862 fn matmul_tiled(
869 a: &[f64], b: &[f64], m: usize, n: usize, k: usize,
870 ) -> Result<Tensor, RuntimeError> {
871 let engine = TiledMatmul::new();
872 let result = engine.matmul(a, m, k, b, n);
873 Tensor::from_vec(result, &[m, n])
874 }
875
876 #[cfg(feature = "parallel")]
887 fn matmul_parallel_mode_a(
888 a: &[f64], b: &[f64], m: usize, n: usize, k: usize,
889 ) -> Result<Tensor, RuntimeError> {
890 use rayon::prelude::*;
891 use cjc_repro::KahanAccumulatorF64;
892
893 if m >= 512 && n >= 512 {
898 let band_size = (m + rayon::current_num_threads() - 1) / rayon::current_num_threads();
901 let band_size = band_size.max(64); let mut result = vec![0.0f64; m * n];
903
904 result
905 .par_chunks_mut(band_size * n)
906 .enumerate()
907 .for_each(|(band_idx, band)| {
908 let i_start = band_idx * band_size;
909 let i_end = (i_start + band_size).min(m);
910 let band_m = i_end - i_start;
911 let a_band = &a[i_start * k .. i_end * k];
912 let engine = crate::tensor_tiled::TiledMatmul::new();
913 let tiled_result = engine.matmul(a_band, band_m, k, b, n);
914 band[..band_m * n].copy_from_slice(&tiled_result);
915 });
916
917 return Tensor::from_vec(result, &[m, n]);
918 }
919
920 let mut result = vec![0.0f64; m * n];
922 result
923 .par_chunks_mut(n)
924 .enumerate()
925 .for_each(|(i, row)| {
926 for j in 0..n {
927 let mut acc = KahanAccumulatorF64::new();
928 for p in 0..k {
929 acc.add(a[i * k + p] * b[p * n + j]);
930 }
931 row[j] = acc.finalize();
932 }
933 });
934
935 Tensor::from_vec(result, &[m, n])
936 }
937
938 pub fn bmm(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
946 if self.ndim() < 2 || other.ndim() < 2 {
947 return Err(RuntimeError::InvalidOperation(
948 "bmm requires at least 2-D tensors".to_string(),
949 ));
950 }
951 if self.ndim() == 2 && other.ndim() == 2 {
952 return self.matmul(other);
953 }
954 if self.ndim() != other.ndim() {
955 return Err(RuntimeError::InvalidOperation(
956 format!(
957 "bmm requires same number of dimensions, got {} and {}",
958 self.ndim(),
959 other.ndim()
960 ),
961 ));
962 }
963 let nd = self.ndim();
964 let batch_dims_a = &self.shape[..nd - 2];
965 let batch_dims_b = &other.shape[..nd - 2];
966 if batch_dims_a != batch_dims_b {
967 return Err(RuntimeError::InvalidOperation(
968 format!(
969 "bmm batch dimensions mismatch: {:?} vs {:?}",
970 batch_dims_a, batch_dims_b
971 ),
972 ));
973 }
974 let m = self.shape[nd - 2];
975 let k = self.shape[nd - 1];
976 let k2 = other.shape[nd - 2];
977 let n = other.shape[nd - 1];
978 if k != k2 {
979 return Err(RuntimeError::DimensionMismatch {
980 expected: k,
981 got: k2,
982 });
983 }
984
985 let batch_size: usize = batch_dims_a.iter().product();
986 let a = self.to_vec();
987 let b = other.to_vec();
988 let mat_a_stride = m * k;
989 let mat_b_stride = k * n;
990 let mat_c_stride = m * n;
991 let mut result = vec![0.0f64; batch_size * mat_c_stride];
992
993 let compute_batch = |batch: usize, c_slice: &mut [f64]| {
995 let a_slice = &a[batch * mat_a_stride..(batch + 1) * mat_a_stride];
996 let b_slice = &b[batch * mat_b_stride..(batch + 1) * mat_b_stride];
997
998 if m >= 64 || n >= 64 || k >= 64 {
999 let engine = crate::tensor_tiled::TiledMatmul::new();
1000 let tiled = engine.matmul(a_slice, m, k, b_slice, n);
1001 c_slice.copy_from_slice(&tiled);
1002 } else {
1003 for i in 0..m {
1004 for j in 0..n {
1005 let mut acc = KahanAccumulatorF64::new();
1006 for p in 0..k {
1007 acc.add(a_slice[i * k + p] * b_slice[p * n + j]);
1008 }
1009 c_slice[i * n + j] = acc.finalize();
1010 }
1011 }
1012 }
1013 };
1014
1015 #[cfg(feature = "parallel")]
1017 {
1018 if batch_size > 1 && m * k >= 4096 {
1019 use rayon::prelude::*;
1020 result
1021 .par_chunks_mut(mat_c_stride)
1022 .enumerate()
1023 .for_each(|(batch, c_slice)| {
1024 compute_batch(batch, c_slice);
1025 });
1026
1027 let mut out_shape = batch_dims_a.to_vec();
1028 out_shape.push(m);
1029 out_shape.push(n);
1030 return Tensor::from_vec(result, &out_shape);
1031 }
1032 }
1033
1034 for batch in 0..batch_size {
1036 let c_off = batch * mat_c_stride;
1037 compute_batch(batch, &mut result[c_off..c_off + mat_c_stride]);
1038 }
1039
1040 let mut out_shape = batch_dims_a.to_vec();
1041 out_shape.push(m);
1042 out_shape.push(n);
1043 Tensor::from_vec(result, &out_shape)
1044 }
1045
1046 pub fn softmax(&self) -> Result<Tensor, RuntimeError> {
1054 if self.ndim() == 0 {
1055 return Err(RuntimeError::InvalidOperation(
1056 "softmax requires at least 1-D tensor".to_string(),
1057 ));
1058 }
1059 let data_ref;
1061 let data_vec;
1062 let data: &[f64] = if self.is_contiguous() && self.offset == 0 {
1063 data_ref = self.buffer.borrow_data();
1064 &data_ref
1065 } else {
1066 data_vec = self.to_vec();
1067 &data_vec
1068 };
1069 let n = *self.shape.last().unwrap(); let outer: usize = data.len() / n; let mut result = vec![0.0f64; data.len()];
1072
1073 for row in 0..outer {
1074 let start = row * n;
1075 let end = start + n;
1076 let slice = &data[start..end];
1077
1078 let mut max_val = f64::NEG_INFINITY;
1080 for &v in slice {
1081 if v > max_val {
1082 max_val = v;
1083 }
1084 }
1085
1086 let mut exp_vals = vec![0.0f64; n];
1088 let mut sum = 0.0f64;
1089 let mut comp = 0.0f64; for i in 0..n {
1091 let e = (slice[i] - max_val).exp();
1092 exp_vals[i] = e;
1093 let y = e - comp;
1095 let t = sum + y;
1096 comp = (t - sum) - y;
1097 sum = t;
1098 }
1099
1100 if sum == 0.0 {
1102 let uniform = 1.0 / n as f64;
1104 for i in 0..n {
1105 result[start + i] = uniform;
1106 }
1107 } else {
1108 for i in 0..n {
1109 result[start + i] = exp_vals[i] / sum;
1110 }
1111 }
1112 }
1113
1114 Tensor::from_vec(result, &self.shape)
1115 }
1116
1117 pub fn layer_norm(
1128 &self,
1129 gamma: &Tensor,
1130 beta: &Tensor,
1131 eps: f64,
1132 ) -> Result<Tensor, RuntimeError> {
1133 if self.ndim() == 0 {
1134 return Err(RuntimeError::InvalidOperation(
1135 "layer_norm requires at least 1-D tensor".to_string(),
1136 ));
1137 }
1138 let d = *self.shape.last().unwrap();
1139 if gamma.len() != d || beta.len() != d {
1140 return Err(RuntimeError::InvalidOperation(
1141 format!(
1142 "layer_norm: gamma/beta length {} must match last dim {}",
1143 gamma.len(),
1144 d
1145 ),
1146 ));
1147 }
1148
1149 let data = self.to_vec();
1150 let gamma_data = gamma.to_vec();
1151 let beta_data = beta.to_vec();
1152 let outer = data.len() / d;
1153 let mut result = vec![0.0f64; data.len()];
1154
1155 for row in 0..outer {
1156 let start = row * d;
1157 let slice = &data[start..start + d];
1158
1159 let mean = binned_sum_f64(slice) / d as f64;
1161
1162 let diffs: Vec<f64> = slice.iter().map(|&x| {
1164 let diff = x - mean;
1165 diff * diff
1166 }).collect();
1167 let variance = binned_sum_f64(&diffs) / d as f64;
1168
1169 let inv_std = 1.0 / (variance + eps).sqrt();
1171 for i in 0..d {
1172 let normalized = (slice[i] - mean) * inv_std;
1173 result[start + i] = gamma_data[i] * normalized + beta_data[i];
1174 }
1175 }
1176
1177 Tensor::from_vec(result, &self.shape)
1178 }
1179
1180 fn map_elementwise(&self, f: impl Fn(f64) -> f64) -> Tensor {
1186 if self.is_contiguous() && self.offset == 0 && self.buffer.refcount() == 1 {
1187 let mut data = self.buffer.borrow_data().clone();
1189 for x in data.iter_mut() {
1190 *x = f(*x);
1191 }
1192 Tensor::from_vec(data, &self.shape).unwrap()
1193 } else {
1194 let data = self.to_vec();
1196 let result: Vec<f64> = data.iter().map(|&x| f(x)).collect();
1197 Tensor::from_vec(result, &self.shape).unwrap()
1198 }
1199 }
1200
1201 pub fn relu(&self) -> Tensor {
1203 self.map_elementwise(|x| if x > 0.0 { x } else { 0.0 })
1204 }
1205
1206 pub fn sigmoid(&self) -> Tensor {
1208 self.map_elementwise(|x| 1.0 / (1.0 + (-x).exp()))
1209 }
1210
1211 pub fn tanh_activation(&self) -> Tensor {
1213 self.map_elementwise(|x| x.tanh())
1214 }
1215
1216 pub fn leaky_relu(&self, alpha: f64) -> Tensor {
1218 self.map_elementwise(move |x| if x > 0.0 { x } else { alpha * x })
1219 }
1220
1221 pub fn silu(&self) -> Tensor {
1223 let data = self.to_vec();
1224 let result: Vec<f64> = data.iter().map(|&x| x / (1.0 + (-x).exp())).collect();
1225 Tensor::from_vec(result, &self.shape).unwrap()
1226 }
1227
1228 pub fn mish(&self) -> Tensor {
1230 let data = self.to_vec();
1231 let result: Vec<f64> = data.iter().map(|&x| {
1232 let sp = (1.0 + x.exp()).ln();
1233 x * sp.tanh()
1234 }).collect();
1235 Tensor::from_vec(result, &self.shape).unwrap()
1236 }
1237
1238 pub fn argmax(&self) -> usize {
1240 let data = self.to_vec();
1241 let mut best_idx = 0;
1242 let mut best_val = f64::NEG_INFINITY;
1243 for (i, &v) in data.iter().enumerate() {
1244 if v > best_val || (v == best_val && i < best_idx) {
1245 best_val = v;
1246 best_idx = i;
1247 }
1248 }
1249 best_idx
1250 }
1251
1252 pub fn argmin(&self) -> usize {
1254 let data = self.to_vec();
1255 let mut best_idx = 0;
1256 let mut best_val = f64::INFINITY;
1257 for (i, &v) in data.iter().enumerate() {
1258 if v < best_val || (v == best_val && i < best_idx) {
1259 best_val = v;
1260 best_idx = i;
1261 }
1262 }
1263 best_idx
1264 }
1265
1266 pub fn clamp(&self, min: f64, max: f64) -> Tensor {
1268 let data = self.to_vec();
1269 let result: Vec<f64> = data.iter().map(|&x| x.max(min).min(max)).collect();
1270 Tensor::from_vec(result, &self.shape).unwrap()
1271 }
1272
1273 pub fn one_hot(indices: &[usize], depth: usize) -> Result<Tensor, RuntimeError> {
1276 let n = indices.len();
1277 let mut data = vec![0.0; n * depth];
1278 for (i, &idx) in indices.iter().enumerate() {
1279 if idx >= depth {
1280 return Err(RuntimeError::InvalidOperation(format!(
1281 "one_hot: index {idx} >= depth {depth}"
1282 )));
1283 }
1284 data[i * depth + idx] = 1.0;
1285 }
1286 Tensor::from_vec(data, &[n, depth])
1287 }
1288
1289 pub fn cat(tensors: &[&Tensor], axis: usize) -> Result<Tensor, RuntimeError> {
1295 if tensors.is_empty() {
1296 return Err(RuntimeError::InvalidOperation("cat: no tensors".to_string()));
1297 }
1298 let ndim = tensors[0].ndim();
1299 if axis >= ndim {
1300 return Err(RuntimeError::InvalidOperation(
1301 format!("cat: axis {axis} out of bounds for {ndim}D tensor"),
1302 ));
1303 }
1304 for (i, t) in tensors.iter().enumerate().skip(1) {
1305 if t.ndim() != ndim {
1306 return Err(RuntimeError::InvalidOperation(
1307 format!("cat: tensor {i} has different ndim"),
1308 ));
1309 }
1310 for d in 0..ndim {
1311 if d != axis && t.shape[d] != tensors[0].shape[d] {
1312 return Err(RuntimeError::InvalidOperation(
1313 format!("cat: shape mismatch at dim {d}"),
1314 ));
1315 }
1316 }
1317 }
1318 let mut out_shape = tensors[0].shape.clone();
1319 for t in tensors.iter().skip(1) {
1320 out_shape[axis] += t.shape[axis];
1321 }
1322 let total = out_shape.iter().product::<usize>();
1323 let mut result = vec![0.0; total];
1324 let mut out_strides = vec![1usize; ndim];
1325 for d in (0..ndim - 1).rev() {
1326 out_strides[d] = out_strides[d + 1] * out_shape[d + 1];
1327 }
1328 let mut offset = 0;
1329 for t in tensors {
1330 let t_data = t.to_vec();
1331 let t_total: usize = t.shape.iter().product();
1332 let mut t_strides = vec![1usize; ndim];
1333 for d in (0..ndim - 1).rev() {
1334 t_strides[d] = t_strides[d + 1] * t.shape[d + 1];
1335 }
1336 for idx in 0..t_total {
1337 let mut remaining = idx;
1338 let mut out_flat = 0;
1339 for d in 0..ndim {
1340 let coord = remaining / t_strides[d];
1341 remaining %= t_strides[d];
1342 let out_coord = if d == axis { coord + offset } else { coord };
1343 out_flat += out_coord * out_strides[d];
1344 }
1345 result[out_flat] = t_data[idx];
1346 }
1347 offset += t.shape[axis];
1348 }
1349 Tensor::from_vec(result, &out_shape)
1350 }
1351
1352 pub fn stack(tensors: &[&Tensor], axis: usize) -> Result<Tensor, RuntimeError> {
1354 if tensors.is_empty() {
1355 return Err(RuntimeError::InvalidOperation("stack: no tensors".to_string()));
1356 }
1357 let base_shape = &tensors[0].shape;
1358 let ndim = base_shape.len();
1359 if axis > ndim {
1360 return Err(RuntimeError::InvalidOperation(
1361 format!("stack: axis {axis} out of bounds"),
1362 ));
1363 }
1364 for (i, t) in tensors.iter().enumerate().skip(1) {
1365 if &t.shape != base_shape {
1366 return Err(RuntimeError::InvalidOperation(
1367 format!("stack: tensor {i} shape mismatch"),
1368 ));
1369 }
1370 }
1371 let mut out_shape = Vec::with_capacity(ndim + 1);
1372 for d in 0..axis { out_shape.push(base_shape[d]); }
1373 out_shape.push(tensors.len());
1374 for d in axis..ndim { out_shape.push(base_shape[d]); }
1375 let total: usize = out_shape.iter().product();
1376 let mut result = vec![0.0; total];
1377 let inner_size: usize = base_shape[axis..].iter().product::<usize>().max(1);
1378 let outer_size: usize = base_shape[..axis].iter().product::<usize>().max(1);
1379 for (t_idx, t) in tensors.iter().enumerate() {
1380 let t_data = t.to_vec();
1381 for outer in 0..outer_size {
1382 for inner in 0..inner_size {
1383 let src = outer * inner_size + inner;
1384 let dst = outer * (tensors.len() * inner_size) + t_idx * inner_size + inner;
1385 if src < t_data.len() && dst < result.len() {
1386 result[dst] = t_data[src];
1387 }
1388 }
1389 }
1390 }
1391 Tensor::from_vec(result, &out_shape)
1392 }
1393
1394 pub fn topk(&self, k: usize) -> Result<(Tensor, Vec<usize>), RuntimeError> {
1396 let data = self.to_vec();
1397 let n = data.len();
1398 if k > n {
1399 return Err(RuntimeError::InvalidOperation(
1400 format!("topk: k={k} exceeds data length {n}"),
1401 ));
1402 }
1403 let mut indexed: Vec<(usize, f64)> = data.into_iter().enumerate().collect();
1404 indexed.sort_by(|a, b| b.1.total_cmp(&a.1).then(a.0.cmp(&b.0)));
1405 let top_k: Vec<(usize, f64)> = indexed[..k].to_vec();
1406 let values: Vec<f64> = top_k.iter().map(|&(_, v)| v).collect();
1407 let indices: Vec<usize> = top_k.iter().map(|&(i, _)| i).collect();
1408 Ok((Tensor::from_vec(values, &[k])?, indices))
1409 }
1410
1411 pub fn gelu(&self) -> Tensor {
1413 let data = self.to_vec();
1414 let sqrt_2_over_pi = (2.0_f64 / std::f64::consts::PI).sqrt();
1415 let result: Vec<f64> = data.iter().map(|&x| {
1416 let inner = sqrt_2_over_pi * (x + 0.044715 * x * x * x);
1417 0.5 * x * (1.0 + inner.tanh())
1418 }).collect();
1419 Tensor::from_vec(result, &self.shape).unwrap()
1420 }
1421
1422 pub fn linear(
1428 &self,
1429 weight: &Tensor,
1430 bias: &Tensor,
1431 ) -> Result<Tensor, RuntimeError> {
1432 if weight.ndim() != 2 {
1433 return Err(RuntimeError::InvalidOperation(
1434 "linear: weight must be 2-D [out_features, in_features]".to_string(),
1435 ));
1436 }
1437 let out_features = weight.shape[0];
1438 let in_features = weight.shape[1];
1439 let last_dim = *self.shape.last().ok_or_else(|| {
1440 RuntimeError::InvalidOperation("linear: input must be at least 1-D".to_string())
1441 })?;
1442 if last_dim != in_features {
1443 return Err(RuntimeError::DimensionMismatch {
1444 expected: in_features,
1445 got: last_dim,
1446 });
1447 }
1448 if bias.len() != out_features {
1449 return Err(RuntimeError::InvalidOperation(
1450 format!(
1451 "linear: bias length {} must match out_features {}",
1452 bias.len(),
1453 out_features
1454 ),
1455 ));
1456 }
1457
1458 let data = self.to_vec();
1459 let w = weight.to_vec();
1460 let b = bias.to_vec();
1461 let outer = data.len() / in_features;
1462 let mut result = vec![0.0f64; outer * out_features];
1463
1464 for row in 0..outer {
1465 let x_start = row * in_features;
1466 let x_slice = &data[x_start..x_start + in_features];
1467 let y_start = row * out_features;
1468 for j in 0..out_features {
1469 let w_start = j * in_features;
1470 let mut acc = BinnedAccumulatorF64::new();
1471 for p in 0..in_features {
1472 acc.add(x_slice[p] * w[w_start + p]);
1473 }
1474 result[y_start + j] = acc.finalize() + b[j];
1475 }
1476 }
1477
1478 let mut out_shape = self.shape[..self.shape.len() - 1].to_vec();
1479 out_shape.push(out_features);
1480 Tensor::from_vec(result, &out_shape)
1481 }
1482
1483 pub fn conv1d(
1487 &self,
1488 filters: &Tensor,
1489 bias: &Tensor,
1490 ) -> Result<Tensor, RuntimeError> {
1491 if self.ndim() != 1 {
1492 return Err(RuntimeError::InvalidOperation(
1493 "conv1d: input must be 1-D [signal_len]".to_string(),
1494 ));
1495 }
1496 if filters.ndim() != 2 {
1497 return Err(RuntimeError::InvalidOperation(
1498 "conv1d: filters must be 2-D [out_channels, kernel_size]".to_string(),
1499 ));
1500 }
1501 let signal_len = self.shape[0];
1502 let out_channels = filters.shape[0];
1503 let kernel_size = filters.shape[1];
1504 if signal_len < kernel_size {
1505 return Err(RuntimeError::InvalidOperation(
1506 format!(
1507 "conv1d: signal_len {} < kernel_size {}",
1508 signal_len, kernel_size
1509 ),
1510 ));
1511 }
1512 if bias.len() != out_channels {
1513 return Err(RuntimeError::InvalidOperation(
1514 format!(
1515 "conv1d: bias length {} must match out_channels {}",
1516 bias.len(), out_channels
1517 ),
1518 ));
1519 }
1520 let out_len = signal_len - kernel_size + 1;
1521 let s = self.to_vec();
1522 let f = filters.to_vec();
1523 let b = bias.to_vec();
1524 let mut result = vec![0.0; out_channels * out_len];
1525 kernel_fns::conv1d_raw(&s, &f, &b, &mut result, signal_len, out_channels, kernel_size);
1526 Tensor::from_vec(result, &[out_channels, out_len])
1527 }
1528
1529 pub fn conv2d(
1543 &self,
1544 filters: &Tensor,
1545 bias: &Tensor,
1546 stride: usize,
1547 ) -> Result<Tensor, RuntimeError> {
1548 if self.ndim() != 4 {
1549 return Err(RuntimeError::InvalidOperation(
1550 "conv2d: input must be 4-D [N, C_in, H, W]".to_string(),
1551 ));
1552 }
1553 if filters.ndim() != 4 {
1554 return Err(RuntimeError::InvalidOperation(
1555 "conv2d: filters must be 4-D [C_out, C_in, kH, kW]".to_string(),
1556 ));
1557 }
1558 if stride == 0 {
1559 return Err(RuntimeError::InvalidOperation(
1560 "conv2d: stride must be >= 1".to_string(),
1561 ));
1562 }
1563
1564 let n = self.shape[0];
1565 let c_in = self.shape[1];
1566 let h_in = self.shape[2];
1567 let w_in = self.shape[3];
1568
1569 let c_out = filters.shape[0];
1570 let c_in_check = filters.shape[1];
1571 let kh = filters.shape[2];
1572 let kw = filters.shape[3];
1573
1574 if c_in != c_in_check {
1575 return Err(RuntimeError::InvalidOperation(format!(
1576 "conv2d: input C_in={} does not match filter C_in={}",
1577 c_in, c_in_check
1578 )));
1579 }
1580 if h_in < kh || w_in < kw {
1581 return Err(RuntimeError::InvalidOperation(format!(
1582 "conv2d: input spatial [{}, {}] is smaller than kernel [{}, {}]",
1583 h_in, w_in, kh, kw
1584 )));
1585 }
1586 if bias.len() != c_out {
1587 return Err(RuntimeError::InvalidOperation(format!(
1588 "conv2d: bias length {} must match C_out={}",
1589 bias.len(), c_out
1590 )));
1591 }
1592
1593 let h_out = (h_in - kh) / stride + 1;
1594 let w_out = (w_in - kw) / stride + 1;
1595
1596 let inp = self.to_vec();
1597 let flt = filters.to_vec();
1598 let b = bias.to_vec();
1599 let mut result = vec![0.0f64; n * c_out * h_out * w_out];
1600
1601 kernel_fns::conv2d_raw(&inp, &flt, &b, &mut result,
1602 n, c_in, h_in, w_in, c_out, kh, kw, stride);
1603
1604 Tensor::from_vec(result, &[n, c_out, h_out, w_out])
1605 }
1606
1607 pub fn maxpool2d(&self, ph: usize, pw: usize) -> Result<Tensor, RuntimeError> {
1614 if self.ndim() != 4 {
1615 return Err(RuntimeError::InvalidOperation(
1616 "maxpool2d: input must be 4-D [N, C, H, W]".to_string(),
1617 ));
1618 }
1619 if ph == 0 || pw == 0 {
1620 return Err(RuntimeError::InvalidOperation(
1621 "maxpool2d: pool size must be >= 1".to_string(),
1622 ));
1623 }
1624
1625 let n = self.shape[0];
1626 let c = self.shape[1];
1627 let h_in = self.shape[2];
1628 let w_in = self.shape[3];
1629
1630 if h_in < ph || w_in < pw {
1631 return Err(RuntimeError::InvalidOperation(format!(
1632 "maxpool2d: input [{}, {}] smaller than pool [{}, {}]",
1633 h_in, w_in, ph, pw
1634 )));
1635 }
1636
1637 let h_out = h_in / ph;
1638 let w_out = w_in / pw;
1639
1640 let inp = self.to_vec();
1641 let mut result = vec![0.0f64; n * c * h_out * w_out];
1642
1643 kernel_fns::maxpool2d_raw(&inp, &mut result, n, c, h_in, w_in, ph, pw);
1644
1645 Tensor::from_vec(result, &[n, c, h_out, w_out])
1646 }
1647
1648 pub fn avgpool2d(&self, kernel_h: usize, kernel_w: usize, stride_h: usize, stride_w: usize) -> Result<Tensor, RuntimeError> {
1663 let shape = self.shape();
1664 if shape.len() != 3 {
1665 return Err(RuntimeError::InvalidOperation(format!("avgpool2d requires 3-D [C,H,W], got {:?}", shape)));
1666 }
1667 let (c, h, w) = (shape[0], shape[1], shape[2]);
1668 if kernel_h > h || kernel_w > w {
1669 return Err(RuntimeError::InvalidOperation("avgpool2d: kernel larger than input".into()));
1670 }
1671 let out_h = (h - kernel_h) / stride_h + 1;
1672 let out_w = (w - kernel_w) / stride_w + 1;
1673 let data = self.to_vec();
1674 let mut out = Vec::with_capacity(c * out_h * out_w);
1675 let pool_size = (kernel_h * kernel_w) as f64;
1676
1677 for ch in 0..c {
1678 for oh in 0..out_h {
1679 for ow in 0..out_w {
1680 let mut sum = 0.0f64;
1681 for kh in 0..kernel_h {
1682 for kw in 0..kernel_w {
1683 let ih = oh * stride_h + kh;
1684 let iw = ow * stride_w + kw;
1685 sum += data[ch * h * w + ih * w + iw];
1686 }
1687 }
1688 out.push(sum / pool_size);
1689 }
1690 }
1691 }
1692 Tensor::from_vec(out, &[c, out_h, out_w])
1693 }
1694
1695 pub fn scaled_dot_product_attention(
1704 queries: &Tensor,
1705 keys: &Tensor,
1706 values: &Tensor,
1707 ) -> Result<Tensor, RuntimeError> {
1708 if queries.ndim() < 2 || keys.ndim() < 2 || values.ndim() < 2 {
1709 return Err(RuntimeError::InvalidOperation(
1710 "attention: Q, K, V must be at least 2-D".to_string(),
1711 ));
1712 }
1713 let nd = queries.ndim();
1714 let d_k = queries.shape[nd - 1];
1715 let scale = 1.0 / (d_k as f64).sqrt();
1716
1717 let keys_t = keys.transpose_last_two()?;
1719
1720 let scores = queries.bmm(&keys_t)?;
1722
1723 let scores_scaled = scores.scalar_mul(scale);
1725
1726 let attn_weights = scores_scaled.softmax()?;
1728
1729 attn_weights.bmm(values)
1731 }
1732
1733 pub fn transpose_last_two(&self) -> Result<Tensor, RuntimeError> {
1737 if self.ndim() < 2 {
1738 return Err(RuntimeError::InvalidOperation(
1739 "transpose_last_two requires at least 2-D tensor".to_string(),
1740 ));
1741 }
1742 let nd = self.ndim();
1743 let rows = self.shape[nd - 2];
1744 let cols = self.shape[nd - 1];
1745 let data = self.to_vec();
1746 let batch_size: usize = self.shape[..nd - 2].iter().product::<usize>().max(1);
1747 let mat_size = rows * cols;
1748 let mut result = vec![0.0f64; data.len()];
1749
1750 for b in 0..batch_size {
1751 let off = b * mat_size;
1752 for i in 0..rows {
1753 for j in 0..cols {
1754 result[off + j * rows + i] = data[off + i * cols + j];
1755 }
1756 }
1757 }
1758
1759 let mut out_shape = self.shape.clone();
1760 out_shape[nd - 2] = cols;
1761 out_shape[nd - 1] = rows;
1762 Tensor::from_vec(result, &out_shape)
1763 }
1764
1765 pub fn from_bytes(bytes: &[u8], shape: &[usize], dtype: &str) -> Result<Tensor, RuntimeError> {
1781 let numel = Self::shape_numel(shape);
1782 match dtype {
1783 "f64" => {
1784 let expected = numel * 8;
1785 if bytes.len() != expected {
1786 return Err(RuntimeError::ShapeMismatch {
1787 expected,
1788 got: bytes.len(),
1789 });
1790 }
1791 let mut data = Vec::with_capacity(numel);
1792 for i in 0..numel {
1793 let off = i * 8;
1794 let mut buf = [0u8; 8];
1795 buf.copy_from_slice(&bytes[off..off + 8]);
1796 data.push(f64::from_le_bytes(buf));
1797 }
1798 Ok(Tensor {
1799 buffer: Buffer::from_vec(data),
1800 shape: shape.to_vec(),
1801 strides: Self::compute_strides(shape),
1802 offset: 0,
1803 })
1804 }
1805 "f32" => {
1806 let expected = numel * 4;
1807 if bytes.len() != expected {
1808 return Err(RuntimeError::ShapeMismatch {
1809 expected,
1810 got: bytes.len(),
1811 });
1812 }
1813 let mut data = Vec::with_capacity(numel);
1814 for i in 0..numel {
1815 let off = i * 4;
1816 let mut buf = [0u8; 4];
1817 buf.copy_from_slice(&bytes[off..off + 4]);
1818 data.push(f32::from_le_bytes(buf) as f64);
1819 }
1820 Ok(Tensor {
1821 buffer: Buffer::from_vec(data),
1822 shape: shape.to_vec(),
1823 strides: Self::compute_strides(shape),
1824 offset: 0,
1825 })
1826 }
1827 _ => Err(RuntimeError::InvalidOperation(
1828 format!("from_bytes: unsupported dtype '{}', expected 'f32' or 'f64'", dtype),
1829 )),
1830 }
1831 }
1832
1833 pub fn split_heads(&self, num_heads: usize) -> Result<Tensor, RuntimeError> {
1841 if self.ndim() != 3 {
1842 return Err(RuntimeError::DimensionMismatch {
1843 expected: 3,
1844 got: self.ndim(),
1845 });
1846 }
1847 let batch = self.shape[0];
1848 let seq = self.shape[1];
1849 let model_dim = self.shape[2];
1850 if model_dim % num_heads != 0 {
1851 return Err(RuntimeError::InvalidOperation(
1852 format!(
1853 "split_heads: model_dim {} not divisible by num_heads {}",
1854 model_dim, num_heads
1855 ),
1856 ));
1857 }
1858 let head_dim = model_dim / num_heads;
1859 let tensor = if self.is_contiguous() { self.clone() } else { self.to_contiguous() };
1861 let reshaped = Tensor {
1863 buffer: tensor.buffer.clone(),
1864 shape: vec![batch, seq, num_heads, head_dim],
1865 strides: Self::compute_strides(&[batch, seq, num_heads, head_dim]),
1866 offset: 0,
1867 };
1868 Ok(Tensor {
1871 buffer: reshaped.buffer,
1872 shape: vec![batch, num_heads, seq, head_dim],
1873 strides: vec![
1874 reshaped.strides[0], reshaped.strides[2], reshaped.strides[1], reshaped.strides[3], ],
1879 offset: 0,
1880 })
1881 }
1882
1883 pub fn merge_heads(&self) -> Result<Tensor, RuntimeError> {
1886 if self.ndim() != 4 {
1887 return Err(RuntimeError::DimensionMismatch {
1888 expected: 4,
1889 got: self.ndim(),
1890 });
1891 }
1892 let batch = self.shape[0];
1893 let num_heads = self.shape[1];
1894 let seq = self.shape[2];
1895 let head_dim = self.shape[3];
1896 let transposed = Tensor {
1899 buffer: self.buffer.clone(),
1900 shape: vec![batch, seq, num_heads, head_dim],
1901 strides: vec![
1902 self.strides[0],
1903 self.strides[2], self.strides[1], self.strides[3],
1906 ],
1907 offset: self.offset,
1908 };
1909 let contig = transposed.to_contiguous();
1911 let model_dim = num_heads * head_dim;
1912 Ok(Tensor {
1913 buffer: contig.buffer,
1914 shape: vec![batch, seq, model_dim],
1915 strides: Self::compute_strides(&[batch, seq, model_dim]),
1916 offset: 0,
1917 })
1918 }
1919
1920 pub fn view_reshape(&self, new_shape: &[usize]) -> Result<Tensor, RuntimeError> {
1923 self.reshape(new_shape)
1924 }
1925
1926 pub fn argsort(&self) -> Tensor {
1933 let data = self.to_vec();
1934 let mut indices: Vec<usize> = (0..data.len()).collect();
1935 indices.sort_by(|&a, &b| data[a].total_cmp(&data[b]));
1936 let result: Vec<f64> = indices.iter().map(|&i| i as f64).collect();
1937 Tensor::from_vec_unchecked(result, &[data.len()])
1938 }
1939
1940 pub fn gather(&self, dim: usize, indices: &Tensor) -> Result<Tensor, RuntimeError> {
1945 let data = self.to_vec();
1946 let idx_data = indices.to_vec();
1947 if self.ndim() == 1 {
1948 let mut result = Vec::with_capacity(idx_data.len());
1949 for &idx in &idx_data {
1950 let i = idx as usize;
1951 if i >= data.len() {
1952 return Err(RuntimeError::InvalidOperation(
1953 format!("gather: index {} out of bounds for size {}", i, data.len()),
1954 ));
1955 }
1956 result.push(data[i]);
1957 }
1958 Ok(Tensor::from_vec_unchecked(result, indices.shape()))
1959 } else if self.ndim() == 2 {
1960 let rows = self.shape[0];
1961 let cols = self.shape[1];
1962 let idx_shape = indices.shape();
1963 let out_rows = idx_shape[0];
1964 let out_cols = idx_shape[1];
1965 let mut result = vec![0.0; out_rows * out_cols];
1966 for i in 0..out_rows {
1967 for j in 0..out_cols {
1968 let idx = idx_data[i * out_cols + j] as usize;
1969 let val = if dim == 0 {
1970 if idx >= rows {
1971 return Err(RuntimeError::InvalidOperation(
1972 format!("gather dim=0: index {} out of bounds for {} rows", idx, rows),
1973 ));
1974 }
1975 data[idx * cols + j]
1976 } else {
1977 if idx >= cols {
1978 return Err(RuntimeError::InvalidOperation(
1979 format!("gather dim=1: index {} out of bounds for {} cols", idx, cols),
1980 ));
1981 }
1982 data[i * cols + idx]
1983 };
1984 result[i * out_cols + j] = val;
1985 }
1986 }
1987 Ok(Tensor::from_vec_unchecked(result, idx_shape))
1988 } else {
1989 Err(RuntimeError::InvalidOperation(
1990 "gather: only 1D and 2D tensors supported".into(),
1991 ))
1992 }
1993 }
1994
1995 pub fn scatter(&self, dim: usize, indices: &Tensor, src: &Tensor) -> Result<Tensor, RuntimeError> {
2000 let mut result = self.to_vec();
2001 let idx_data = indices.to_vec();
2002 let src_data = src.to_vec();
2003 if self.ndim() == 1 {
2004 for (k, &idx) in idx_data.iter().enumerate() {
2005 let i = idx as usize;
2006 if i >= result.len() {
2007 return Err(RuntimeError::InvalidOperation(
2008 format!("scatter: index {} out of bounds for size {}", i, result.len()),
2009 ));
2010 }
2011 result[i] = src_data[k];
2012 }
2013 Ok(Tensor::from_vec_unchecked(result, self.shape()))
2014 } else if self.ndim() == 2 {
2015 let cols = self.shape[1];
2016 let idx_shape = indices.shape();
2017 let out_cols = idx_shape[1];
2018 let out_rows = idx_shape[0];
2019 for i in 0..out_rows {
2020 for j in 0..out_cols {
2021 let idx = idx_data[i * out_cols + j] as usize;
2022 let src_val = src_data[i * out_cols + j];
2023 if dim == 0 {
2024 if idx >= self.shape[0] {
2025 return Err(RuntimeError::InvalidOperation(
2026 format!("scatter dim=0: index {} out of bounds for {} rows", idx, self.shape[0]),
2027 ));
2028 }
2029 result[idx * cols + j] = src_val;
2030 } else {
2031 if idx >= cols {
2032 return Err(RuntimeError::InvalidOperation(
2033 format!("scatter dim=1: index {} out of bounds for {} cols", idx, cols),
2034 ));
2035 }
2036 result[i * cols + idx] = src_val;
2037 }
2038 }
2039 }
2040 Ok(Tensor::from_vec_unchecked(result, self.shape()))
2041 } else {
2042 Err(RuntimeError::InvalidOperation(
2043 "scatter: only 1D and 2D tensors supported".into(),
2044 ))
2045 }
2046 }
2047
2048 pub fn index_select(&self, dim: usize, indices: &Tensor) -> Result<Tensor, RuntimeError> {
2052 let data = self.to_vec();
2053 let idx_data = indices.to_vec();
2054 if self.ndim() == 1 {
2055 let mut result = Vec::with_capacity(idx_data.len());
2056 for &idx in &idx_data {
2057 let i = idx as usize;
2058 if i >= data.len() {
2059 return Err(RuntimeError::InvalidOperation(
2060 format!("index_select: index {} out of bounds for size {}", i, data.len()),
2061 ));
2062 }
2063 result.push(data[i]);
2064 }
2065 Ok(Tensor::from_vec_unchecked(result, &[idx_data.len()]))
2066 } else if self.ndim() == 2 {
2067 let rows = self.shape[0];
2068 let cols = self.shape[1];
2069 let n = idx_data.len();
2070 if dim == 0 {
2071 let mut result = Vec::with_capacity(n * cols);
2072 for &idx in &idx_data {
2073 let i = idx as usize;
2074 if i >= rows {
2075 return Err(RuntimeError::InvalidOperation(
2076 format!("index_select dim=0: index {} out of bounds for {} rows", i, rows),
2077 ));
2078 }
2079 for j in 0..cols {
2080 result.push(data[i * cols + j]);
2081 }
2082 }
2083 Ok(Tensor::from_vec_unchecked(result, &[n, cols]))
2084 } else {
2085 let mut result = Vec::with_capacity(rows * n);
2086 for i in 0..rows {
2087 for &idx in &idx_data {
2088 let j = idx as usize;
2089 if j >= cols {
2090 return Err(RuntimeError::InvalidOperation(
2091 format!("index_select dim=1: index {} out of bounds for {} cols", j, cols),
2092 ));
2093 }
2094 result.push(data[i * cols + j]);
2095 }
2096 }
2097 Ok(Tensor::from_vec_unchecked(result, &[rows, n]))
2098 }
2099 } else {
2100 Err(RuntimeError::InvalidOperation(
2101 "index_select: only 1D and 2D tensors supported".into(),
2102 ))
2103 }
2104 }
2105
2106 pub fn tensor_where(&self, condition: &Tensor, other: &Tensor) -> Result<Tensor, RuntimeError> {
2113 if self.shape() != condition.shape() || self.shape() != other.shape() {
2114 return Err(RuntimeError::InvalidOperation(
2115 format!("where: shape mismatch self={:?} cond={:?} other={:?}",
2116 self.shape(), condition.shape(), other.shape()),
2117 ));
2118 }
2119 let s = self.to_vec();
2120 let c = condition.to_vec();
2121 let o = other.to_vec();
2122 let result: Vec<f64> = s.iter().zip(c.iter()).zip(o.iter())
2123 .map(|((&sv, &cv), &ov)| if cv != 0.0 { sv } else { ov })
2124 .collect();
2125 Tensor::from_vec(result, self.shape())
2126 }
2127
2128 pub fn any(&self) -> bool {
2130 let data = self.to_vec();
2131 data.iter().any(|&x| x != 0.0)
2132 }
2133
2134 pub fn all(&self) -> bool {
2136 let data = self.to_vec();
2137 data.iter().all(|&x| x != 0.0)
2138 }
2139
2140 pub fn nonzero(&self) -> Tensor {
2144 let data = self.to_vec();
2145 let indices: Vec<f64> = data.iter().enumerate()
2146 .filter(|(_, &v)| v != 0.0)
2147 .map(|(i, _)| i as f64)
2148 .collect();
2149 let len = indices.len();
2150 if len == 0 {
2151 Tensor::from_vec(vec![], &[0]).unwrap()
2152 } else {
2153 Tensor::from_vec(indices, &[len]).unwrap()
2154 }
2155 }
2156
2157 pub fn masked_fill(&self, mask: &Tensor, value: f64) -> Result<Tensor, RuntimeError> {
2159 if self.shape() != mask.shape() {
2160 return Err(RuntimeError::InvalidOperation(
2161 format!("masked_fill: shape mismatch self={:?} mask={:?}",
2162 self.shape(), mask.shape()),
2163 ));
2164 }
2165 let data = self.to_vec();
2166 let m = mask.to_vec();
2167 let result: Vec<f64> = data.iter().zip(m.iter())
2168 .map(|(&d, &mv)| if mv != 0.0 { value } else { d })
2169 .collect();
2170 Tensor::from_vec(result, self.shape())
2171 }
2172
2173 fn reduce_axis<F>(&self, axis: usize, keepdim: bool, reduce_fn: F)
2183 -> Result<Tensor, RuntimeError>
2184 where
2185 F: Fn(&[f64]) -> f64,
2186 {
2187 let ndim = self.ndim();
2188 if axis >= ndim {
2189 return Err(RuntimeError::IndexOutOfBounds {
2190 index: axis,
2191 length: ndim,
2192 });
2193 }
2194
2195 let axis_len = self.shape[axis];
2196 let mut out_shape: Vec<usize> = self.shape.clone();
2198 out_shape[axis] = 1;
2199 let out_numel = Self::shape_numel(&out_shape);
2200 let out_strides = Self::compute_strides(&out_shape);
2201
2202 let data = self.to_vec();
2203 let mut result = Vec::with_capacity(out_numel);
2204 let mut indices = vec![0usize; ndim];
2205
2206 for out_idx in 0..out_numel {
2207 {
2209 let mut remaining = out_idx;
2210 for d in 0..ndim {
2211 indices[d] = remaining / out_strides[d];
2212 remaining %= out_strides[d];
2213 }
2214 }
2215
2216 let mut vals = Vec::with_capacity(axis_len);
2218 for k in 0..axis_len {
2219 let mut flat = self.offset;
2220 for d in 0..ndim {
2221 let idx = if d == axis { k } else { indices[d] };
2222 flat += idx * self.strides[d];
2223 }
2224 vals.push(data[flat]);
2225 }
2226 result.push(reduce_fn(&vals));
2227 }
2228
2229 let final_shape = if keepdim {
2230 out_shape
2231 } else {
2232 let mut s: Vec<usize> = self.shape.iter().enumerate()
2234 .filter(|&(i, _)| i != axis)
2235 .map(|(_, &v)| v)
2236 .collect();
2237 if s.is_empty() {
2238 s.push(1); }
2240 s
2241 };
2242
2243 Tensor::from_vec(result, &final_shape)
2244 }
2245
2246 pub fn mean_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2251 self.reduce_axis(axis, keepdim, |vals| {
2252 let mut acc = BinnedAccumulatorF64::new();
2253 for &v in vals { acc.add(v); }
2254 acc.finalize() / vals.len() as f64
2255 })
2256 }
2257
2258 pub fn max_axis(&self, axis: usize, keepdim: bool) -> Result<(Tensor, Tensor), RuntimeError> {
2262 let ndim = self.ndim();
2263 if axis >= ndim {
2264 return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2265 }
2266 let axis_len = self.shape[axis];
2267 let mut out_shape = self.shape.clone();
2268 out_shape[axis] = 1;
2269 let out_numel = Self::shape_numel(&out_shape);
2270 let out_strides = Self::compute_strides(&out_shape);
2271 let data = self.to_vec();
2272 let mut values = Vec::with_capacity(out_numel);
2273 let mut idx_vals = Vec::with_capacity(out_numel);
2274 let mut indices = vec![0usize; ndim];
2275
2276 for out_idx in 0..out_numel {
2277 let mut remaining = out_idx;
2278 for d in 0..ndim {
2279 indices[d] = remaining / out_strides[d];
2280 remaining %= out_strides[d];
2281 }
2282 let mut best_val = f64::NEG_INFINITY;
2283 let mut best_idx = 0usize;
2284 for k in 0..axis_len {
2285 let mut flat = self.offset;
2286 for d in 0..ndim {
2287 let idx = if d == axis { k } else { indices[d] };
2288 flat += idx * self.strides[d];
2289 }
2290 let v = data[flat];
2291 if v > best_val {
2292 best_val = v;
2293 best_idx = k;
2294 }
2295 }
2296 values.push(best_val);
2297 idx_vals.push(best_idx as f64);
2298 }
2299
2300 let final_shape = if keepdim {
2301 out_shape
2302 } else {
2303 let mut s: Vec<usize> = self.shape.iter().enumerate()
2304 .filter(|&(i, _)| i != axis).map(|(_, &v)| v).collect();
2305 if s.is_empty() { s.push(1); }
2306 s
2307 };
2308 Ok((
2309 Tensor::from_vec(values, &final_shape)?,
2310 Tensor::from_vec(idx_vals, &final_shape)?,
2311 ))
2312 }
2313
2314 pub fn min_axis(&self, axis: usize, keepdim: bool) -> Result<(Tensor, Tensor), RuntimeError> {
2318 let ndim = self.ndim();
2319 if axis >= ndim {
2320 return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2321 }
2322 let axis_len = self.shape[axis];
2323 let mut out_shape = self.shape.clone();
2324 out_shape[axis] = 1;
2325 let out_numel = Self::shape_numel(&out_shape);
2326 let out_strides = Self::compute_strides(&out_shape);
2327 let data = self.to_vec();
2328 let mut values = Vec::with_capacity(out_numel);
2329 let mut idx_vals = Vec::with_capacity(out_numel);
2330 let mut indices = vec![0usize; ndim];
2331
2332 for out_idx in 0..out_numel {
2333 let mut remaining = out_idx;
2334 for d in 0..ndim {
2335 indices[d] = remaining / out_strides[d];
2336 remaining %= out_strides[d];
2337 }
2338 let mut best_val = f64::INFINITY;
2339 let mut best_idx = 0usize;
2340 for k in 0..axis_len {
2341 let mut flat = self.offset;
2342 for d in 0..ndim {
2343 let idx = if d == axis { k } else { indices[d] };
2344 flat += idx * self.strides[d];
2345 }
2346 let v = data[flat];
2347 if v < best_val {
2348 best_val = v;
2349 best_idx = k;
2350 }
2351 }
2352 values.push(best_val);
2353 idx_vals.push(best_idx as f64);
2354 }
2355
2356 let final_shape = if keepdim {
2357 out_shape
2358 } else {
2359 let mut s: Vec<usize> = self.shape.iter().enumerate()
2360 .filter(|&(i, _)| i != axis).map(|(_, &v)| v).collect();
2361 if s.is_empty() { s.push(1); }
2362 s
2363 };
2364 Ok((
2365 Tensor::from_vec(values, &final_shape)?,
2366 Tensor::from_vec(idx_vals, &final_shape)?,
2367 ))
2368 }
2369
2370 pub fn var_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2376 let mean_t = self.mean_axis(axis, true)?;
2377 let ndim = self.ndim();
2378 if axis >= ndim {
2379 return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2380 }
2381 let axis_len = self.shape[axis];
2382 let mut out_shape = self.shape.clone();
2383 out_shape[axis] = 1;
2384 let out_numel = Self::shape_numel(&out_shape);
2385 let out_strides = Self::compute_strides(&out_shape);
2386 let data = self.to_vec();
2387 let mean_data = mean_t.to_vec();
2388 let mut result = Vec::with_capacity(out_numel);
2389 let mut indices = vec![0usize; ndim];
2390
2391 for out_idx in 0..out_numel {
2392 let mut remaining = out_idx;
2393 for d in 0..ndim {
2394 indices[d] = remaining / out_strides[d];
2395 remaining %= out_strides[d];
2396 }
2397 let mu = mean_data[out_idx];
2398 let mut acc = BinnedAccumulatorF64::new();
2399 for k in 0..axis_len {
2400 let mut flat = self.offset;
2401 for d in 0..ndim {
2402 let idx = if d == axis { k } else { indices[d] };
2403 flat += idx * self.strides[d];
2404 }
2405 let diff = data[flat] - mu;
2406 acc.add(diff * diff);
2407 }
2408 result.push(acc.finalize() / axis_len as f64);
2409 }
2410
2411 let final_shape = if keepdim {
2412 out_shape
2413 } else {
2414 let mut s: Vec<usize> = self.shape.iter().enumerate()
2415 .filter(|&(i, _)| i != axis).map(|(_, &v)| v).collect();
2416 if s.is_empty() { s.push(1); }
2417 s
2418 };
2419 Tensor::from_vec(result, &final_shape)
2420 }
2421
2422 pub fn std_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2426 let var = self.var_axis(axis, keepdim)?;
2427 Ok(var.map(|x| x.sqrt()))
2428 }
2429
2430 pub fn prod_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2434 self.reduce_axis(axis, keepdim, |vals| {
2435 let mut product = 1.0f64;
2438 for &v in vals { product *= v; }
2439 product
2440 })
2441 }
2442
2443 pub fn sort_axis(&self, axis: usize, descending: bool) -> Result<Tensor, RuntimeError> {
2453 let ndim = self.ndim();
2454 if axis >= ndim {
2455 return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2456 }
2457 let data = self.to_vec();
2458 let axis_len = self.shape[axis];
2459 let out_shape = self.shape.clone();
2460 let out_numel = Self::shape_numel(&out_shape);
2461
2462 let mut iter_shape: Vec<usize> = Vec::new();
2464 for (i, &s) in self.shape.iter().enumerate() {
2465 if i != axis { iter_shape.push(s); }
2466 }
2467 let n_slices: usize = iter_shape.iter().product::<usize>().max(1);
2468
2469 let mut result = vec![0.0f64; out_numel];
2470
2471 let mut pos = vec![0usize; ndim];
2473 for slice_idx in 0..n_slices {
2474 let mut remaining = slice_idx;
2476 let mut dim_idx = 0;
2477 for d in 0..ndim {
2478 if d == axis {
2479 pos[d] = 0;
2480 } else {
2481 let stride = {
2482 let mut s = 1usize;
2483 let mut di = 0;
2484 for d2 in 0..ndim {
2485 if d2 == axis { continue; }
2486 if di > dim_idx { s *= self.shape[d2]; }
2487 di += 1;
2488 }
2489 s
2490 };
2491 pos[d] = remaining / stride;
2492 remaining %= stride;
2493 dim_idx += 1;
2494 }
2495 }
2496
2497 let mut vals: Vec<(f64, usize)> = Vec::with_capacity(axis_len);
2499 for k in 0..axis_len {
2500 let mut flat = self.offset;
2501 for d in 0..ndim {
2502 let idx = if d == axis { k } else { pos[d] };
2503 flat += idx * self.strides[d];
2504 }
2505 vals.push((data[flat], k));
2506 }
2507
2508 if descending {
2510 vals.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
2511 .then(a.1.cmp(&b.1)));
2512 } else {
2513 vals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
2514 .then(a.1.cmp(&b.1)));
2515 }
2516
2517 for (k, &(v, _)) in vals.iter().enumerate() {
2519 let mut flat = 0;
2520 let out_strides_local = Self::compute_strides(&out_shape);
2521 for d in 0..ndim {
2522 let idx = if d == axis { k } else { pos[d] };
2523 flat += idx * out_strides_local[d];
2524 }
2525 result[flat] = v;
2526 }
2527 }
2528
2529 Tensor::from_vec(result, &out_shape)
2530 }
2531
2532 pub fn argsort_axis(&self, axis: usize, descending: bool) -> Result<Tensor, RuntimeError> {
2537 let ndim = self.ndim();
2538 if axis >= ndim {
2539 return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2540 }
2541 let data = self.to_vec();
2542 let axis_len = self.shape[axis];
2543 let out_shape = self.shape.clone();
2544 let out_numel = Self::shape_numel(&out_shape);
2545
2546 let mut iter_shape: Vec<usize> = Vec::new();
2547 for (i, &s) in self.shape.iter().enumerate() {
2548 if i != axis { iter_shape.push(s); }
2549 }
2550 let n_slices: usize = iter_shape.iter().product::<usize>().max(1);
2551
2552 let mut result = vec![0.0f64; out_numel];
2553 let mut pos = vec![0usize; ndim];
2554
2555 for slice_idx in 0..n_slices {
2556 let mut remaining = slice_idx;
2557 let mut dim_idx = 0;
2558 for d in 0..ndim {
2559 if d == axis {
2560 pos[d] = 0;
2561 } else {
2562 let stride = {
2563 let mut s = 1usize;
2564 let mut di = 0;
2565 for d2 in 0..ndim {
2566 if d2 == axis { continue; }
2567 if di > dim_idx { s *= self.shape[d2]; }
2568 di += 1;
2569 }
2570 s
2571 };
2572 pos[d] = remaining / stride;
2573 remaining %= stride;
2574 dim_idx += 1;
2575 }
2576 }
2577
2578 let mut vals: Vec<(f64, usize)> = Vec::with_capacity(axis_len);
2579 for k in 0..axis_len {
2580 let mut flat = self.offset;
2581 for d in 0..ndim {
2582 let idx = if d == axis { k } else { pos[d] };
2583 flat += idx * self.strides[d];
2584 }
2585 vals.push((data[flat], k));
2586 }
2587
2588 if descending {
2589 vals.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
2590 .then(a.1.cmp(&b.1)));
2591 } else {
2592 vals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
2593 .then(a.1.cmp(&b.1)));
2594 }
2595
2596 for (k, &(_, orig_idx)) in vals.iter().enumerate() {
2597 let out_strides_local = Self::compute_strides(&out_shape);
2598 let mut flat = 0;
2599 for d in 0..ndim {
2600 let idx = if d == axis { k } else { pos[d] };
2601 flat += idx * out_strides_local[d];
2602 }
2603 result[flat] = orig_idx as f64;
2604 }
2605 }
2606
2607 Tensor::from_vec(result, &out_shape)
2608 }
2609
2610 pub fn einsum(notation: &str, inputs: &[&Tensor]) -> Result<Tensor, RuntimeError> {
2619 let parts: Vec<&str> = notation.split("->").collect();
2621 if parts.len() != 2 {
2622 return Err(RuntimeError::InvalidOperation(
2623 format!("einsum: expected 'subscripts->output' notation, got '{}'", notation),
2624 ));
2625 }
2626 let input_specs: Vec<&str> = parts[0].split(',').collect();
2627 let output_spec = parts[1];
2628
2629 if input_specs.len() != inputs.len() {
2630 return Err(RuntimeError::InvalidOperation(
2631 format!("einsum: {} input specs but {} tensors", input_specs.len(), inputs.len()),
2632 ));
2633 }
2634
2635 let mut label_size = std::collections::BTreeMap::new();
2637 for (i, &spec) in input_specs.iter().enumerate() {
2638 let chars: Vec<char> = spec.chars().collect();
2639 if chars.len() != inputs[i].ndim() {
2640 return Err(RuntimeError::InvalidOperation(
2641 format!("einsum: spec '{}' has {} dims but tensor has {}", spec, chars.len(), inputs[i].ndim()),
2642 ));
2643 }
2644 for (d, &c) in chars.iter().enumerate() {
2645 let sz = inputs[i].shape()[d];
2646 if let Some(&prev) = label_size.get(&c) {
2647 if prev != sz {
2648 return Err(RuntimeError::InvalidOperation(
2649 format!("einsum: label '{}' has conflicting sizes {} vs {}", c, prev, sz),
2650 ));
2651 }
2652 } else {
2653 label_size.insert(c, sz);
2654 }
2655 }
2656 }
2657
2658 let output_chars: Vec<char> = output_spec.chars().collect();
2660 let output_shape: Vec<usize> = output_chars.iter()
2661 .map(|c| label_size.get(c).copied().ok_or_else(||
2662 RuntimeError::InvalidOperation(format!("einsum: unknown label '{}' in output", c))))
2663 .collect::<Result<_, _>>()?;
2664 let out_numel = Self::shape_numel(&output_shape);
2665
2666 let output_set: std::collections::BTreeSet<char> = output_chars.iter().copied().collect();
2668 let contract_labels: Vec<char> = label_size.keys()
2669 .filter(|c| !output_set.contains(c))
2670 .copied()
2671 .collect();
2672 let contract_sizes: Vec<usize> = contract_labels.iter()
2673 .map(|c| label_size[c])
2674 .collect();
2675 let contract_numel: usize = contract_sizes.iter().product::<usize>().max(1);
2676
2677 let input_chars: Vec<Vec<char>> = input_specs.iter().map(|s| s.chars().collect()).collect();
2679
2680 let out_strides = Self::compute_strides(&output_shape);
2682 let mut result = vec![0.0f64; out_numel];
2683
2684 let input_data: Vec<Vec<f64>> = inputs.iter().map(|t| t.to_vec()).collect();
2686 let input_strides: Vec<Vec<usize>> = inputs.iter().map(|t| t.strides.clone()).collect();
2687 let input_offsets: Vec<usize> = inputs.iter().map(|t| t.offset).collect();
2688
2689 for out_idx in 0..out_numel {
2690 let mut label_vals = std::collections::BTreeMap::new();
2692 let mut remaining = out_idx;
2693 for (d, &c) in output_chars.iter().enumerate() {
2694 let stride = if d < out_strides.len() { out_strides[d] } else { 1 };
2695 label_vals.insert(c, remaining / stride);
2696 remaining %= stride;
2697 }
2698
2699 let mut acc = BinnedAccumulatorF64::new();
2700 for cidx in 0..contract_numel {
2702 let mut cr = cidx;
2704 for (ci, &cl) in contract_labels.iter().enumerate() {
2705 let stride: usize = contract_sizes[ci+1..].iter().product::<usize>().max(1);
2706 label_vals.insert(cl, cr / stride);
2707 cr %= stride;
2708 }
2709
2710 let mut product = 1.0f64;
2712 for (inp_idx, chars) in input_chars.iter().enumerate() {
2713 let mut flat = input_offsets[inp_idx];
2714 for (d, &c) in chars.iter().enumerate() {
2715 flat += label_vals[&c] * input_strides[inp_idx][d];
2716 }
2717 product *= input_data[inp_idx][flat];
2718 }
2719 acc.add(product);
2720 }
2721 result[out_idx] = acc.finalize();
2722 }
2723
2724 if output_shape.is_empty() {
2725 Tensor::from_vec(result, &[1])
2726 } else {
2727 Tensor::from_vec(result, &output_shape)
2728 }
2729 }
2730
2731 pub fn unsqueeze(&self, dim: usize) -> Result<Tensor, RuntimeError> {
2740 let ndim = self.ndim();
2741 if dim > ndim {
2742 return Err(RuntimeError::IndexOutOfBounds { index: dim, length: ndim + 1 });
2743 }
2744 let mut new_shape = self.shape.clone();
2745 new_shape.insert(dim, 1);
2746 self.reshape(&new_shape)
2747 }
2748
2749 pub fn squeeze(&self, dim: Option<usize>) -> Result<Tensor, RuntimeError> {
2752 match dim {
2753 Some(d) => {
2754 if d >= self.ndim() {
2755 return Err(RuntimeError::IndexOutOfBounds { index: d, length: self.ndim() });
2756 }
2757 if self.shape[d] != 1 {
2758 return Err(RuntimeError::InvalidOperation(
2759 format!("squeeze: dimension {} has size {}, not 1", d, self.shape[d]),
2760 ));
2761 }
2762 let mut new_shape = self.shape.clone();
2763 new_shape.remove(d);
2764 if new_shape.is_empty() {
2765 new_shape.push(1); }
2767 self.reshape(&new_shape)
2768 }
2769 None => {
2770 let new_shape: Vec<usize> = self.shape.iter()
2771 .filter(|&&s| s != 1)
2772 .copied()
2773 .collect();
2774 let new_shape = if new_shape.is_empty() { vec![1] } else { new_shape };
2775 self.reshape(&new_shape)
2776 }
2777 }
2778 }
2779
2780 pub fn expand(&self, target_shape: &[usize]) -> Result<Tensor, RuntimeError> {
2784 self.broadcast_to(target_shape)
2785 }
2786
2787 pub fn flatten(&self, start_dim: usize, end_dim: usize) -> Result<Tensor, RuntimeError> {
2789 if start_dim > end_dim || end_dim >= self.ndim() {
2790 return Err(RuntimeError::InvalidOperation(
2791 format!("flatten: invalid dim range [{}, {}] for {}D tensor", start_dim, end_dim, self.ndim()),
2792 ));
2793 }
2794 let mut new_shape = Vec::new();
2795 for i in 0..start_dim {
2796 new_shape.push(self.shape[i]);
2797 }
2798 let flat_size: usize = self.shape[start_dim..=end_dim].iter().product();
2799 new_shape.push(flat_size);
2800 for i in (end_dim + 1)..self.ndim() {
2801 new_shape.push(self.shape[i]);
2802 }
2803 self.reshape(&new_shape)
2804 }
2805
2806 pub fn chunk(&self, n: usize, dim: usize) -> Result<Vec<Tensor>, RuntimeError> {
2808 if dim >= self.ndim() {
2809 return Err(RuntimeError::IndexOutOfBounds { index: dim, length: self.ndim() });
2810 }
2811 if n == 0 {
2812 return Err(RuntimeError::InvalidOperation("chunk: n must be > 0".into()));
2813 }
2814 let dim_size = self.shape[dim];
2815 let chunk_size = (dim_size + n - 1) / n;
2816 let mut sizes = Vec::new();
2817 let mut remaining = dim_size;
2818 while remaining > 0 {
2819 let s = remaining.min(chunk_size);
2820 sizes.push(s);
2821 remaining -= s;
2822 }
2823 self.split(&sizes, dim)
2824 }
2825
2826 pub fn split(&self, sizes: &[usize], dim: usize) -> Result<Vec<Tensor>, RuntimeError> {
2828 if dim >= self.ndim() {
2829 return Err(RuntimeError::IndexOutOfBounds { index: dim, length: self.ndim() });
2830 }
2831 let total: usize = sizes.iter().sum();
2832 if total != self.shape[dim] {
2833 return Err(RuntimeError::InvalidOperation(
2834 format!("split: sizes sum {} != dim size {}", total, self.shape[dim]),
2835 ));
2836 }
2837
2838 let mut results = Vec::new();
2839 let mut offset = 0;
2840
2841 for &sz in sizes {
2842 let ranges: Vec<(usize, usize)> = self.shape.iter()
2843 .enumerate()
2844 .map(|(i, &s)| {
2845 if i == dim { (offset, offset + sz) } else { (0, s) }
2846 })
2847 .collect();
2848 let chunk = self.slice(&ranges)?;
2849 results.push(chunk.to_contiguous());
2851 offset += sz;
2852 }
2853
2854 Ok(results)
2855 }
2856
2857 pub fn scale_add(&self, alpha: f64, other: &Tensor, beta: f64) -> Result<Tensor, RuntimeError> {
2862 if self.shape != other.shape {
2863 return Err(RuntimeError::InvalidOperation(
2864 "scale_add: shape mismatch".to_string(),
2865 ));
2866 }
2867 let a = self.to_vec();
2868 let b = other.to_vec();
2869 let result: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| alpha * x + beta * y).collect();
2870 Tensor::from_vec(result, &self.shape)
2871 }
2872}
2873