1use cjc_repro::Rng;
43
44use crate::accumulator::{binned_sum_f64, BinnedAccumulatorF64};
45use cjc_repro::KahanAccumulatorF64;
46
47use crate::accumulator;
48use crate::buffer::Buffer;
49use crate::dispatch;
50use crate::error::RuntimeError;
51use crate::kernel as kernel_fns;
52use crate::tensor_simd::{self, BinOp, UnaryOp};
53use crate::tensor_tiled::TiledMatmul;
54
55#[derive(Debug, Clone)]
73pub struct Tensor {
74 pub buffer: Buffer<f64>,
76 pub(crate) shape: Vec<usize>,
78 pub(crate) strides: Vec<usize>,
82 pub(crate) offset: usize,
85}
86
87impl Tensor {
88 pub(crate) fn compute_strides(shape: &[usize]) -> Vec<usize> {
92 let mut strides = vec![1usize; shape.len()];
93 for i in (0..shape.len().saturating_sub(1)).rev() {
94 strides[i] = strides[i + 1] * shape[i + 1];
95 }
96 strides
97 }
98
99 fn shape_numel(shape: &[usize]) -> usize {
101 shape.iter().product()
102 }
103
104 pub fn zeros(shape: &[usize]) -> Self {
106 let numel = Self::shape_numel(shape);
107 Tensor {
108 buffer: Buffer::alloc(numel, 0.0),
109 shape: shape.to_vec(),
110 strides: Self::compute_strides(shape),
111 offset: 0,
112 }
113 }
114
115 pub fn ones(shape: &[usize]) -> Self {
117 let numel = Self::shape_numel(shape);
118 Tensor {
119 buffer: Buffer::alloc(numel, 1.0),
120 shape: shape.to_vec(),
121 strides: Self::compute_strides(shape),
122 offset: 0,
123 }
124 }
125
126 pub fn randn(shape: &[usize], rng: &mut Rng) -> Self {
129 let numel = Self::shape_numel(shape);
130 let data: Vec<f64> = (0..numel).map(|_| rng.next_normal_f64()).collect();
131 Tensor {
132 buffer: Buffer::from_vec(data),
133 shape: shape.to_vec(),
134 strides: Self::compute_strides(shape),
135 offset: 0,
136 }
137 }
138
139 pub fn from_vec(data: Vec<f64>, shape: &[usize]) -> Result<Self, RuntimeError> {
142 let numel = Self::shape_numel(shape);
143 if data.len() != numel {
144 return Err(RuntimeError::ShapeMismatch {
145 expected: numel,
146 got: data.len(),
147 });
148 }
149 Ok(Tensor {
150 buffer: Buffer::from_vec(data),
151 shape: shape.to_vec(),
152 strides: Self::compute_strides(shape),
153 offset: 0,
154 })
155 }
156
157 pub fn shape(&self) -> &[usize] {
161 &self.shape
162 }
163
164 pub fn ndim(&self) -> usize {
166 self.shape.len()
167 }
168
169 pub fn len(&self) -> usize {
171 Self::shape_numel(&self.shape)
172 }
173
174 pub fn is_empty(&self) -> bool {
176 self.len() == 0
177 }
178
179 fn linear_index(&self, indices: &[usize]) -> Result<usize, RuntimeError> {
181 if indices.len() != self.shape.len() {
182 return Err(RuntimeError::DimensionMismatch {
183 expected: self.shape.len(),
184 got: indices.len(),
185 });
186 }
187 let mut off = self.offset;
188 for (i, &idx) in indices.iter().enumerate() {
189 if idx >= self.shape[i] {
190 return Err(RuntimeError::IndexOutOfBounds {
191 index: idx,
192 length: self.shape[i],
193 });
194 }
195 off += idx * self.strides[i];
196 }
197 Ok(off)
198 }
199
200 pub fn is_contiguous(&self) -> bool {
202 if self.offset != 0 {
203 return false;
204 }
205 let expected = Self::compute_strides(&self.shape);
206 self.strides == expected
207 }
208
209 pub fn slice(&self, ranges: &[(usize, usize)]) -> Result<Tensor, RuntimeError> {
212 if ranges.len() != self.shape.len() {
213 return Err(RuntimeError::DimensionMismatch {
214 expected: self.shape.len(),
215 got: ranges.len(),
216 });
217 }
218 let mut new_offset = self.offset;
219 let mut new_shape = Vec::with_capacity(ranges.len());
220 for (i, &(start, end)) in ranges.iter().enumerate() {
221 if end > self.shape[i] || start > end {
222 return Err(RuntimeError::IndexOutOfBounds {
223 index: end,
224 length: self.shape[i],
225 });
226 }
227 new_offset += start * self.strides[i];
228 new_shape.push(end - start);
229 }
230 Ok(Tensor {
231 buffer: self.buffer.clone(), shape: new_shape,
233 strides: self.strides.clone(),
234 offset: new_offset,
235 })
236 }
237
238 pub fn to_contiguous(&self) -> Tensor {
240 if self.is_contiguous() {
241 return self.clone();
242 }
243 let data = self.to_vec();
244 Tensor {
245 buffer: Buffer::from_vec(data),
246 shape: self.shape.clone(),
247 strides: Self::compute_strides(&self.shape),
248 offset: 0,
249 }
250 }
251
252 pub fn broadcast_to(&self, target_shape: &[usize]) -> Result<Tensor, RuntimeError> {
255 let src_ndim = self.shape.len();
256 let tgt_ndim = target_shape.len();
257 if tgt_ndim < src_ndim {
258 return Err(RuntimeError::InvalidOperation(
259 "cannot broadcast to a smaller rank".to_string(),
260 ));
261 }
262 let pad = tgt_ndim - src_ndim;
263 let mut new_strides = vec![0usize; tgt_ndim];
264 for i in 0..tgt_ndim {
265 if i < pad {
266 new_strides[i] = 0;
268 } else {
269 let src_i = i - pad;
270 if self.shape[src_i] == target_shape[i] {
271 new_strides[i] = self.strides[src_i];
272 } else if self.shape[src_i] == 1 {
273 new_strides[i] = 0; } else {
275 return Err(RuntimeError::ShapeMismatch {
276 expected: target_shape[i],
277 got: self.shape[src_i],
278 });
279 }
280 }
281 }
282 Ok(Tensor {
283 buffer: self.buffer.clone(),
284 shape: target_shape.to_vec(),
285 strides: new_strides,
286 offset: self.offset,
287 })
288 }
289
290 pub fn get(&self, indices: &[usize]) -> Result<f64, RuntimeError> {
292 let offset = self.linear_index(indices)?;
293 self.buffer
294 .get(offset)
295 .ok_or(RuntimeError::IndexOutOfBounds {
296 index: offset,
297 length: self.buffer.len(),
298 })
299 }
300
301 pub fn set(&mut self, indices: &[usize], val: f64) -> Result<(), RuntimeError> {
303 let offset = self.linear_index(indices)?;
304 self.buffer.set(offset, val)
305 }
306
307 pub fn to_vec(&self) -> Vec<f64> {
309 if self.is_contiguous() {
310 let full = self.buffer.borrow_data();
311 let numel = self.len();
312 if full.len() == numel {
313 return full.to_vec();
314 }
315 return full[..numel].to_vec();
318 }
319 let numel = self.len();
321 let mut result = Vec::with_capacity(numel);
322 let ndim = self.shape.len();
323 let mut indices = vec![0usize; ndim];
324 for _ in 0..numel {
325 let mut off = self.offset;
326 for d in 0..ndim {
327 off += indices[d] * self.strides[d];
328 }
329 result.push(self.buffer.get(off).unwrap_or(0.0));
330 for d in (0..ndim).rev() {
332 indices[d] += 1;
333 if indices[d] < self.shape[d] {
334 break;
335 }
336 indices[d] = 0;
337 }
338 }
339 result
340 }
341
342 pub fn reshape(&self, new_shape: &[usize]) -> Result<Tensor, RuntimeError> {
347 let new_numel = Self::shape_numel(new_shape);
348 if new_numel != self.len() {
349 return Err(RuntimeError::ShapeMismatch {
350 expected: self.len(),
351 got: new_numel,
352 });
353 }
354 let tensor = if self.is_contiguous() { self.clone() } else { self.to_contiguous() };
356 Ok(Tensor {
357 buffer: tensor.buffer,
358 shape: new_shape.to_vec(),
359 strides: Self::compute_strides(new_shape),
360 offset: 0,
361 })
362 }
363
364 fn elementwise_binop(
368 &self,
369 other: &Tensor,
370 op: impl Fn(f64, f64) -> f64,
371 ) -> Result<Tensor, RuntimeError> {
372 if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
373 let a = self.buffer.borrow_data();
375 let b = other.buffer.borrow_data();
376 let data: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| op(x, y)).collect();
377 return Ok(Tensor {
378 buffer: Buffer::from_vec(data),
379 shape: self.shape.clone(),
380 strides: Self::compute_strides(&self.shape),
381 offset: 0,
382 });
383 }
384
385 let result_shape = Self::broadcast_result_shape(&self.shape, &other.shape)?;
387 let a_broadcast = self.broadcast_to(&result_shape)?;
388 let b_broadcast = other.broadcast_to(&result_shape)?;
389
390 let numel = Self::shape_numel(&result_shape);
391 let ndim = result_shape.len();
392 let mut data = Vec::with_capacity(numel);
393 let mut indices = vec![0usize; ndim];
394
395 for _ in 0..numel {
396 let mut off_a = a_broadcast.offset;
397 let mut off_b = b_broadcast.offset;
398 for d in 0..ndim {
399 off_a += indices[d] * a_broadcast.strides[d];
400 off_b += indices[d] * b_broadcast.strides[d];
401 }
402 let va = a_broadcast.buffer.get(off_a).ok_or_else(|| {
403 RuntimeError::InvalidOperation(format!(
404 "broadcast binop: left operand index {} out of bounds (buffer len {})",
405 off_a,
406 a_broadcast.buffer.len()
407 ))
408 })?;
409 let vb = b_broadcast.buffer.get(off_b).ok_or_else(|| {
410 RuntimeError::InvalidOperation(format!(
411 "broadcast binop: right operand index {} out of bounds (buffer len {})",
412 off_b,
413 b_broadcast.buffer.len()
414 ))
415 })?;
416 data.push(op(va, vb));
417
418 for d in (0..ndim).rev() {
419 indices[d] += 1;
420 if indices[d] < result_shape[d] {
421 break;
422 }
423 indices[d] = 0;
424 }
425 }
426
427 Ok(Tensor {
428 buffer: Buffer::from_vec(data),
429 shape: result_shape.clone(),
430 strides: Self::compute_strides(&result_shape),
431 offset: 0,
432 })
433 }
434
435 fn broadcast_result_shape(a: &[usize], b: &[usize]) -> Result<Vec<usize>, RuntimeError> {
437 let max_ndim = a.len().max(b.len());
438 let mut result = Vec::with_capacity(max_ndim);
439 for i in 0..max_ndim {
440 let da = if i < max_ndim - a.len() { 1 } else { a[i - (max_ndim - a.len())] };
441 let db = if i < max_ndim - b.len() { 1 } else { b[i - (max_ndim - b.len())] };
442 if da == db {
443 result.push(da);
444 } else if da == 1 {
445 result.push(db);
446 } else if db == 1 {
447 result.push(da);
448 } else {
449 return Err(RuntimeError::ShapeMismatch {
450 expected: da,
451 got: db,
452 });
453 }
454 }
455 Ok(result)
456 }
457
458 fn elementwise_binop_simd(
463 &self,
464 other: &Tensor,
465 op: BinOp,
466 fallback: impl Fn(f64, f64) -> f64,
467 ) -> Result<Tensor, RuntimeError> {
468 if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
469 let a = self.buffer.borrow_data();
471 let b = other.buffer.borrow_data();
472 let data = tensor_simd::simd_binop(&a, &b, op);
473 return Ok(Tensor {
474 buffer: Buffer::from_vec(data),
475 shape: self.shape.clone(),
476 strides: Self::compute_strides(&self.shape),
477 offset: 0,
478 });
479 }
480 self.elementwise_binop(other, fallback)
482 }
483
484 pub fn add(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
486 self.elementwise_binop_simd(other, BinOp::Add, |a, b| a + b)
487 }
488
489 pub fn sub(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
491 self.elementwise_binop_simd(other, BinOp::Sub, |a, b| a - b)
492 }
493
494 pub fn mul_elem(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
496 self.elementwise_binop_simd(other, BinOp::Mul, |a, b| a * b)
497 }
498
499 pub fn div_elem(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
501 self.elementwise_binop_simd(other, BinOp::Div, |a, b| a / b)
502 }
503
504 pub fn fused_mul_add(&self, b: &Tensor, c: &Tensor) -> Result<Tensor, RuntimeError> {
510 if self.shape != b.shape || self.shape != c.shape {
511 return Err(RuntimeError::InvalidOperation(
512 "broadcast_fma: all three tensors must have the same shape".to_string(),
513 ));
514 }
515 if self.is_contiguous() && b.is_contiguous() && c.is_contiguous() {
516 let a_data = self.buffer.borrow_data();
517 let b_data = b.buffer.borrow_data();
518 let c_data = c.buffer.borrow_data();
519 let n = a_data.len();
520 let mut out = vec![0.0f64; n];
521 for i in 0..n {
524 out[i] = a_data[i] * b_data[i] + c_data[i];
525 }
526 return Ok(Tensor {
527 buffer: Buffer::from_vec(out),
528 shape: self.shape.clone(),
529 strides: Self::compute_strides(&self.shape),
530 offset: 0,
531 });
532 }
533 let temp = self.mul_elem(b)?;
535 temp.add(c)
536 }
537
538 pub fn elem_pow(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
542 self.elementwise_binop(other, |a, b| a.powf(b))
543 }
544
545 pub fn elem_min(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
547 self.elementwise_binop(other, |a, b| a.min(b))
548 }
549
550 pub fn elem_max(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
552 self.elementwise_binop(other, |a, b| a.max(b))
553 }
554
555 pub fn elem_atan2(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
557 self.elementwise_binop(other, |a, b| a.atan2(b))
558 }
559
560 pub fn elem_hypot(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
562 self.elementwise_binop(other, |a, b| a.hypot(b))
563 }
564
565 pub fn map(&self, f: impl Fn(f64) -> f64) -> Tensor {
567 let data: Vec<f64> = self.to_vec().iter().map(|&x| f(x)).collect();
568 Tensor {
569 buffer: Buffer::from_vec(data),
570 shape: self.shape.clone(),
571 strides: Self::compute_strides(&self.shape),
572 offset: 0,
573 }
574 }
575
576 pub fn map_simd(&self, op: UnaryOp) -> Tensor {
581 let src = self.to_vec();
582 let data = tensor_simd::simd_unary(&src, op);
583 Tensor {
584 buffer: Buffer::from_vec(data),
585 shape: self.shape.clone(),
586 strides: Self::compute_strides(&self.shape),
587 offset: 0,
588 }
589 }
590
591 pub fn sum(&self) -> f64 {
595 let data = self.buffer.borrow_data();
596 binned_sum_f64(&data)
597 }
598
599 pub fn binned_sum(&self) -> f64 {
603 let data = self.buffer.borrow_data();
604 accumulator::binned_sum_f64(&data)
605 }
606
607 pub fn dispatched_sum(&self, ctx: &dispatch::ReductionContext) -> f64 {
611 let data = self.buffer.borrow_data();
612 dispatch::dispatch_sum_f64(&data, ctx)
613 }
614
615 pub fn mean(&self) -> f64 {
617 let n = self.len();
618 if n == 0 {
619 return 0.0;
620 }
621 self.sum() / n as f64
622 }
623
624 pub fn dispatched_mean(&self, ctx: &dispatch::ReductionContext) -> f64 {
626 let n = self.len();
627 if n == 0 {
628 return 0.0;
629 }
630 self.dispatched_sum(ctx) / n as f64
631 }
632
633 pub fn sum_axis(&self, axis: usize) -> Result<Tensor, RuntimeError> {
643 let ndim = self.ndim();
644 if axis >= ndim {
645 return Err(RuntimeError::IndexOutOfBounds {
646 index: axis,
647 length: ndim,
648 });
649 }
650
651 let mut out_shape = self.shape.clone();
653 out_shape[axis] = 1;
654 let out_numel = Self::shape_numel(&out_shape);
655 let out_strides = Self::compute_strides(&out_shape);
656
657 let data = self.to_vec();
658 let axis_len = self.shape[axis];
659 let mut result = vec![0.0f64; out_numel];
660
661 let mut indices = vec![0usize; ndim];
663 for out_idx in 0..out_numel {
664 {
666 let mut remaining = out_idx;
667 for d in 0..ndim {
668 indices[d] = remaining / out_strides[d];
669 remaining %= out_strides[d];
670 }
671 }
672
673 let mut acc = BinnedAccumulatorF64::new();
674 for k in 0..axis_len {
675 let mut flat = self.offset;
677 for d in 0..ndim {
678 let idx = if d == axis { k } else { indices[d] };
679 flat += idx * self.strides[d];
680 }
681 acc.add(data[flat]);
682 }
683 result[out_idx] = acc.finalize();
684 }
685
686 Tensor::from_vec(result, &out_shape)
687 }
688
689 pub fn neg(&self) -> Tensor {
693 self.map(|x| -x)
694 }
695
696 pub fn transpose(&self) -> Tensor {
699 let ndim = self.ndim();
700 if ndim <= 1 {
701 return self.clone();
702 }
703 let mut new_shape = self.shape.clone();
705 let mut new_strides = self.strides.clone();
706 new_shape.reverse();
707 new_strides.reverse();
708 Tensor {
709 buffer: self.buffer.clone(), shape: new_shape,
711 strides: new_strides,
712 offset: self.offset,
713 }
714 }
715
716 pub fn transpose_axes(&self, axes: &[usize]) -> Result<Tensor, RuntimeError> {
720 let ndim = self.ndim();
721 if axes.len() != ndim {
722 return Err(RuntimeError::InvalidOperation(
723 format!("transpose_axes: expected {} axes, got {}", ndim, axes.len()),
724 ));
725 }
726 let mut seen = vec![false; ndim];
728 for &ax in axes {
729 if ax >= ndim {
730 return Err(RuntimeError::IndexOutOfBounds { index: ax, length: ndim });
731 }
732 if seen[ax] {
733 return Err(RuntimeError::InvalidOperation(
734 format!("transpose_axes: duplicate axis {ax}"),
735 ));
736 }
737 seen[ax] = true;
738 }
739 let new_shape: Vec<usize> = axes.iter().map(|&ax| self.shape[ax]).collect();
740 let new_strides: Vec<usize> = axes.iter().map(|&ax| self.strides[ax]).collect();
741 Ok(Tensor {
742 buffer: self.buffer.clone(),
743 shape: new_shape,
744 strides: new_strides,
745 offset: self.offset,
746 })
747 }
748
749 pub fn scalar_mul(&self, s: f64) -> Tensor {
751 self.map(|x| x * s)
752 }
753
754 pub fn from_vec_unchecked(data: Vec<f64>, shape: &[usize]) -> Tensor {
759 Self::from_vec(data, shape).expect("Tensor::from_vec_unchecked: shape mismatch")
760 }
761
762 pub fn add_unchecked(&self, other: &Tensor) -> Tensor {
764 self.add(other).expect("Tensor::add shape mismatch")
765 }
766
767 pub fn sub_unchecked(&self, other: &Tensor) -> Tensor {
769 self.sub(other).expect("Tensor::sub shape mismatch")
770 }
771
772 pub fn mul_elem_unchecked(&self, other: &Tensor) -> Tensor {
774 self.mul_elem(other).expect("Tensor::mul_elem shape mismatch")
775 }
776
777 pub fn div_elem_unchecked(&self, other: &Tensor) -> Tensor {
779 self.div_elem(other).expect("Tensor::div_elem shape mismatch")
780 }
781
782 pub fn matmul_unchecked(&self, other: &Tensor) -> Tensor {
784 self.matmul(other).expect("Tensor::matmul dimension mismatch")
785 }
786
787 pub fn matmul(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
791 if self.ndim() != 2 || other.ndim() != 2 {
792 return Err(RuntimeError::InvalidOperation(
793 "matmul requires 2-D tensors".to_string(),
794 ));
795 }
796 let m = self.shape[0];
797 let k = self.shape[1];
798 let k2 = other.shape[0];
799 let n = other.shape[1];
800 if k != k2 {
801 return Err(RuntimeError::DimensionMismatch {
802 expected: k,
803 got: k2,
804 });
805 }
806
807 let a = self.to_vec();
808 let b = other.to_vec();
809
810 #[cfg(feature = "parallel")]
813 {
814 if m >= 256 || n >= 256 || k >= 256 {
815 return Self::matmul_parallel_mode_a(&a, &b, m, n, k);
816 }
817 }
818
819 if m >= 64 || n >= 64 || k >= 64 {
824 return Self::matmul_tiled(&a, &b, m, n, k);
825 }
826
827 Self::matmul_sequential(&a, &b, m, n, k)
829 }
830
831 fn matmul_sequential(
833 a: &[f64], b: &[f64], m: usize, n: usize, k: usize,
834 ) -> Result<Tensor, RuntimeError> {
835 let mut result = vec![0.0f64; m * n];
836 for i in 0..m {
837 for j in 0..n {
838 let mut acc = KahanAccumulatorF64::new();
839 for p in 0..k {
840 acc.add(a[i * k + p] * b[p * n + j]);
841 }
842 result[i * n + j] = acc.finalize();
843 }
844 }
845 Tensor::from_vec(result, &[m, n])
846 }
847
848 fn matmul_tiled(
855 a: &[f64], b: &[f64], m: usize, n: usize, k: usize,
856 ) -> Result<Tensor, RuntimeError> {
857 let engine = TiledMatmul::new();
858 let result = engine.matmul(a, m, k, b, n);
859 Tensor::from_vec(result, &[m, n])
860 }
861
862 #[cfg(feature = "parallel")]
873 fn matmul_parallel_mode_a(
874 a: &[f64], b: &[f64], m: usize, n: usize, k: usize,
875 ) -> Result<Tensor, RuntimeError> {
876 use rayon::prelude::*;
877 use cjc_repro::KahanAccumulatorF64;
878
879 if m >= 512 && n >= 512 {
884 let band_size = (m + rayon::current_num_threads() - 1) / rayon::current_num_threads();
887 let band_size = band_size.max(64); let mut result = vec![0.0f64; m * n];
889
890 result
891 .par_chunks_mut(band_size * n)
892 .enumerate()
893 .for_each(|(band_idx, band)| {
894 let i_start = band_idx * band_size;
895 let i_end = (i_start + band_size).min(m);
896 let band_m = i_end - i_start;
897 let a_band = &a[i_start * k .. i_end * k];
898 let engine = crate::tensor_tiled::TiledMatmul::new();
899 let tiled_result = engine.matmul(a_band, band_m, k, b, n);
900 band[..band_m * n].copy_from_slice(&tiled_result);
901 });
902
903 return Tensor::from_vec(result, &[m, n]);
904 }
905
906 let mut result = vec![0.0f64; m * n];
908 result
909 .par_chunks_mut(n)
910 .enumerate()
911 .for_each(|(i, row)| {
912 for j in 0..n {
913 let mut acc = KahanAccumulatorF64::new();
914 for p in 0..k {
915 acc.add(a[i * k + p] * b[p * n + j]);
916 }
917 row[j] = acc.finalize();
918 }
919 });
920
921 Tensor::from_vec(result, &[m, n])
922 }
923
924 pub fn bmm(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
932 if self.ndim() < 2 || other.ndim() < 2 {
933 return Err(RuntimeError::InvalidOperation(
934 "bmm requires at least 2-D tensors".to_string(),
935 ));
936 }
937 if self.ndim() == 2 && other.ndim() == 2 {
938 return self.matmul(other);
939 }
940 if self.ndim() != other.ndim() {
941 return Err(RuntimeError::InvalidOperation(
942 format!(
943 "bmm requires same number of dimensions, got {} and {}",
944 self.ndim(),
945 other.ndim()
946 ),
947 ));
948 }
949 let nd = self.ndim();
950 let batch_dims_a = &self.shape[..nd - 2];
951 let batch_dims_b = &other.shape[..nd - 2];
952 if batch_dims_a != batch_dims_b {
953 return Err(RuntimeError::InvalidOperation(
954 format!(
955 "bmm batch dimensions mismatch: {:?} vs {:?}",
956 batch_dims_a, batch_dims_b
957 ),
958 ));
959 }
960 let m = self.shape[nd - 2];
961 let k = self.shape[nd - 1];
962 let k2 = other.shape[nd - 2];
963 let n = other.shape[nd - 1];
964 if k != k2 {
965 return Err(RuntimeError::DimensionMismatch {
966 expected: k,
967 got: k2,
968 });
969 }
970
971 let batch_size: usize = batch_dims_a.iter().product();
972 let a = self.to_vec();
973 let b = other.to_vec();
974 let mat_a_stride = m * k;
975 let mat_b_stride = k * n;
976 let mat_c_stride = m * n;
977 let mut result = vec![0.0f64; batch_size * mat_c_stride];
978
979 let compute_batch = |batch: usize, c_slice: &mut [f64]| {
981 let a_slice = &a[batch * mat_a_stride..(batch + 1) * mat_a_stride];
982 let b_slice = &b[batch * mat_b_stride..(batch + 1) * mat_b_stride];
983
984 if m >= 64 || n >= 64 || k >= 64 {
985 let engine = crate::tensor_tiled::TiledMatmul::new();
986 let tiled = engine.matmul(a_slice, m, k, b_slice, n);
987 c_slice.copy_from_slice(&tiled);
988 } else {
989 for i in 0..m {
990 for j in 0..n {
991 let mut acc = KahanAccumulatorF64::new();
992 for p in 0..k {
993 acc.add(a_slice[i * k + p] * b_slice[p * n + j]);
994 }
995 c_slice[i * n + j] = acc.finalize();
996 }
997 }
998 }
999 };
1000
1001 #[cfg(feature = "parallel")]
1003 {
1004 if batch_size > 1 && m * k >= 4096 {
1005 use rayon::prelude::*;
1006 result
1007 .par_chunks_mut(mat_c_stride)
1008 .enumerate()
1009 .for_each(|(batch, c_slice)| {
1010 compute_batch(batch, c_slice);
1011 });
1012
1013 let mut out_shape = batch_dims_a.to_vec();
1014 out_shape.push(m);
1015 out_shape.push(n);
1016 return Tensor::from_vec(result, &out_shape);
1017 }
1018 }
1019
1020 for batch in 0..batch_size {
1022 let c_off = batch * mat_c_stride;
1023 compute_batch(batch, &mut result[c_off..c_off + mat_c_stride]);
1024 }
1025
1026 let mut out_shape = batch_dims_a.to_vec();
1027 out_shape.push(m);
1028 out_shape.push(n);
1029 Tensor::from_vec(result, &out_shape)
1030 }
1031
1032 pub fn softmax(&self) -> Result<Tensor, RuntimeError> {
1040 if self.ndim() == 0 {
1041 return Err(RuntimeError::InvalidOperation(
1042 "softmax requires at least 1-D tensor".to_string(),
1043 ));
1044 }
1045 let data_ref;
1047 let data_vec;
1048 let data: &[f64] = if self.is_contiguous() && self.offset == 0 {
1049 data_ref = self.buffer.borrow_data();
1050 &data_ref
1051 } else {
1052 data_vec = self.to_vec();
1053 &data_vec
1054 };
1055 let n = *self.shape.last().unwrap(); let outer: usize = data.len() / n; let mut result = vec![0.0f64; data.len()];
1058
1059 for row in 0..outer {
1060 let start = row * n;
1061 let end = start + n;
1062 let slice = &data[start..end];
1063
1064 let mut max_val = f64::NEG_INFINITY;
1066 for &v in slice {
1067 if v > max_val {
1068 max_val = v;
1069 }
1070 }
1071
1072 let mut exp_vals = vec![0.0f64; n];
1074 let mut sum = 0.0f64;
1075 let mut comp = 0.0f64; for i in 0..n {
1077 let e = (slice[i] - max_val).exp();
1078 exp_vals[i] = e;
1079 let y = e - comp;
1081 let t = sum + y;
1082 comp = (t - sum) - y;
1083 sum = t;
1084 }
1085
1086 if sum == 0.0 {
1088 let uniform = 1.0 / n as f64;
1090 for i in 0..n {
1091 result[start + i] = uniform;
1092 }
1093 } else {
1094 for i in 0..n {
1095 result[start + i] = exp_vals[i] / sum;
1096 }
1097 }
1098 }
1099
1100 Tensor::from_vec(result, &self.shape)
1101 }
1102
1103 pub fn layer_norm(
1114 &self,
1115 gamma: &Tensor,
1116 beta: &Tensor,
1117 eps: f64,
1118 ) -> Result<Tensor, RuntimeError> {
1119 if self.ndim() == 0 {
1120 return Err(RuntimeError::InvalidOperation(
1121 "layer_norm requires at least 1-D tensor".to_string(),
1122 ));
1123 }
1124 let d = *self.shape.last().unwrap();
1125 if gamma.len() != d || beta.len() != d {
1126 return Err(RuntimeError::InvalidOperation(
1127 format!(
1128 "layer_norm: gamma/beta length {} must match last dim {}",
1129 gamma.len(),
1130 d
1131 ),
1132 ));
1133 }
1134
1135 let data = self.to_vec();
1136 let gamma_data = gamma.to_vec();
1137 let beta_data = beta.to_vec();
1138 let outer = data.len() / d;
1139 let mut result = vec![0.0f64; data.len()];
1140
1141 for row in 0..outer {
1142 let start = row * d;
1143 let slice = &data[start..start + d];
1144
1145 let mean = binned_sum_f64(slice) / d as f64;
1147
1148 let diffs: Vec<f64> = slice.iter().map(|&x| {
1150 let diff = x - mean;
1151 diff * diff
1152 }).collect();
1153 let variance = binned_sum_f64(&diffs) / d as f64;
1154
1155 let inv_std = 1.0 / (variance + eps).sqrt();
1157 for i in 0..d {
1158 let normalized = (slice[i] - mean) * inv_std;
1159 result[start + i] = gamma_data[i] * normalized + beta_data[i];
1160 }
1161 }
1162
1163 Tensor::from_vec(result, &self.shape)
1164 }
1165
1166 fn map_elementwise(&self, f: impl Fn(f64) -> f64) -> Tensor {
1172 if self.is_contiguous() && self.offset == 0 && self.buffer.refcount() == 1 {
1173 let mut data = self.buffer.borrow_data().clone();
1175 for x in data.iter_mut() {
1176 *x = f(*x);
1177 }
1178 Tensor::from_vec(data, &self.shape).unwrap()
1179 } else {
1180 let data = self.to_vec();
1182 let result: Vec<f64> = data.iter().map(|&x| f(x)).collect();
1183 Tensor::from_vec(result, &self.shape).unwrap()
1184 }
1185 }
1186
1187 pub fn relu(&self) -> Tensor {
1189 self.map_elementwise(|x| if x > 0.0 { x } else { 0.0 })
1190 }
1191
1192 pub fn sigmoid(&self) -> Tensor {
1194 self.map_elementwise(|x| 1.0 / (1.0 + (-x).exp()))
1195 }
1196
1197 pub fn tanh_activation(&self) -> Tensor {
1199 self.map_elementwise(|x| x.tanh())
1200 }
1201
1202 pub fn leaky_relu(&self, alpha: f64) -> Tensor {
1204 self.map_elementwise(move |x| if x > 0.0 { x } else { alpha * x })
1205 }
1206
1207 pub fn silu(&self) -> Tensor {
1209 let data = self.to_vec();
1210 let result: Vec<f64> = data.iter().map(|&x| x / (1.0 + (-x).exp())).collect();
1211 Tensor::from_vec(result, &self.shape).unwrap()
1212 }
1213
1214 pub fn mish(&self) -> Tensor {
1216 let data = self.to_vec();
1217 let result: Vec<f64> = data.iter().map(|&x| {
1218 let sp = (1.0 + x.exp()).ln();
1219 x * sp.tanh()
1220 }).collect();
1221 Tensor::from_vec(result, &self.shape).unwrap()
1222 }
1223
1224 pub fn argmax(&self) -> usize {
1226 let data = self.to_vec();
1227 let mut best_idx = 0;
1228 let mut best_val = f64::NEG_INFINITY;
1229 for (i, &v) in data.iter().enumerate() {
1230 if v > best_val || (v == best_val && i < best_idx) {
1231 best_val = v;
1232 best_idx = i;
1233 }
1234 }
1235 best_idx
1236 }
1237
1238 pub fn argmin(&self) -> usize {
1240 let data = self.to_vec();
1241 let mut best_idx = 0;
1242 let mut best_val = f64::INFINITY;
1243 for (i, &v) in data.iter().enumerate() {
1244 if v < best_val || (v == best_val && i < best_idx) {
1245 best_val = v;
1246 best_idx = i;
1247 }
1248 }
1249 best_idx
1250 }
1251
1252 pub fn clamp(&self, min: f64, max: f64) -> Tensor {
1254 let data = self.to_vec();
1255 let result: Vec<f64> = data.iter().map(|&x| x.max(min).min(max)).collect();
1256 Tensor::from_vec(result, &self.shape).unwrap()
1257 }
1258
1259 pub fn one_hot(indices: &[usize], depth: usize) -> Result<Tensor, RuntimeError> {
1262 let n = indices.len();
1263 let mut data = vec![0.0; n * depth];
1264 for (i, &idx) in indices.iter().enumerate() {
1265 if idx >= depth {
1266 return Err(RuntimeError::InvalidOperation(format!(
1267 "one_hot: index {idx} >= depth {depth}"
1268 )));
1269 }
1270 data[i * depth + idx] = 1.0;
1271 }
1272 Tensor::from_vec(data, &[n, depth])
1273 }
1274
1275 pub fn cat(tensors: &[&Tensor], axis: usize) -> Result<Tensor, RuntimeError> {
1281 if tensors.is_empty() {
1282 return Err(RuntimeError::InvalidOperation("cat: no tensors".to_string()));
1283 }
1284 let ndim = tensors[0].ndim();
1285 if axis >= ndim {
1286 return Err(RuntimeError::InvalidOperation(
1287 format!("cat: axis {axis} out of bounds for {ndim}D tensor"),
1288 ));
1289 }
1290 for (i, t) in tensors.iter().enumerate().skip(1) {
1291 if t.ndim() != ndim {
1292 return Err(RuntimeError::InvalidOperation(
1293 format!("cat: tensor {i} has different ndim"),
1294 ));
1295 }
1296 for d in 0..ndim {
1297 if d != axis && t.shape[d] != tensors[0].shape[d] {
1298 return Err(RuntimeError::InvalidOperation(
1299 format!("cat: shape mismatch at dim {d}"),
1300 ));
1301 }
1302 }
1303 }
1304 let mut out_shape = tensors[0].shape.clone();
1305 for t in tensors.iter().skip(1) {
1306 out_shape[axis] += t.shape[axis];
1307 }
1308 let total = out_shape.iter().product::<usize>();
1309 let mut result = vec![0.0; total];
1310 let mut out_strides = vec![1usize; ndim];
1311 for d in (0..ndim - 1).rev() {
1312 out_strides[d] = out_strides[d + 1] * out_shape[d + 1];
1313 }
1314 let mut offset = 0;
1315 for t in tensors {
1316 let t_data = t.to_vec();
1317 let t_total: usize = t.shape.iter().product();
1318 let mut t_strides = vec![1usize; ndim];
1319 for d in (0..ndim - 1).rev() {
1320 t_strides[d] = t_strides[d + 1] * t.shape[d + 1];
1321 }
1322 for idx in 0..t_total {
1323 let mut remaining = idx;
1324 let mut out_flat = 0;
1325 for d in 0..ndim {
1326 let coord = remaining / t_strides[d];
1327 remaining %= t_strides[d];
1328 let out_coord = if d == axis { coord + offset } else { coord };
1329 out_flat += out_coord * out_strides[d];
1330 }
1331 result[out_flat] = t_data[idx];
1332 }
1333 offset += t.shape[axis];
1334 }
1335 Tensor::from_vec(result, &out_shape)
1336 }
1337
1338 pub fn stack(tensors: &[&Tensor], axis: usize) -> Result<Tensor, RuntimeError> {
1340 if tensors.is_empty() {
1341 return Err(RuntimeError::InvalidOperation("stack: no tensors".to_string()));
1342 }
1343 let base_shape = &tensors[0].shape;
1344 let ndim = base_shape.len();
1345 if axis > ndim {
1346 return Err(RuntimeError::InvalidOperation(
1347 format!("stack: axis {axis} out of bounds"),
1348 ));
1349 }
1350 for (i, t) in tensors.iter().enumerate().skip(1) {
1351 if &t.shape != base_shape {
1352 return Err(RuntimeError::InvalidOperation(
1353 format!("stack: tensor {i} shape mismatch"),
1354 ));
1355 }
1356 }
1357 let mut out_shape = Vec::with_capacity(ndim + 1);
1358 for d in 0..axis { out_shape.push(base_shape[d]); }
1359 out_shape.push(tensors.len());
1360 for d in axis..ndim { out_shape.push(base_shape[d]); }
1361 let total: usize = out_shape.iter().product();
1362 let mut result = vec![0.0; total];
1363 let inner_size: usize = base_shape[axis..].iter().product::<usize>().max(1);
1364 let outer_size: usize = base_shape[..axis].iter().product::<usize>().max(1);
1365 for (t_idx, t) in tensors.iter().enumerate() {
1366 let t_data = t.to_vec();
1367 for outer in 0..outer_size {
1368 for inner in 0..inner_size {
1369 let src = outer * inner_size + inner;
1370 let dst = outer * (tensors.len() * inner_size) + t_idx * inner_size + inner;
1371 if src < t_data.len() && dst < result.len() {
1372 result[dst] = t_data[src];
1373 }
1374 }
1375 }
1376 }
1377 Tensor::from_vec(result, &out_shape)
1378 }
1379
1380 pub fn topk(&self, k: usize) -> Result<(Tensor, Vec<usize>), RuntimeError> {
1382 let data = self.to_vec();
1383 let n = data.len();
1384 if k > n {
1385 return Err(RuntimeError::InvalidOperation(
1386 format!("topk: k={k} exceeds data length {n}"),
1387 ));
1388 }
1389 let mut indexed: Vec<(usize, f64)> = data.into_iter().enumerate().collect();
1390 indexed.sort_by(|a, b| b.1.total_cmp(&a.1).then(a.0.cmp(&b.0)));
1391 let top_k: Vec<(usize, f64)> = indexed[..k].to_vec();
1392 let values: Vec<f64> = top_k.iter().map(|&(_, v)| v).collect();
1393 let indices: Vec<usize> = top_k.iter().map(|&(i, _)| i).collect();
1394 Ok((Tensor::from_vec(values, &[k])?, indices))
1395 }
1396
1397 pub fn gelu(&self) -> Tensor {
1399 let data = self.to_vec();
1400 let sqrt_2_over_pi = (2.0_f64 / std::f64::consts::PI).sqrt();
1401 let result: Vec<f64> = data.iter().map(|&x| {
1402 let inner = sqrt_2_over_pi * (x + 0.044715 * x * x * x);
1403 0.5 * x * (1.0 + inner.tanh())
1404 }).collect();
1405 Tensor::from_vec(result, &self.shape).unwrap()
1406 }
1407
1408 pub fn linear(
1414 &self,
1415 weight: &Tensor,
1416 bias: &Tensor,
1417 ) -> Result<Tensor, RuntimeError> {
1418 if weight.ndim() != 2 {
1419 return Err(RuntimeError::InvalidOperation(
1420 "linear: weight must be 2-D [out_features, in_features]".to_string(),
1421 ));
1422 }
1423 let out_features = weight.shape[0];
1424 let in_features = weight.shape[1];
1425 let last_dim = *self.shape.last().ok_or_else(|| {
1426 RuntimeError::InvalidOperation("linear: input must be at least 1-D".to_string())
1427 })?;
1428 if last_dim != in_features {
1429 return Err(RuntimeError::DimensionMismatch {
1430 expected: in_features,
1431 got: last_dim,
1432 });
1433 }
1434 if bias.len() != out_features {
1435 return Err(RuntimeError::InvalidOperation(
1436 format!(
1437 "linear: bias length {} must match out_features {}",
1438 bias.len(),
1439 out_features
1440 ),
1441 ));
1442 }
1443
1444 let data = self.to_vec();
1445 let w = weight.to_vec();
1446 let b = bias.to_vec();
1447 let outer = data.len() / in_features;
1448 let mut result = vec![0.0f64; outer * out_features];
1449
1450 for row in 0..outer {
1451 let x_start = row * in_features;
1452 let x_slice = &data[x_start..x_start + in_features];
1453 let y_start = row * out_features;
1454 for j in 0..out_features {
1455 let w_start = j * in_features;
1456 let mut acc = BinnedAccumulatorF64::new();
1457 for p in 0..in_features {
1458 acc.add(x_slice[p] * w[w_start + p]);
1459 }
1460 result[y_start + j] = acc.finalize() + b[j];
1461 }
1462 }
1463
1464 let mut out_shape = self.shape[..self.shape.len() - 1].to_vec();
1465 out_shape.push(out_features);
1466 Tensor::from_vec(result, &out_shape)
1467 }
1468
1469 pub fn conv1d(
1473 &self,
1474 filters: &Tensor,
1475 bias: &Tensor,
1476 ) -> Result<Tensor, RuntimeError> {
1477 if self.ndim() != 1 {
1478 return Err(RuntimeError::InvalidOperation(
1479 "conv1d: input must be 1-D [signal_len]".to_string(),
1480 ));
1481 }
1482 if filters.ndim() != 2 {
1483 return Err(RuntimeError::InvalidOperation(
1484 "conv1d: filters must be 2-D [out_channels, kernel_size]".to_string(),
1485 ));
1486 }
1487 let signal_len = self.shape[0];
1488 let out_channels = filters.shape[0];
1489 let kernel_size = filters.shape[1];
1490 if signal_len < kernel_size {
1491 return Err(RuntimeError::InvalidOperation(
1492 format!(
1493 "conv1d: signal_len {} < kernel_size {}",
1494 signal_len, kernel_size
1495 ),
1496 ));
1497 }
1498 if bias.len() != out_channels {
1499 return Err(RuntimeError::InvalidOperation(
1500 format!(
1501 "conv1d: bias length {} must match out_channels {}",
1502 bias.len(), out_channels
1503 ),
1504 ));
1505 }
1506 let out_len = signal_len - kernel_size + 1;
1507 let s = self.to_vec();
1508 let f = filters.to_vec();
1509 let b = bias.to_vec();
1510 let mut result = vec![0.0; out_channels * out_len];
1511 kernel_fns::conv1d_raw(&s, &f, &b, &mut result, signal_len, out_channels, kernel_size);
1512 Tensor::from_vec(result, &[out_channels, out_len])
1513 }
1514
1515 pub fn conv2d(
1529 &self,
1530 filters: &Tensor,
1531 bias: &Tensor,
1532 stride: usize,
1533 ) -> Result<Tensor, RuntimeError> {
1534 if self.ndim() != 4 {
1535 return Err(RuntimeError::InvalidOperation(
1536 "conv2d: input must be 4-D [N, C_in, H, W]".to_string(),
1537 ));
1538 }
1539 if filters.ndim() != 4 {
1540 return Err(RuntimeError::InvalidOperation(
1541 "conv2d: filters must be 4-D [C_out, C_in, kH, kW]".to_string(),
1542 ));
1543 }
1544 if stride == 0 {
1545 return Err(RuntimeError::InvalidOperation(
1546 "conv2d: stride must be >= 1".to_string(),
1547 ));
1548 }
1549
1550 let n = self.shape[0];
1551 let c_in = self.shape[1];
1552 let h_in = self.shape[2];
1553 let w_in = self.shape[3];
1554
1555 let c_out = filters.shape[0];
1556 let c_in_check = filters.shape[1];
1557 let kh = filters.shape[2];
1558 let kw = filters.shape[3];
1559
1560 if c_in != c_in_check {
1561 return Err(RuntimeError::InvalidOperation(format!(
1562 "conv2d: input C_in={} does not match filter C_in={}",
1563 c_in, c_in_check
1564 )));
1565 }
1566 if h_in < kh || w_in < kw {
1567 return Err(RuntimeError::InvalidOperation(format!(
1568 "conv2d: input spatial [{}, {}] is smaller than kernel [{}, {}]",
1569 h_in, w_in, kh, kw
1570 )));
1571 }
1572 if bias.len() != c_out {
1573 return Err(RuntimeError::InvalidOperation(format!(
1574 "conv2d: bias length {} must match C_out={}",
1575 bias.len(), c_out
1576 )));
1577 }
1578
1579 let h_out = (h_in - kh) / stride + 1;
1580 let w_out = (w_in - kw) / stride + 1;
1581
1582 let inp = self.to_vec();
1583 let flt = filters.to_vec();
1584 let b = bias.to_vec();
1585 let mut result = vec![0.0f64; n * c_out * h_out * w_out];
1586
1587 kernel_fns::conv2d_raw(&inp, &flt, &b, &mut result,
1588 n, c_in, h_in, w_in, c_out, kh, kw, stride);
1589
1590 Tensor::from_vec(result, &[n, c_out, h_out, w_out])
1591 }
1592
1593 pub fn maxpool2d(&self, ph: usize, pw: usize) -> Result<Tensor, RuntimeError> {
1600 if self.ndim() != 4 {
1601 return Err(RuntimeError::InvalidOperation(
1602 "maxpool2d: input must be 4-D [N, C, H, W]".to_string(),
1603 ));
1604 }
1605 if ph == 0 || pw == 0 {
1606 return Err(RuntimeError::InvalidOperation(
1607 "maxpool2d: pool size must be >= 1".to_string(),
1608 ));
1609 }
1610
1611 let n = self.shape[0];
1612 let c = self.shape[1];
1613 let h_in = self.shape[2];
1614 let w_in = self.shape[3];
1615
1616 if h_in < ph || w_in < pw {
1617 return Err(RuntimeError::InvalidOperation(format!(
1618 "maxpool2d: input [{}, {}] smaller than pool [{}, {}]",
1619 h_in, w_in, ph, pw
1620 )));
1621 }
1622
1623 let h_out = h_in / ph;
1624 let w_out = w_in / pw;
1625
1626 let inp = self.to_vec();
1627 let mut result = vec![0.0f64; n * c * h_out * w_out];
1628
1629 kernel_fns::maxpool2d_raw(&inp, &mut result, n, c, h_in, w_in, ph, pw);
1630
1631 Tensor::from_vec(result, &[n, c, h_out, w_out])
1632 }
1633
1634 pub fn avgpool2d(&self, kernel_h: usize, kernel_w: usize, stride_h: usize, stride_w: usize) -> Result<Tensor, RuntimeError> {
1649 let shape = self.shape();
1650 if shape.len() != 3 {
1651 return Err(RuntimeError::InvalidOperation(format!("avgpool2d requires 3-D [C,H,W], got {:?}", shape)));
1652 }
1653 let (c, h, w) = (shape[0], shape[1], shape[2]);
1654 if kernel_h > h || kernel_w > w {
1655 return Err(RuntimeError::InvalidOperation("avgpool2d: kernel larger than input".into()));
1656 }
1657 let out_h = (h - kernel_h) / stride_h + 1;
1658 let out_w = (w - kernel_w) / stride_w + 1;
1659 let data = self.to_vec();
1660 let mut out = Vec::with_capacity(c * out_h * out_w);
1661 let pool_size = (kernel_h * kernel_w) as f64;
1662
1663 for ch in 0..c {
1664 for oh in 0..out_h {
1665 for ow in 0..out_w {
1666 let mut sum = 0.0f64;
1667 for kh in 0..kernel_h {
1668 for kw in 0..kernel_w {
1669 let ih = oh * stride_h + kh;
1670 let iw = ow * stride_w + kw;
1671 sum += data[ch * h * w + ih * w + iw];
1672 }
1673 }
1674 out.push(sum / pool_size);
1675 }
1676 }
1677 }
1678 Tensor::from_vec(out, &[c, out_h, out_w])
1679 }
1680
1681 pub fn scaled_dot_product_attention(
1690 queries: &Tensor,
1691 keys: &Tensor,
1692 values: &Tensor,
1693 ) -> Result<Tensor, RuntimeError> {
1694 if queries.ndim() < 2 || keys.ndim() < 2 || values.ndim() < 2 {
1695 return Err(RuntimeError::InvalidOperation(
1696 "attention: Q, K, V must be at least 2-D".to_string(),
1697 ));
1698 }
1699 let nd = queries.ndim();
1700 let d_k = queries.shape[nd - 1];
1701 let scale = 1.0 / (d_k as f64).sqrt();
1702
1703 let keys_t = keys.transpose_last_two()?;
1705
1706 let scores = queries.bmm(&keys_t)?;
1708
1709 let scores_scaled = scores.scalar_mul(scale);
1711
1712 let attn_weights = scores_scaled.softmax()?;
1714
1715 attn_weights.bmm(values)
1717 }
1718
1719 pub fn transpose_last_two(&self) -> Result<Tensor, RuntimeError> {
1723 if self.ndim() < 2 {
1724 return Err(RuntimeError::InvalidOperation(
1725 "transpose_last_two requires at least 2-D tensor".to_string(),
1726 ));
1727 }
1728 let nd = self.ndim();
1729 let rows = self.shape[nd - 2];
1730 let cols = self.shape[nd - 1];
1731 let data = self.to_vec();
1732 let batch_size: usize = self.shape[..nd - 2].iter().product::<usize>().max(1);
1733 let mat_size = rows * cols;
1734 let mut result = vec![0.0f64; data.len()];
1735
1736 for b in 0..batch_size {
1737 let off = b * mat_size;
1738 for i in 0..rows {
1739 for j in 0..cols {
1740 result[off + j * rows + i] = data[off + i * cols + j];
1741 }
1742 }
1743 }
1744
1745 let mut out_shape = self.shape.clone();
1746 out_shape[nd - 2] = cols;
1747 out_shape[nd - 1] = rows;
1748 Tensor::from_vec(result, &out_shape)
1749 }
1750
1751 pub fn from_bytes(bytes: &[u8], shape: &[usize], dtype: &str) -> Result<Tensor, RuntimeError> {
1767 let numel = Self::shape_numel(shape);
1768 match dtype {
1769 "f64" => {
1770 let expected = numel * 8;
1771 if bytes.len() != expected {
1772 return Err(RuntimeError::ShapeMismatch {
1773 expected,
1774 got: bytes.len(),
1775 });
1776 }
1777 let mut data = Vec::with_capacity(numel);
1778 for i in 0..numel {
1779 let off = i * 8;
1780 let mut buf = [0u8; 8];
1781 buf.copy_from_slice(&bytes[off..off + 8]);
1782 data.push(f64::from_le_bytes(buf));
1783 }
1784 Ok(Tensor {
1785 buffer: Buffer::from_vec(data),
1786 shape: shape.to_vec(),
1787 strides: Self::compute_strides(shape),
1788 offset: 0,
1789 })
1790 }
1791 "f32" => {
1792 let expected = numel * 4;
1793 if bytes.len() != expected {
1794 return Err(RuntimeError::ShapeMismatch {
1795 expected,
1796 got: bytes.len(),
1797 });
1798 }
1799 let mut data = Vec::with_capacity(numel);
1800 for i in 0..numel {
1801 let off = i * 4;
1802 let mut buf = [0u8; 4];
1803 buf.copy_from_slice(&bytes[off..off + 4]);
1804 data.push(f32::from_le_bytes(buf) as f64);
1805 }
1806 Ok(Tensor {
1807 buffer: Buffer::from_vec(data),
1808 shape: shape.to_vec(),
1809 strides: Self::compute_strides(shape),
1810 offset: 0,
1811 })
1812 }
1813 _ => Err(RuntimeError::InvalidOperation(
1814 format!("from_bytes: unsupported dtype '{}', expected 'f32' or 'f64'", dtype),
1815 )),
1816 }
1817 }
1818
1819 pub fn split_heads(&self, num_heads: usize) -> Result<Tensor, RuntimeError> {
1827 if self.ndim() != 3 {
1828 return Err(RuntimeError::DimensionMismatch {
1829 expected: 3,
1830 got: self.ndim(),
1831 });
1832 }
1833 let batch = self.shape[0];
1834 let seq = self.shape[1];
1835 let model_dim = self.shape[2];
1836 if model_dim % num_heads != 0 {
1837 return Err(RuntimeError::InvalidOperation(
1838 format!(
1839 "split_heads: model_dim {} not divisible by num_heads {}",
1840 model_dim, num_heads
1841 ),
1842 ));
1843 }
1844 let head_dim = model_dim / num_heads;
1845 let tensor = if self.is_contiguous() { self.clone() } else { self.to_contiguous() };
1847 let reshaped = Tensor {
1849 buffer: tensor.buffer.clone(),
1850 shape: vec![batch, seq, num_heads, head_dim],
1851 strides: Self::compute_strides(&[batch, seq, num_heads, head_dim]),
1852 offset: 0,
1853 };
1854 Ok(Tensor {
1857 buffer: reshaped.buffer,
1858 shape: vec![batch, num_heads, seq, head_dim],
1859 strides: vec![
1860 reshaped.strides[0], reshaped.strides[2], reshaped.strides[1], reshaped.strides[3], ],
1865 offset: 0,
1866 })
1867 }
1868
1869 pub fn merge_heads(&self) -> Result<Tensor, RuntimeError> {
1872 if self.ndim() != 4 {
1873 return Err(RuntimeError::DimensionMismatch {
1874 expected: 4,
1875 got: self.ndim(),
1876 });
1877 }
1878 let batch = self.shape[0];
1879 let num_heads = self.shape[1];
1880 let seq = self.shape[2];
1881 let head_dim = self.shape[3];
1882 let transposed = Tensor {
1885 buffer: self.buffer.clone(),
1886 shape: vec![batch, seq, num_heads, head_dim],
1887 strides: vec![
1888 self.strides[0],
1889 self.strides[2], self.strides[1], self.strides[3],
1892 ],
1893 offset: self.offset,
1894 };
1895 let contig = transposed.to_contiguous();
1897 let model_dim = num_heads * head_dim;
1898 Ok(Tensor {
1899 buffer: contig.buffer,
1900 shape: vec![batch, seq, model_dim],
1901 strides: Self::compute_strides(&[batch, seq, model_dim]),
1902 offset: 0,
1903 })
1904 }
1905
1906 pub fn view_reshape(&self, new_shape: &[usize]) -> Result<Tensor, RuntimeError> {
1909 self.reshape(new_shape)
1910 }
1911
1912 pub fn argsort(&self) -> Tensor {
1919 let data = self.to_vec();
1920 let mut indices: Vec<usize> = (0..data.len()).collect();
1921 indices.sort_by(|&a, &b| data[a].total_cmp(&data[b]));
1922 let result: Vec<f64> = indices.iter().map(|&i| i as f64).collect();
1923 Tensor::from_vec_unchecked(result, &[data.len()])
1924 }
1925
1926 pub fn gather(&self, dim: usize, indices: &Tensor) -> Result<Tensor, RuntimeError> {
1931 let data = self.to_vec();
1932 let idx_data = indices.to_vec();
1933 if self.ndim() == 1 {
1934 let mut result = Vec::with_capacity(idx_data.len());
1935 for &idx in &idx_data {
1936 let i = idx as usize;
1937 if i >= data.len() {
1938 return Err(RuntimeError::InvalidOperation(
1939 format!("gather: index {} out of bounds for size {}", i, data.len()),
1940 ));
1941 }
1942 result.push(data[i]);
1943 }
1944 Ok(Tensor::from_vec_unchecked(result, indices.shape()))
1945 } else if self.ndim() == 2 {
1946 let rows = self.shape[0];
1947 let cols = self.shape[1];
1948 let idx_shape = indices.shape();
1949 let out_rows = idx_shape[0];
1950 let out_cols = idx_shape[1];
1951 let mut result = vec![0.0; out_rows * out_cols];
1952 for i in 0..out_rows {
1953 for j in 0..out_cols {
1954 let idx = idx_data[i * out_cols + j] as usize;
1955 let val = if dim == 0 {
1956 if idx >= rows {
1957 return Err(RuntimeError::InvalidOperation(
1958 format!("gather dim=0: index {} out of bounds for {} rows", idx, rows),
1959 ));
1960 }
1961 data[idx * cols + j]
1962 } else {
1963 if idx >= cols {
1964 return Err(RuntimeError::InvalidOperation(
1965 format!("gather dim=1: index {} out of bounds for {} cols", idx, cols),
1966 ));
1967 }
1968 data[i * cols + idx]
1969 };
1970 result[i * out_cols + j] = val;
1971 }
1972 }
1973 Ok(Tensor::from_vec_unchecked(result, idx_shape))
1974 } else {
1975 Err(RuntimeError::InvalidOperation(
1976 "gather: only 1D and 2D tensors supported".into(),
1977 ))
1978 }
1979 }
1980
1981 pub fn scatter(&self, dim: usize, indices: &Tensor, src: &Tensor) -> Result<Tensor, RuntimeError> {
1986 let mut result = self.to_vec();
1987 let idx_data = indices.to_vec();
1988 let src_data = src.to_vec();
1989 if self.ndim() == 1 {
1990 for (k, &idx) in idx_data.iter().enumerate() {
1991 let i = idx as usize;
1992 if i >= result.len() {
1993 return Err(RuntimeError::InvalidOperation(
1994 format!("scatter: index {} out of bounds for size {}", i, result.len()),
1995 ));
1996 }
1997 result[i] = src_data[k];
1998 }
1999 Ok(Tensor::from_vec_unchecked(result, self.shape()))
2000 } else if self.ndim() == 2 {
2001 let cols = self.shape[1];
2002 let idx_shape = indices.shape();
2003 let out_cols = idx_shape[1];
2004 let out_rows = idx_shape[0];
2005 for i in 0..out_rows {
2006 for j in 0..out_cols {
2007 let idx = idx_data[i * out_cols + j] as usize;
2008 let src_val = src_data[i * out_cols + j];
2009 if dim == 0 {
2010 if idx >= self.shape[0] {
2011 return Err(RuntimeError::InvalidOperation(
2012 format!("scatter dim=0: index {} out of bounds for {} rows", idx, self.shape[0]),
2013 ));
2014 }
2015 result[idx * cols + j] = src_val;
2016 } else {
2017 if idx >= cols {
2018 return Err(RuntimeError::InvalidOperation(
2019 format!("scatter dim=1: index {} out of bounds for {} cols", idx, cols),
2020 ));
2021 }
2022 result[i * cols + idx] = src_val;
2023 }
2024 }
2025 }
2026 Ok(Tensor::from_vec_unchecked(result, self.shape()))
2027 } else {
2028 Err(RuntimeError::InvalidOperation(
2029 "scatter: only 1D and 2D tensors supported".into(),
2030 ))
2031 }
2032 }
2033
2034 pub fn index_select(&self, dim: usize, indices: &Tensor) -> Result<Tensor, RuntimeError> {
2038 let data = self.to_vec();
2039 let idx_data = indices.to_vec();
2040 if self.ndim() == 1 {
2041 let mut result = Vec::with_capacity(idx_data.len());
2042 for &idx in &idx_data {
2043 let i = idx as usize;
2044 if i >= data.len() {
2045 return Err(RuntimeError::InvalidOperation(
2046 format!("index_select: index {} out of bounds for size {}", i, data.len()),
2047 ));
2048 }
2049 result.push(data[i]);
2050 }
2051 Ok(Tensor::from_vec_unchecked(result, &[idx_data.len()]))
2052 } else if self.ndim() == 2 {
2053 let rows = self.shape[0];
2054 let cols = self.shape[1];
2055 let n = idx_data.len();
2056 if dim == 0 {
2057 let mut result = Vec::with_capacity(n * cols);
2058 for &idx in &idx_data {
2059 let i = idx as usize;
2060 if i >= rows {
2061 return Err(RuntimeError::InvalidOperation(
2062 format!("index_select dim=0: index {} out of bounds for {} rows", i, rows),
2063 ));
2064 }
2065 for j in 0..cols {
2066 result.push(data[i * cols + j]);
2067 }
2068 }
2069 Ok(Tensor::from_vec_unchecked(result, &[n, cols]))
2070 } else {
2071 let mut result = Vec::with_capacity(rows * n);
2072 for i in 0..rows {
2073 for &idx in &idx_data {
2074 let j = idx as usize;
2075 if j >= cols {
2076 return Err(RuntimeError::InvalidOperation(
2077 format!("index_select dim=1: index {} out of bounds for {} cols", j, cols),
2078 ));
2079 }
2080 result.push(data[i * cols + j]);
2081 }
2082 }
2083 Ok(Tensor::from_vec_unchecked(result, &[rows, n]))
2084 }
2085 } else {
2086 Err(RuntimeError::InvalidOperation(
2087 "index_select: only 1D and 2D tensors supported".into(),
2088 ))
2089 }
2090 }
2091
2092 pub fn tensor_where(&self, condition: &Tensor, other: &Tensor) -> Result<Tensor, RuntimeError> {
2099 if self.shape() != condition.shape() || self.shape() != other.shape() {
2100 return Err(RuntimeError::InvalidOperation(
2101 format!("where: shape mismatch self={:?} cond={:?} other={:?}",
2102 self.shape(), condition.shape(), other.shape()),
2103 ));
2104 }
2105 let s = self.to_vec();
2106 let c = condition.to_vec();
2107 let o = other.to_vec();
2108 let result: Vec<f64> = s.iter().zip(c.iter()).zip(o.iter())
2109 .map(|((&sv, &cv), &ov)| if cv != 0.0 { sv } else { ov })
2110 .collect();
2111 Tensor::from_vec(result, self.shape())
2112 }
2113
2114 pub fn any(&self) -> bool {
2116 let data = self.to_vec();
2117 data.iter().any(|&x| x != 0.0)
2118 }
2119
2120 pub fn all(&self) -> bool {
2122 let data = self.to_vec();
2123 data.iter().all(|&x| x != 0.0)
2124 }
2125
2126 pub fn nonzero(&self) -> Tensor {
2130 let data = self.to_vec();
2131 let indices: Vec<f64> = data.iter().enumerate()
2132 .filter(|(_, &v)| v != 0.0)
2133 .map(|(i, _)| i as f64)
2134 .collect();
2135 let len = indices.len();
2136 if len == 0 {
2137 Tensor::from_vec(vec![], &[0]).unwrap()
2138 } else {
2139 Tensor::from_vec(indices, &[len]).unwrap()
2140 }
2141 }
2142
2143 pub fn masked_fill(&self, mask: &Tensor, value: f64) -> Result<Tensor, RuntimeError> {
2145 if self.shape() != mask.shape() {
2146 return Err(RuntimeError::InvalidOperation(
2147 format!("masked_fill: shape mismatch self={:?} mask={:?}",
2148 self.shape(), mask.shape()),
2149 ));
2150 }
2151 let data = self.to_vec();
2152 let m = mask.to_vec();
2153 let result: Vec<f64> = data.iter().zip(m.iter())
2154 .map(|(&d, &mv)| if mv != 0.0 { value } else { d })
2155 .collect();
2156 Tensor::from_vec(result, self.shape())
2157 }
2158
2159 fn reduce_axis<F>(&self, axis: usize, keepdim: bool, reduce_fn: F)
2169 -> Result<Tensor, RuntimeError>
2170 where
2171 F: Fn(&[f64]) -> f64,
2172 {
2173 let ndim = self.ndim();
2174 if axis >= ndim {
2175 return Err(RuntimeError::IndexOutOfBounds {
2176 index: axis,
2177 length: ndim,
2178 });
2179 }
2180
2181 let axis_len = self.shape[axis];
2182 let mut out_shape: Vec<usize> = self.shape.clone();
2184 out_shape[axis] = 1;
2185 let out_numel = Self::shape_numel(&out_shape);
2186 let out_strides = Self::compute_strides(&out_shape);
2187
2188 let data = self.to_vec();
2189 let mut result = Vec::with_capacity(out_numel);
2190 let mut indices = vec![0usize; ndim];
2191
2192 for out_idx in 0..out_numel {
2193 {
2195 let mut remaining = out_idx;
2196 for d in 0..ndim {
2197 indices[d] = remaining / out_strides[d];
2198 remaining %= out_strides[d];
2199 }
2200 }
2201
2202 let mut vals = Vec::with_capacity(axis_len);
2204 for k in 0..axis_len {
2205 let mut flat = self.offset;
2206 for d in 0..ndim {
2207 let idx = if d == axis { k } else { indices[d] };
2208 flat += idx * self.strides[d];
2209 }
2210 vals.push(data[flat]);
2211 }
2212 result.push(reduce_fn(&vals));
2213 }
2214
2215 let final_shape = if keepdim {
2216 out_shape
2217 } else {
2218 let mut s: Vec<usize> = self.shape.iter().enumerate()
2220 .filter(|&(i, _)| i != axis)
2221 .map(|(_, &v)| v)
2222 .collect();
2223 if s.is_empty() {
2224 s.push(1); }
2226 s
2227 };
2228
2229 Tensor::from_vec(result, &final_shape)
2230 }
2231
2232 pub fn mean_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2237 self.reduce_axis(axis, keepdim, |vals| {
2238 let mut acc = BinnedAccumulatorF64::new();
2239 for &v in vals { acc.add(v); }
2240 acc.finalize() / vals.len() as f64
2241 })
2242 }
2243
2244 pub fn max_axis(&self, axis: usize, keepdim: bool) -> Result<(Tensor, Tensor), RuntimeError> {
2248 let ndim = self.ndim();
2249 if axis >= ndim {
2250 return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2251 }
2252 let axis_len = self.shape[axis];
2253 let mut out_shape = self.shape.clone();
2254 out_shape[axis] = 1;
2255 let out_numel = Self::shape_numel(&out_shape);
2256 let out_strides = Self::compute_strides(&out_shape);
2257 let data = self.to_vec();
2258 let mut values = Vec::with_capacity(out_numel);
2259 let mut idx_vals = Vec::with_capacity(out_numel);
2260 let mut indices = vec![0usize; ndim];
2261
2262 for out_idx in 0..out_numel {
2263 let mut remaining = out_idx;
2264 for d in 0..ndim {
2265 indices[d] = remaining / out_strides[d];
2266 remaining %= out_strides[d];
2267 }
2268 let mut best_val = f64::NEG_INFINITY;
2269 let mut best_idx = 0usize;
2270 for k in 0..axis_len {
2271 let mut flat = self.offset;
2272 for d in 0..ndim {
2273 let idx = if d == axis { k } else { indices[d] };
2274 flat += idx * self.strides[d];
2275 }
2276 let v = data[flat];
2277 if v > best_val {
2278 best_val = v;
2279 best_idx = k;
2280 }
2281 }
2282 values.push(best_val);
2283 idx_vals.push(best_idx as f64);
2284 }
2285
2286 let final_shape = if keepdim {
2287 out_shape
2288 } else {
2289 let mut s: Vec<usize> = self.shape.iter().enumerate()
2290 .filter(|&(i, _)| i != axis).map(|(_, &v)| v).collect();
2291 if s.is_empty() { s.push(1); }
2292 s
2293 };
2294 Ok((
2295 Tensor::from_vec(values, &final_shape)?,
2296 Tensor::from_vec(idx_vals, &final_shape)?,
2297 ))
2298 }
2299
2300 pub fn min_axis(&self, axis: usize, keepdim: bool) -> Result<(Tensor, Tensor), RuntimeError> {
2304 let ndim = self.ndim();
2305 if axis >= ndim {
2306 return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2307 }
2308 let axis_len = self.shape[axis];
2309 let mut out_shape = self.shape.clone();
2310 out_shape[axis] = 1;
2311 let out_numel = Self::shape_numel(&out_shape);
2312 let out_strides = Self::compute_strides(&out_shape);
2313 let data = self.to_vec();
2314 let mut values = Vec::with_capacity(out_numel);
2315 let mut idx_vals = Vec::with_capacity(out_numel);
2316 let mut indices = vec![0usize; ndim];
2317
2318 for out_idx in 0..out_numel {
2319 let mut remaining = out_idx;
2320 for d in 0..ndim {
2321 indices[d] = remaining / out_strides[d];
2322 remaining %= out_strides[d];
2323 }
2324 let mut best_val = f64::INFINITY;
2325 let mut best_idx = 0usize;
2326 for k in 0..axis_len {
2327 let mut flat = self.offset;
2328 for d in 0..ndim {
2329 let idx = if d == axis { k } else { indices[d] };
2330 flat += idx * self.strides[d];
2331 }
2332 let v = data[flat];
2333 if v < best_val {
2334 best_val = v;
2335 best_idx = k;
2336 }
2337 }
2338 values.push(best_val);
2339 idx_vals.push(best_idx as f64);
2340 }
2341
2342 let final_shape = if keepdim {
2343 out_shape
2344 } else {
2345 let mut s: Vec<usize> = self.shape.iter().enumerate()
2346 .filter(|&(i, _)| i != axis).map(|(_, &v)| v).collect();
2347 if s.is_empty() { s.push(1); }
2348 s
2349 };
2350 Ok((
2351 Tensor::from_vec(values, &final_shape)?,
2352 Tensor::from_vec(idx_vals, &final_shape)?,
2353 ))
2354 }
2355
2356 pub fn var_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2362 let mean_t = self.mean_axis(axis, true)?;
2363 let ndim = self.ndim();
2364 if axis >= ndim {
2365 return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2366 }
2367 let axis_len = self.shape[axis];
2368 let mut out_shape = self.shape.clone();
2369 out_shape[axis] = 1;
2370 let out_numel = Self::shape_numel(&out_shape);
2371 let out_strides = Self::compute_strides(&out_shape);
2372 let data = self.to_vec();
2373 let mean_data = mean_t.to_vec();
2374 let mut result = Vec::with_capacity(out_numel);
2375 let mut indices = vec![0usize; ndim];
2376
2377 for out_idx in 0..out_numel {
2378 let mut remaining = out_idx;
2379 for d in 0..ndim {
2380 indices[d] = remaining / out_strides[d];
2381 remaining %= out_strides[d];
2382 }
2383 let mu = mean_data[out_idx];
2384 let mut acc = BinnedAccumulatorF64::new();
2385 for k in 0..axis_len {
2386 let mut flat = self.offset;
2387 for d in 0..ndim {
2388 let idx = if d == axis { k } else { indices[d] };
2389 flat += idx * self.strides[d];
2390 }
2391 let diff = data[flat] - mu;
2392 acc.add(diff * diff);
2393 }
2394 result.push(acc.finalize() / axis_len as f64);
2395 }
2396
2397 let final_shape = if keepdim {
2398 out_shape
2399 } else {
2400 let mut s: Vec<usize> = self.shape.iter().enumerate()
2401 .filter(|&(i, _)| i != axis).map(|(_, &v)| v).collect();
2402 if s.is_empty() { s.push(1); }
2403 s
2404 };
2405 Tensor::from_vec(result, &final_shape)
2406 }
2407
2408 pub fn std_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2412 let var = self.var_axis(axis, keepdim)?;
2413 Ok(var.map(|x| x.sqrt()))
2414 }
2415
2416 pub fn prod_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2420 self.reduce_axis(axis, keepdim, |vals| {
2421 let mut product = 1.0f64;
2424 for &v in vals { product *= v; }
2425 product
2426 })
2427 }
2428
2429 pub fn sort_axis(&self, axis: usize, descending: bool) -> Result<Tensor, RuntimeError> {
2439 let ndim = self.ndim();
2440 if axis >= ndim {
2441 return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2442 }
2443 let data = self.to_vec();
2444 let axis_len = self.shape[axis];
2445 let out_shape = self.shape.clone();
2446 let out_numel = Self::shape_numel(&out_shape);
2447
2448 let mut iter_shape: Vec<usize> = Vec::new();
2450 for (i, &s) in self.shape.iter().enumerate() {
2451 if i != axis { iter_shape.push(s); }
2452 }
2453 let n_slices: usize = iter_shape.iter().product::<usize>().max(1);
2454
2455 let mut result = vec![0.0f64; out_numel];
2456
2457 let mut pos = vec![0usize; ndim];
2459 for slice_idx in 0..n_slices {
2460 let mut remaining = slice_idx;
2462 let mut dim_idx = 0;
2463 for d in 0..ndim {
2464 if d == axis {
2465 pos[d] = 0;
2466 } else {
2467 let stride = {
2468 let mut s = 1usize;
2469 let mut di = 0;
2470 for d2 in 0..ndim {
2471 if d2 == axis { continue; }
2472 if di > dim_idx { s *= self.shape[d2]; }
2473 di += 1;
2474 }
2475 s
2476 };
2477 pos[d] = remaining / stride;
2478 remaining %= stride;
2479 dim_idx += 1;
2480 }
2481 }
2482
2483 let mut vals: Vec<(f64, usize)> = Vec::with_capacity(axis_len);
2485 for k in 0..axis_len {
2486 let mut flat = self.offset;
2487 for d in 0..ndim {
2488 let idx = if d == axis { k } else { pos[d] };
2489 flat += idx * self.strides[d];
2490 }
2491 vals.push((data[flat], k));
2492 }
2493
2494 if descending {
2496 vals.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
2497 .then(a.1.cmp(&b.1)));
2498 } else {
2499 vals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
2500 .then(a.1.cmp(&b.1)));
2501 }
2502
2503 for (k, &(v, _)) in vals.iter().enumerate() {
2505 let mut flat = 0;
2506 let out_strides_local = Self::compute_strides(&out_shape);
2507 for d in 0..ndim {
2508 let idx = if d == axis { k } else { pos[d] };
2509 flat += idx * out_strides_local[d];
2510 }
2511 result[flat] = v;
2512 }
2513 }
2514
2515 Tensor::from_vec(result, &out_shape)
2516 }
2517
2518 pub fn argsort_axis(&self, axis: usize, descending: bool) -> Result<Tensor, RuntimeError> {
2523 let ndim = self.ndim();
2524 if axis >= ndim {
2525 return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2526 }
2527 let data = self.to_vec();
2528 let axis_len = self.shape[axis];
2529 let out_shape = self.shape.clone();
2530 let out_numel = Self::shape_numel(&out_shape);
2531
2532 let mut iter_shape: Vec<usize> = Vec::new();
2533 for (i, &s) in self.shape.iter().enumerate() {
2534 if i != axis { iter_shape.push(s); }
2535 }
2536 let n_slices: usize = iter_shape.iter().product::<usize>().max(1);
2537
2538 let mut result = vec![0.0f64; out_numel];
2539 let mut pos = vec![0usize; ndim];
2540
2541 for slice_idx in 0..n_slices {
2542 let mut remaining = slice_idx;
2543 let mut dim_idx = 0;
2544 for d in 0..ndim {
2545 if d == axis {
2546 pos[d] = 0;
2547 } else {
2548 let stride = {
2549 let mut s = 1usize;
2550 let mut di = 0;
2551 for d2 in 0..ndim {
2552 if d2 == axis { continue; }
2553 if di > dim_idx { s *= self.shape[d2]; }
2554 di += 1;
2555 }
2556 s
2557 };
2558 pos[d] = remaining / stride;
2559 remaining %= stride;
2560 dim_idx += 1;
2561 }
2562 }
2563
2564 let mut vals: Vec<(f64, usize)> = Vec::with_capacity(axis_len);
2565 for k in 0..axis_len {
2566 let mut flat = self.offset;
2567 for d in 0..ndim {
2568 let idx = if d == axis { k } else { pos[d] };
2569 flat += idx * self.strides[d];
2570 }
2571 vals.push((data[flat], k));
2572 }
2573
2574 if descending {
2575 vals.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
2576 .then(a.1.cmp(&b.1)));
2577 } else {
2578 vals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
2579 .then(a.1.cmp(&b.1)));
2580 }
2581
2582 for (k, &(_, orig_idx)) in vals.iter().enumerate() {
2583 let out_strides_local = Self::compute_strides(&out_shape);
2584 let mut flat = 0;
2585 for d in 0..ndim {
2586 let idx = if d == axis { k } else { pos[d] };
2587 flat += idx * out_strides_local[d];
2588 }
2589 result[flat] = orig_idx as f64;
2590 }
2591 }
2592
2593 Tensor::from_vec(result, &out_shape)
2594 }
2595
2596 pub fn einsum(notation: &str, inputs: &[&Tensor]) -> Result<Tensor, RuntimeError> {
2605 let parts: Vec<&str> = notation.split("->").collect();
2607 if parts.len() != 2 {
2608 return Err(RuntimeError::InvalidOperation(
2609 format!("einsum: expected 'subscripts->output' notation, got '{}'", notation),
2610 ));
2611 }
2612 let input_specs: Vec<&str> = parts[0].split(',').collect();
2613 let output_spec = parts[1];
2614
2615 if input_specs.len() != inputs.len() {
2616 return Err(RuntimeError::InvalidOperation(
2617 format!("einsum: {} input specs but {} tensors", input_specs.len(), inputs.len()),
2618 ));
2619 }
2620
2621 let mut label_size = std::collections::BTreeMap::new();
2623 for (i, &spec) in input_specs.iter().enumerate() {
2624 let chars: Vec<char> = spec.chars().collect();
2625 if chars.len() != inputs[i].ndim() {
2626 return Err(RuntimeError::InvalidOperation(
2627 format!("einsum: spec '{}' has {} dims but tensor has {}", spec, chars.len(), inputs[i].ndim()),
2628 ));
2629 }
2630 for (d, &c) in chars.iter().enumerate() {
2631 let sz = inputs[i].shape()[d];
2632 if let Some(&prev) = label_size.get(&c) {
2633 if prev != sz {
2634 return Err(RuntimeError::InvalidOperation(
2635 format!("einsum: label '{}' has conflicting sizes {} vs {}", c, prev, sz),
2636 ));
2637 }
2638 } else {
2639 label_size.insert(c, sz);
2640 }
2641 }
2642 }
2643
2644 let output_chars: Vec<char> = output_spec.chars().collect();
2646 let output_shape: Vec<usize> = output_chars.iter()
2647 .map(|c| label_size.get(c).copied().ok_or_else(||
2648 RuntimeError::InvalidOperation(format!("einsum: unknown label '{}' in output", c))))
2649 .collect::<Result<_, _>>()?;
2650 let out_numel = Self::shape_numel(&output_shape);
2651
2652 let output_set: std::collections::BTreeSet<char> = output_chars.iter().copied().collect();
2654 let contract_labels: Vec<char> = label_size.keys()
2655 .filter(|c| !output_set.contains(c))
2656 .copied()
2657 .collect();
2658 let contract_sizes: Vec<usize> = contract_labels.iter()
2659 .map(|c| label_size[c])
2660 .collect();
2661 let contract_numel: usize = contract_sizes.iter().product::<usize>().max(1);
2662
2663 let input_chars: Vec<Vec<char>> = input_specs.iter().map(|s| s.chars().collect()).collect();
2665
2666 let out_strides = Self::compute_strides(&output_shape);
2668 let mut result = vec![0.0f64; out_numel];
2669
2670 let input_data: Vec<Vec<f64>> = inputs.iter().map(|t| t.to_vec()).collect();
2672 let input_strides: Vec<Vec<usize>> = inputs.iter().map(|t| t.strides.clone()).collect();
2673 let input_offsets: Vec<usize> = inputs.iter().map(|t| t.offset).collect();
2674
2675 for out_idx in 0..out_numel {
2676 let mut label_vals = std::collections::BTreeMap::new();
2678 let mut remaining = out_idx;
2679 for (d, &c) in output_chars.iter().enumerate() {
2680 let stride = if d < out_strides.len() { out_strides[d] } else { 1 };
2681 label_vals.insert(c, remaining / stride);
2682 remaining %= stride;
2683 }
2684
2685 let mut acc = BinnedAccumulatorF64::new();
2686 for cidx in 0..contract_numel {
2688 let mut cr = cidx;
2690 for (ci, &cl) in contract_labels.iter().enumerate() {
2691 let stride: usize = contract_sizes[ci+1..].iter().product::<usize>().max(1);
2692 label_vals.insert(cl, cr / stride);
2693 cr %= stride;
2694 }
2695
2696 let mut product = 1.0f64;
2698 for (inp_idx, chars) in input_chars.iter().enumerate() {
2699 let mut flat = input_offsets[inp_idx];
2700 for (d, &c) in chars.iter().enumerate() {
2701 flat += label_vals[&c] * input_strides[inp_idx][d];
2702 }
2703 product *= input_data[inp_idx][flat];
2704 }
2705 acc.add(product);
2706 }
2707 result[out_idx] = acc.finalize();
2708 }
2709
2710 if output_shape.is_empty() {
2711 Tensor::from_vec(result, &[1])
2712 } else {
2713 Tensor::from_vec(result, &output_shape)
2714 }
2715 }
2716
2717 pub fn unsqueeze(&self, dim: usize) -> Result<Tensor, RuntimeError> {
2726 let ndim = self.ndim();
2727 if dim > ndim {
2728 return Err(RuntimeError::IndexOutOfBounds { index: dim, length: ndim + 1 });
2729 }
2730 let mut new_shape = self.shape.clone();
2731 new_shape.insert(dim, 1);
2732 self.reshape(&new_shape)
2733 }
2734
2735 pub fn squeeze(&self, dim: Option<usize>) -> Result<Tensor, RuntimeError> {
2738 match dim {
2739 Some(d) => {
2740 if d >= self.ndim() {
2741 return Err(RuntimeError::IndexOutOfBounds { index: d, length: self.ndim() });
2742 }
2743 if self.shape[d] != 1 {
2744 return Err(RuntimeError::InvalidOperation(
2745 format!("squeeze: dimension {} has size {}, not 1", d, self.shape[d]),
2746 ));
2747 }
2748 let mut new_shape = self.shape.clone();
2749 new_shape.remove(d);
2750 if new_shape.is_empty() {
2751 new_shape.push(1); }
2753 self.reshape(&new_shape)
2754 }
2755 None => {
2756 let new_shape: Vec<usize> = self.shape.iter()
2757 .filter(|&&s| s != 1)
2758 .copied()
2759 .collect();
2760 let new_shape = if new_shape.is_empty() { vec![1] } else { new_shape };
2761 self.reshape(&new_shape)
2762 }
2763 }
2764 }
2765
2766 pub fn expand(&self, target_shape: &[usize]) -> Result<Tensor, RuntimeError> {
2770 self.broadcast_to(target_shape)
2771 }
2772
2773 pub fn flatten(&self, start_dim: usize, end_dim: usize) -> Result<Tensor, RuntimeError> {
2775 if start_dim > end_dim || end_dim >= self.ndim() {
2776 return Err(RuntimeError::InvalidOperation(
2777 format!("flatten: invalid dim range [{}, {}] for {}D tensor", start_dim, end_dim, self.ndim()),
2778 ));
2779 }
2780 let mut new_shape = Vec::new();
2781 for i in 0..start_dim {
2782 new_shape.push(self.shape[i]);
2783 }
2784 let flat_size: usize = self.shape[start_dim..=end_dim].iter().product();
2785 new_shape.push(flat_size);
2786 for i in (end_dim + 1)..self.ndim() {
2787 new_shape.push(self.shape[i]);
2788 }
2789 self.reshape(&new_shape)
2790 }
2791
2792 pub fn chunk(&self, n: usize, dim: usize) -> Result<Vec<Tensor>, RuntimeError> {
2794 if dim >= self.ndim() {
2795 return Err(RuntimeError::IndexOutOfBounds { index: dim, length: self.ndim() });
2796 }
2797 if n == 0 {
2798 return Err(RuntimeError::InvalidOperation("chunk: n must be > 0".into()));
2799 }
2800 let dim_size = self.shape[dim];
2801 let chunk_size = (dim_size + n - 1) / n;
2802 let mut sizes = Vec::new();
2803 let mut remaining = dim_size;
2804 while remaining > 0 {
2805 let s = remaining.min(chunk_size);
2806 sizes.push(s);
2807 remaining -= s;
2808 }
2809 self.split(&sizes, dim)
2810 }
2811
2812 pub fn split(&self, sizes: &[usize], dim: usize) -> Result<Vec<Tensor>, RuntimeError> {
2814 if dim >= self.ndim() {
2815 return Err(RuntimeError::IndexOutOfBounds { index: dim, length: self.ndim() });
2816 }
2817 let total: usize = sizes.iter().sum();
2818 if total != self.shape[dim] {
2819 return Err(RuntimeError::InvalidOperation(
2820 format!("split: sizes sum {} != dim size {}", total, self.shape[dim]),
2821 ));
2822 }
2823
2824 let mut results = Vec::new();
2825 let mut offset = 0;
2826
2827 for &sz in sizes {
2828 let ranges: Vec<(usize, usize)> = self.shape.iter()
2829 .enumerate()
2830 .map(|(i, &s)| {
2831 if i == dim { (offset, offset + sz) } else { (0, s) }
2832 })
2833 .collect();
2834 let chunk = self.slice(&ranges)?;
2835 results.push(chunk.to_contiguous());
2837 offset += sz;
2838 }
2839
2840 Ok(results)
2841 }
2842
2843 pub fn scale_add(&self, alpha: f64, other: &Tensor, beta: f64) -> Result<Tensor, RuntimeError> {
2848 if self.shape != other.shape {
2849 return Err(RuntimeError::InvalidOperation(
2850 "scale_add: shape mismatch".to_string(),
2851 ));
2852 }
2853 let a = self.to_vec();
2854 let b = other.to_vec();
2855 let result: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| alpha * x + beta * y).collect();
2856 Tensor::from_vec(result, &self.shape)
2857 }
2858}
2859