1
2use cjc_repro::Rng;
3
4use crate::accumulator::{binned_sum_f64, BinnedAccumulatorF64};
5use cjc_repro::KahanAccumulatorF64;
6
7use crate::accumulator;
8use crate::buffer::Buffer;
9use crate::dispatch;
10use crate::error::RuntimeError;
11use crate::kernel as kernel_fns;
12use crate::tensor_simd::{self, BinOp, UnaryOp};
13use crate::tensor_tiled::TiledMatmul;
14
15#[derive(Debug, Clone)]
24pub struct Tensor {
25 pub buffer: Buffer<f64>,
26 pub(crate) shape: Vec<usize>,
27 pub(crate) strides: Vec<usize>,
28 pub(crate) offset: usize,
29}
30
31impl Tensor {
32 pub(crate) fn compute_strides(shape: &[usize]) -> Vec<usize> {
36 let mut strides = vec![1usize; shape.len()];
37 for i in (0..shape.len().saturating_sub(1)).rev() {
38 strides[i] = strides[i + 1] * shape[i + 1];
39 }
40 strides
41 }
42
43 fn shape_numel(shape: &[usize]) -> usize {
45 shape.iter().product()
46 }
47
48 pub fn zeros(shape: &[usize]) -> Self {
50 let numel = Self::shape_numel(shape);
51 Tensor {
52 buffer: Buffer::alloc(numel, 0.0),
53 shape: shape.to_vec(),
54 strides: Self::compute_strides(shape),
55 offset: 0,
56 }
57 }
58
59 pub fn ones(shape: &[usize]) -> Self {
61 let numel = Self::shape_numel(shape);
62 Tensor {
63 buffer: Buffer::alloc(numel, 1.0),
64 shape: shape.to_vec(),
65 strides: Self::compute_strides(shape),
66 offset: 0,
67 }
68 }
69
70 pub fn randn(shape: &[usize], rng: &mut Rng) -> Self {
73 let numel = Self::shape_numel(shape);
74 let data: Vec<f64> = (0..numel).map(|_| rng.next_normal_f64()).collect();
75 Tensor {
76 buffer: Buffer::from_vec(data),
77 shape: shape.to_vec(),
78 strides: Self::compute_strides(shape),
79 offset: 0,
80 }
81 }
82
83 pub fn from_vec(data: Vec<f64>, shape: &[usize]) -> Result<Self, RuntimeError> {
86 let numel = Self::shape_numel(shape);
87 if data.len() != numel {
88 return Err(RuntimeError::ShapeMismatch {
89 expected: numel,
90 got: data.len(),
91 });
92 }
93 Ok(Tensor {
94 buffer: Buffer::from_vec(data),
95 shape: shape.to_vec(),
96 strides: Self::compute_strides(shape),
97 offset: 0,
98 })
99 }
100
101 pub fn shape(&self) -> &[usize] {
105 &self.shape
106 }
107
108 pub fn ndim(&self) -> usize {
110 self.shape.len()
111 }
112
113 pub fn len(&self) -> usize {
115 Self::shape_numel(&self.shape)
116 }
117
118 pub fn is_empty(&self) -> bool {
120 self.len() == 0
121 }
122
123 fn linear_index(&self, indices: &[usize]) -> Result<usize, RuntimeError> {
125 if indices.len() != self.shape.len() {
126 return Err(RuntimeError::DimensionMismatch {
127 expected: self.shape.len(),
128 got: indices.len(),
129 });
130 }
131 let mut off = self.offset;
132 for (i, &idx) in indices.iter().enumerate() {
133 if idx >= self.shape[i] {
134 return Err(RuntimeError::IndexOutOfBounds {
135 index: idx,
136 length: self.shape[i],
137 });
138 }
139 off += idx * self.strides[i];
140 }
141 Ok(off)
142 }
143
144 pub fn is_contiguous(&self) -> bool {
146 if self.offset != 0 {
147 return false;
148 }
149 let expected = Self::compute_strides(&self.shape);
150 self.strides == expected
151 }
152
153 pub fn slice(&self, ranges: &[(usize, usize)]) -> Result<Tensor, RuntimeError> {
156 if ranges.len() != self.shape.len() {
157 return Err(RuntimeError::DimensionMismatch {
158 expected: self.shape.len(),
159 got: ranges.len(),
160 });
161 }
162 let mut new_offset = self.offset;
163 let mut new_shape = Vec::with_capacity(ranges.len());
164 for (i, &(start, end)) in ranges.iter().enumerate() {
165 if end > self.shape[i] || start > end {
166 return Err(RuntimeError::IndexOutOfBounds {
167 index: end,
168 length: self.shape[i],
169 });
170 }
171 new_offset += start * self.strides[i];
172 new_shape.push(end - start);
173 }
174 Ok(Tensor {
175 buffer: self.buffer.clone(), shape: new_shape,
177 strides: self.strides.clone(),
178 offset: new_offset,
179 })
180 }
181
182 pub fn to_contiguous(&self) -> Tensor {
184 if self.is_contiguous() {
185 return self.clone();
186 }
187 let data = self.to_vec();
188 Tensor {
189 buffer: Buffer::from_vec(data),
190 shape: self.shape.clone(),
191 strides: Self::compute_strides(&self.shape),
192 offset: 0,
193 }
194 }
195
196 pub fn broadcast_to(&self, target_shape: &[usize]) -> Result<Tensor, RuntimeError> {
199 let src_ndim = self.shape.len();
200 let tgt_ndim = target_shape.len();
201 if tgt_ndim < src_ndim {
202 return Err(RuntimeError::InvalidOperation(
203 "cannot broadcast to a smaller rank".to_string(),
204 ));
205 }
206 let pad = tgt_ndim - src_ndim;
207 let mut new_strides = vec![0usize; tgt_ndim];
208 for i in 0..tgt_ndim {
209 if i < pad {
210 new_strides[i] = 0;
212 } else {
213 let src_i = i - pad;
214 if self.shape[src_i] == target_shape[i] {
215 new_strides[i] = self.strides[src_i];
216 } else if self.shape[src_i] == 1 {
217 new_strides[i] = 0; } else {
219 return Err(RuntimeError::ShapeMismatch {
220 expected: target_shape[i],
221 got: self.shape[src_i],
222 });
223 }
224 }
225 }
226 Ok(Tensor {
227 buffer: self.buffer.clone(),
228 shape: target_shape.to_vec(),
229 strides: new_strides,
230 offset: self.offset,
231 })
232 }
233
234 pub fn get(&self, indices: &[usize]) -> Result<f64, RuntimeError> {
236 let offset = self.linear_index(indices)?;
237 self.buffer
238 .get(offset)
239 .ok_or(RuntimeError::IndexOutOfBounds {
240 index: offset,
241 length: self.buffer.len(),
242 })
243 }
244
245 pub fn set(&mut self, indices: &[usize], val: f64) -> Result<(), RuntimeError> {
247 let offset = self.linear_index(indices)?;
248 self.buffer.set(offset, val)
249 }
250
251 pub fn to_vec(&self) -> Vec<f64> {
253 if self.is_contiguous() {
254 let full = self.buffer.borrow_data();
255 let numel = self.len();
256 if full.len() == numel {
257 return full.to_vec();
258 }
259 return full[..numel].to_vec();
262 }
263 let numel = self.len();
265 let mut result = Vec::with_capacity(numel);
266 let ndim = self.shape.len();
267 let mut indices = vec![0usize; ndim];
268 for _ in 0..numel {
269 let mut off = self.offset;
270 for d in 0..ndim {
271 off += indices[d] * self.strides[d];
272 }
273 result.push(self.buffer.get(off).unwrap_or(0.0));
274 for d in (0..ndim).rev() {
276 indices[d] += 1;
277 if indices[d] < self.shape[d] {
278 break;
279 }
280 indices[d] = 0;
281 }
282 }
283 result
284 }
285
286 pub fn reshape(&self, new_shape: &[usize]) -> Result<Tensor, RuntimeError> {
291 let new_numel = Self::shape_numel(new_shape);
292 if new_numel != self.len() {
293 return Err(RuntimeError::ShapeMismatch {
294 expected: self.len(),
295 got: new_numel,
296 });
297 }
298 let tensor = if self.is_contiguous() { self.clone() } else { self.to_contiguous() };
300 Ok(Tensor {
301 buffer: tensor.buffer,
302 shape: new_shape.to_vec(),
303 strides: Self::compute_strides(new_shape),
304 offset: 0,
305 })
306 }
307
308 fn elementwise_binop(
312 &self,
313 other: &Tensor,
314 op: impl Fn(f64, f64) -> f64,
315 ) -> Result<Tensor, RuntimeError> {
316 if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
317 let a = self.buffer.borrow_data();
319 let b = other.buffer.borrow_data();
320 let data: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| op(x, y)).collect();
321 return Ok(Tensor {
322 buffer: Buffer::from_vec(data),
323 shape: self.shape.clone(),
324 strides: Self::compute_strides(&self.shape),
325 offset: 0,
326 });
327 }
328
329 let result_shape = Self::broadcast_result_shape(&self.shape, &other.shape)?;
331 let a_broadcast = self.broadcast_to(&result_shape)?;
332 let b_broadcast = other.broadcast_to(&result_shape)?;
333
334 let numel = Self::shape_numel(&result_shape);
335 let ndim = result_shape.len();
336 let mut data = Vec::with_capacity(numel);
337 let mut indices = vec![0usize; ndim];
338
339 for _ in 0..numel {
340 let mut off_a = a_broadcast.offset;
341 let mut off_b = b_broadcast.offset;
342 for d in 0..ndim {
343 off_a += indices[d] * a_broadcast.strides[d];
344 off_b += indices[d] * b_broadcast.strides[d];
345 }
346 let va = a_broadcast.buffer.get(off_a).ok_or_else(|| {
347 RuntimeError::InvalidOperation(format!(
348 "broadcast binop: left operand index {} out of bounds (buffer len {})",
349 off_a,
350 a_broadcast.buffer.len()
351 ))
352 })?;
353 let vb = b_broadcast.buffer.get(off_b).ok_or_else(|| {
354 RuntimeError::InvalidOperation(format!(
355 "broadcast binop: right operand index {} out of bounds (buffer len {})",
356 off_b,
357 b_broadcast.buffer.len()
358 ))
359 })?;
360 data.push(op(va, vb));
361
362 for d in (0..ndim).rev() {
363 indices[d] += 1;
364 if indices[d] < result_shape[d] {
365 break;
366 }
367 indices[d] = 0;
368 }
369 }
370
371 Ok(Tensor {
372 buffer: Buffer::from_vec(data),
373 shape: result_shape.clone(),
374 strides: Self::compute_strides(&result_shape),
375 offset: 0,
376 })
377 }
378
379 fn broadcast_result_shape(a: &[usize], b: &[usize]) -> Result<Vec<usize>, RuntimeError> {
381 let max_ndim = a.len().max(b.len());
382 let mut result = Vec::with_capacity(max_ndim);
383 for i in 0..max_ndim {
384 let da = if i < max_ndim - a.len() { 1 } else { a[i - (max_ndim - a.len())] };
385 let db = if i < max_ndim - b.len() { 1 } else { b[i - (max_ndim - b.len())] };
386 if da == db {
387 result.push(da);
388 } else if da == 1 {
389 result.push(db);
390 } else if db == 1 {
391 result.push(da);
392 } else {
393 return Err(RuntimeError::ShapeMismatch {
394 expected: da,
395 got: db,
396 });
397 }
398 }
399 Ok(result)
400 }
401
402 fn elementwise_binop_simd(
407 &self,
408 other: &Tensor,
409 op: BinOp,
410 fallback: impl Fn(f64, f64) -> f64,
411 ) -> Result<Tensor, RuntimeError> {
412 if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
413 let a = self.buffer.borrow_data();
415 let b = other.buffer.borrow_data();
416 let data = tensor_simd::simd_binop(&a, &b, op);
417 return Ok(Tensor {
418 buffer: Buffer::from_vec(data),
419 shape: self.shape.clone(),
420 strides: Self::compute_strides(&self.shape),
421 offset: 0,
422 });
423 }
424 self.elementwise_binop(other, fallback)
426 }
427
428 pub fn add(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
430 self.elementwise_binop_simd(other, BinOp::Add, |a, b| a + b)
431 }
432
433 pub fn sub(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
435 self.elementwise_binop_simd(other, BinOp::Sub, |a, b| a - b)
436 }
437
438 pub fn mul_elem(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
440 self.elementwise_binop_simd(other, BinOp::Mul, |a, b| a * b)
441 }
442
443 pub fn div_elem(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
445 self.elementwise_binop_simd(other, BinOp::Div, |a, b| a / b)
446 }
447
448 pub fn fused_mul_add(&self, b: &Tensor, c: &Tensor) -> Result<Tensor, RuntimeError> {
454 if self.shape != b.shape || self.shape != c.shape {
455 return Err(RuntimeError::InvalidOperation(
456 "broadcast_fma: all three tensors must have the same shape".to_string(),
457 ));
458 }
459 if self.is_contiguous() && b.is_contiguous() && c.is_contiguous() {
460 let a_data = self.buffer.borrow_data();
461 let b_data = b.buffer.borrow_data();
462 let c_data = c.buffer.borrow_data();
463 let n = a_data.len();
464 let mut out = vec![0.0f64; n];
465 for i in 0..n {
468 out[i] = a_data[i] * b_data[i] + c_data[i];
469 }
470 return Ok(Tensor {
471 buffer: Buffer::from_vec(out),
472 shape: self.shape.clone(),
473 strides: Self::compute_strides(&self.shape),
474 offset: 0,
475 });
476 }
477 let temp = self.mul_elem(b)?;
479 temp.add(c)
480 }
481
482 pub fn elem_pow(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
486 self.elementwise_binop(other, |a, b| a.powf(b))
487 }
488
489 pub fn elem_min(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
491 self.elementwise_binop(other, |a, b| a.min(b))
492 }
493
494 pub fn elem_max(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
496 self.elementwise_binop(other, |a, b| a.max(b))
497 }
498
499 pub fn elem_atan2(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
501 self.elementwise_binop(other, |a, b| a.atan2(b))
502 }
503
504 pub fn elem_hypot(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
506 self.elementwise_binop(other, |a, b| a.hypot(b))
507 }
508
509 pub fn map(&self, f: impl Fn(f64) -> f64) -> Tensor {
511 let data: Vec<f64> = self.to_vec().iter().map(|&x| f(x)).collect();
512 Tensor {
513 buffer: Buffer::from_vec(data),
514 shape: self.shape.clone(),
515 strides: Self::compute_strides(&self.shape),
516 offset: 0,
517 }
518 }
519
520 pub fn map_simd(&self, op: UnaryOp) -> Tensor {
525 let src = self.to_vec();
526 let data = tensor_simd::simd_unary(&src, op);
527 Tensor {
528 buffer: Buffer::from_vec(data),
529 shape: self.shape.clone(),
530 strides: Self::compute_strides(&self.shape),
531 offset: 0,
532 }
533 }
534
535 pub fn sum(&self) -> f64 {
539 let data = self.buffer.borrow_data();
540 binned_sum_f64(&data)
541 }
542
543 pub fn binned_sum(&self) -> f64 {
547 let data = self.buffer.borrow_data();
548 accumulator::binned_sum_f64(&data)
549 }
550
551 pub fn dispatched_sum(&self, ctx: &dispatch::ReductionContext) -> f64 {
555 let data = self.buffer.borrow_data();
556 dispatch::dispatch_sum_f64(&data, ctx)
557 }
558
559 pub fn mean(&self) -> f64 {
561 let n = self.len();
562 if n == 0 {
563 return 0.0;
564 }
565 self.sum() / n as f64
566 }
567
568 pub fn dispatched_mean(&self, ctx: &dispatch::ReductionContext) -> f64 {
570 let n = self.len();
571 if n == 0 {
572 return 0.0;
573 }
574 self.dispatched_sum(ctx) / n as f64
575 }
576
577 pub fn sum_axis(&self, axis: usize) -> Result<Tensor, RuntimeError> {
587 let ndim = self.ndim();
588 if axis >= ndim {
589 return Err(RuntimeError::IndexOutOfBounds {
590 index: axis,
591 length: ndim,
592 });
593 }
594
595 let mut out_shape = self.shape.clone();
597 out_shape[axis] = 1;
598 let out_numel = Self::shape_numel(&out_shape);
599 let out_strides = Self::compute_strides(&out_shape);
600
601 let data = self.to_vec();
602 let axis_len = self.shape[axis];
603 let mut result = vec![0.0f64; out_numel];
604
605 let mut indices = vec![0usize; ndim];
607 for out_idx in 0..out_numel {
608 {
610 let mut remaining = out_idx;
611 for d in 0..ndim {
612 indices[d] = remaining / out_strides[d];
613 remaining %= out_strides[d];
614 }
615 }
616
617 let mut acc = BinnedAccumulatorF64::new();
618 for k in 0..axis_len {
619 let mut flat = self.offset;
621 for d in 0..ndim {
622 let idx = if d == axis { k } else { indices[d] };
623 flat += idx * self.strides[d];
624 }
625 acc.add(data[flat]);
626 }
627 result[out_idx] = acc.finalize();
628 }
629
630 Tensor::from_vec(result, &out_shape)
631 }
632
633 pub fn neg(&self) -> Tensor {
637 self.map(|x| -x)
638 }
639
640 pub fn transpose(&self) -> Tensor {
643 let ndim = self.ndim();
644 if ndim <= 1 {
645 return self.clone();
646 }
647 let mut new_shape = self.shape.clone();
649 let mut new_strides = self.strides.clone();
650 new_shape.reverse();
651 new_strides.reverse();
652 Tensor {
653 buffer: self.buffer.clone(), shape: new_shape,
655 strides: new_strides,
656 offset: self.offset,
657 }
658 }
659
660 pub fn transpose_axes(&self, axes: &[usize]) -> Result<Tensor, RuntimeError> {
664 let ndim = self.ndim();
665 if axes.len() != ndim {
666 return Err(RuntimeError::InvalidOperation(
667 format!("transpose_axes: expected {} axes, got {}", ndim, axes.len()),
668 ));
669 }
670 let mut seen = vec![false; ndim];
672 for &ax in axes {
673 if ax >= ndim {
674 return Err(RuntimeError::IndexOutOfBounds { index: ax, length: ndim });
675 }
676 if seen[ax] {
677 return Err(RuntimeError::InvalidOperation(
678 format!("transpose_axes: duplicate axis {ax}"),
679 ));
680 }
681 seen[ax] = true;
682 }
683 let new_shape: Vec<usize> = axes.iter().map(|&ax| self.shape[ax]).collect();
684 let new_strides: Vec<usize> = axes.iter().map(|&ax| self.strides[ax]).collect();
685 Ok(Tensor {
686 buffer: self.buffer.clone(),
687 shape: new_shape,
688 strides: new_strides,
689 offset: self.offset,
690 })
691 }
692
693 pub fn scalar_mul(&self, s: f64) -> Tensor {
695 self.map(|x| x * s)
696 }
697
698 pub fn from_vec_unchecked(data: Vec<f64>, shape: &[usize]) -> Tensor {
703 Self::from_vec(data, shape).expect("Tensor::from_vec_unchecked: shape mismatch")
704 }
705
706 pub fn add_unchecked(&self, other: &Tensor) -> Tensor {
708 self.add(other).expect("Tensor::add shape mismatch")
709 }
710
711 pub fn sub_unchecked(&self, other: &Tensor) -> Tensor {
713 self.sub(other).expect("Tensor::sub shape mismatch")
714 }
715
716 pub fn mul_elem_unchecked(&self, other: &Tensor) -> Tensor {
718 self.mul_elem(other).expect("Tensor::mul_elem shape mismatch")
719 }
720
721 pub fn div_elem_unchecked(&self, other: &Tensor) -> Tensor {
723 self.div_elem(other).expect("Tensor::div_elem shape mismatch")
724 }
725
726 pub fn matmul_unchecked(&self, other: &Tensor) -> Tensor {
728 self.matmul(other).expect("Tensor::matmul dimension mismatch")
729 }
730
731 pub fn matmul(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
735 if self.ndim() != 2 || other.ndim() != 2 {
736 return Err(RuntimeError::InvalidOperation(
737 "matmul requires 2-D tensors".to_string(),
738 ));
739 }
740 let m = self.shape[0];
741 let k = self.shape[1];
742 let k2 = other.shape[0];
743 let n = other.shape[1];
744 if k != k2 {
745 return Err(RuntimeError::DimensionMismatch {
746 expected: k,
747 got: k2,
748 });
749 }
750
751 let a = self.to_vec();
752 let b = other.to_vec();
753
754 #[cfg(feature = "parallel")]
757 {
758 if m >= 256 || n >= 256 || k >= 256 {
759 return Self::matmul_parallel_mode_a(&a, &b, m, n, k);
760 }
761 }
762
763 if m >= 64 || n >= 64 || k >= 64 {
768 return Self::matmul_tiled(&a, &b, m, n, k);
769 }
770
771 Self::matmul_sequential(&a, &b, m, n, k)
773 }
774
775 fn matmul_sequential(
777 a: &[f64], b: &[f64], m: usize, n: usize, k: usize,
778 ) -> Result<Tensor, RuntimeError> {
779 let mut result = vec![0.0f64; m * n];
780 for i in 0..m {
781 for j in 0..n {
782 let mut acc = KahanAccumulatorF64::new();
783 for p in 0..k {
784 acc.add(a[i * k + p] * b[p * n + j]);
785 }
786 result[i * n + j] = acc.finalize();
787 }
788 }
789 Tensor::from_vec(result, &[m, n])
790 }
791
792 fn matmul_tiled(
799 a: &[f64], b: &[f64], m: usize, n: usize, k: usize,
800 ) -> Result<Tensor, RuntimeError> {
801 let engine = TiledMatmul::new();
802 let result = engine.matmul(a, m, k, b, n);
803 Tensor::from_vec(result, &[m, n])
804 }
805
806 #[cfg(feature = "parallel")]
817 fn matmul_parallel_mode_a(
818 a: &[f64], b: &[f64], m: usize, n: usize, k: usize,
819 ) -> Result<Tensor, RuntimeError> {
820 use rayon::prelude::*;
821 use cjc_repro::KahanAccumulatorF64;
822
823 if m >= 512 && n >= 512 {
828 let band_size = (m + rayon::current_num_threads() - 1) / rayon::current_num_threads();
831 let band_size = band_size.max(64); let mut result = vec![0.0f64; m * n];
833
834 result
835 .par_chunks_mut(band_size * n)
836 .enumerate()
837 .for_each(|(band_idx, band)| {
838 let i_start = band_idx * band_size;
839 let i_end = (i_start + band_size).min(m);
840 let band_m = i_end - i_start;
841 let a_band = &a[i_start * k .. i_end * k];
842 let engine = crate::tensor_tiled::TiledMatmul::new();
843 let tiled_result = engine.matmul(a_band, band_m, k, b, n);
844 band[..band_m * n].copy_from_slice(&tiled_result);
845 });
846
847 return Tensor::from_vec(result, &[m, n]);
848 }
849
850 let mut result = vec![0.0f64; m * n];
852 result
853 .par_chunks_mut(n)
854 .enumerate()
855 .for_each(|(i, row)| {
856 for j in 0..n {
857 let mut acc = KahanAccumulatorF64::new();
858 for p in 0..k {
859 acc.add(a[i * k + p] * b[p * n + j]);
860 }
861 row[j] = acc.finalize();
862 }
863 });
864
865 Tensor::from_vec(result, &[m, n])
866 }
867
868 pub fn bmm(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
876 if self.ndim() < 2 || other.ndim() < 2 {
877 return Err(RuntimeError::InvalidOperation(
878 "bmm requires at least 2-D tensors".to_string(),
879 ));
880 }
881 if self.ndim() == 2 && other.ndim() == 2 {
882 return self.matmul(other);
883 }
884 if self.ndim() != other.ndim() {
885 return Err(RuntimeError::InvalidOperation(
886 format!(
887 "bmm requires same number of dimensions, got {} and {}",
888 self.ndim(),
889 other.ndim()
890 ),
891 ));
892 }
893 let nd = self.ndim();
894 let batch_dims_a = &self.shape[..nd - 2];
895 let batch_dims_b = &other.shape[..nd - 2];
896 if batch_dims_a != batch_dims_b {
897 return Err(RuntimeError::InvalidOperation(
898 format!(
899 "bmm batch dimensions mismatch: {:?} vs {:?}",
900 batch_dims_a, batch_dims_b
901 ),
902 ));
903 }
904 let m = self.shape[nd - 2];
905 let k = self.shape[nd - 1];
906 let k2 = other.shape[nd - 2];
907 let n = other.shape[nd - 1];
908 if k != k2 {
909 return Err(RuntimeError::DimensionMismatch {
910 expected: k,
911 got: k2,
912 });
913 }
914
915 let batch_size: usize = batch_dims_a.iter().product();
916 let a = self.to_vec();
917 let b = other.to_vec();
918 let mat_a_stride = m * k;
919 let mat_b_stride = k * n;
920 let mat_c_stride = m * n;
921 let mut result = vec![0.0f64; batch_size * mat_c_stride];
922
923 let compute_batch = |batch: usize, c_slice: &mut [f64]| {
925 let a_slice = &a[batch * mat_a_stride..(batch + 1) * mat_a_stride];
926 let b_slice = &b[batch * mat_b_stride..(batch + 1) * mat_b_stride];
927
928 if m >= 64 || n >= 64 || k >= 64 {
929 let engine = crate::tensor_tiled::TiledMatmul::new();
930 let tiled = engine.matmul(a_slice, m, k, b_slice, n);
931 c_slice.copy_from_slice(&tiled);
932 } else {
933 for i in 0..m {
934 for j in 0..n {
935 let mut acc = KahanAccumulatorF64::new();
936 for p in 0..k {
937 acc.add(a_slice[i * k + p] * b_slice[p * n + j]);
938 }
939 c_slice[i * n + j] = acc.finalize();
940 }
941 }
942 }
943 };
944
945 #[cfg(feature = "parallel")]
947 {
948 if batch_size > 1 && m * k >= 4096 {
949 use rayon::prelude::*;
950 result
951 .par_chunks_mut(mat_c_stride)
952 .enumerate()
953 .for_each(|(batch, c_slice)| {
954 compute_batch(batch, c_slice);
955 });
956
957 let mut out_shape = batch_dims_a.to_vec();
958 out_shape.push(m);
959 out_shape.push(n);
960 return Tensor::from_vec(result, &out_shape);
961 }
962 }
963
964 for batch in 0..batch_size {
966 let c_off = batch * mat_c_stride;
967 compute_batch(batch, &mut result[c_off..c_off + mat_c_stride]);
968 }
969
970 let mut out_shape = batch_dims_a.to_vec();
971 out_shape.push(m);
972 out_shape.push(n);
973 Tensor::from_vec(result, &out_shape)
974 }
975
976 pub fn softmax(&self) -> Result<Tensor, RuntimeError> {
984 if self.ndim() == 0 {
985 return Err(RuntimeError::InvalidOperation(
986 "softmax requires at least 1-D tensor".to_string(),
987 ));
988 }
989 let data_ref;
991 let data_vec;
992 let data: &[f64] = if self.is_contiguous() && self.offset == 0 {
993 data_ref = self.buffer.borrow_data();
994 &data_ref
995 } else {
996 data_vec = self.to_vec();
997 &data_vec
998 };
999 let n = *self.shape.last().unwrap(); let outer: usize = data.len() / n; let mut result = vec![0.0f64; data.len()];
1002
1003 for row in 0..outer {
1004 let start = row * n;
1005 let end = start + n;
1006 let slice = &data[start..end];
1007
1008 let mut max_val = f64::NEG_INFINITY;
1010 for &v in slice {
1011 if v > max_val {
1012 max_val = v;
1013 }
1014 }
1015
1016 let mut exp_vals = vec![0.0f64; n];
1018 let mut sum = 0.0f64;
1019 let mut comp = 0.0f64; for i in 0..n {
1021 let e = (slice[i] - max_val).exp();
1022 exp_vals[i] = e;
1023 let y = e - comp;
1025 let t = sum + y;
1026 comp = (t - sum) - y;
1027 sum = t;
1028 }
1029
1030 if sum == 0.0 {
1032 let uniform = 1.0 / n as f64;
1034 for i in 0..n {
1035 result[start + i] = uniform;
1036 }
1037 } else {
1038 for i in 0..n {
1039 result[start + i] = exp_vals[i] / sum;
1040 }
1041 }
1042 }
1043
1044 Tensor::from_vec(result, &self.shape)
1045 }
1046
1047 pub fn layer_norm(
1058 &self,
1059 gamma: &Tensor,
1060 beta: &Tensor,
1061 eps: f64,
1062 ) -> Result<Tensor, RuntimeError> {
1063 if self.ndim() == 0 {
1064 return Err(RuntimeError::InvalidOperation(
1065 "layer_norm requires at least 1-D tensor".to_string(),
1066 ));
1067 }
1068 let d = *self.shape.last().unwrap();
1069 if gamma.len() != d || beta.len() != d {
1070 return Err(RuntimeError::InvalidOperation(
1071 format!(
1072 "layer_norm: gamma/beta length {} must match last dim {}",
1073 gamma.len(),
1074 d
1075 ),
1076 ));
1077 }
1078
1079 let data = self.to_vec();
1080 let gamma_data = gamma.to_vec();
1081 let beta_data = beta.to_vec();
1082 let outer = data.len() / d;
1083 let mut result = vec![0.0f64; data.len()];
1084
1085 for row in 0..outer {
1086 let start = row * d;
1087 let slice = &data[start..start + d];
1088
1089 let mean = binned_sum_f64(slice) / d as f64;
1091
1092 let diffs: Vec<f64> = slice.iter().map(|&x| {
1094 let diff = x - mean;
1095 diff * diff
1096 }).collect();
1097 let variance = binned_sum_f64(&diffs) / d as f64;
1098
1099 let inv_std = 1.0 / (variance + eps).sqrt();
1101 for i in 0..d {
1102 let normalized = (slice[i] - mean) * inv_std;
1103 result[start + i] = gamma_data[i] * normalized + beta_data[i];
1104 }
1105 }
1106
1107 Tensor::from_vec(result, &self.shape)
1108 }
1109
1110 fn map_elementwise(&self, f: impl Fn(f64) -> f64) -> Tensor {
1115 if self.is_contiguous() && self.offset == 0 && self.buffer.refcount() == 1 {
1116 let mut data = self.buffer.borrow_data().clone();
1118 for x in data.iter_mut() {
1119 *x = f(*x);
1120 }
1121 Tensor::from_vec(data, &self.shape).unwrap()
1122 } else {
1123 let data = self.to_vec();
1125 let result: Vec<f64> = data.iter().map(|&x| f(x)).collect();
1126 Tensor::from_vec(result, &self.shape).unwrap()
1127 }
1128 }
1129
1130 pub fn relu(&self) -> Tensor {
1131 self.map_elementwise(|x| if x > 0.0 { x } else { 0.0 })
1132 }
1133
1134 pub fn sigmoid(&self) -> Tensor {
1136 self.map_elementwise(|x| 1.0 / (1.0 + (-x).exp()))
1137 }
1138
1139 pub fn tanh_activation(&self) -> Tensor {
1141 self.map_elementwise(|x| x.tanh())
1142 }
1143
1144 pub fn leaky_relu(&self, alpha: f64) -> Tensor {
1146 self.map_elementwise(move |x| if x > 0.0 { x } else { alpha * x })
1147 }
1148
1149 pub fn silu(&self) -> Tensor {
1151 let data = self.to_vec();
1152 let result: Vec<f64> = data.iter().map(|&x| x / (1.0 + (-x).exp())).collect();
1153 Tensor::from_vec(result, &self.shape).unwrap()
1154 }
1155
1156 pub fn mish(&self) -> Tensor {
1158 let data = self.to_vec();
1159 let result: Vec<f64> = data.iter().map(|&x| {
1160 let sp = (1.0 + x.exp()).ln();
1161 x * sp.tanh()
1162 }).collect();
1163 Tensor::from_vec(result, &self.shape).unwrap()
1164 }
1165
1166 pub fn argmax(&self) -> usize {
1168 let data = self.to_vec();
1169 let mut best_idx = 0;
1170 let mut best_val = f64::NEG_INFINITY;
1171 for (i, &v) in data.iter().enumerate() {
1172 if v > best_val || (v == best_val && i < best_idx) {
1173 best_val = v;
1174 best_idx = i;
1175 }
1176 }
1177 best_idx
1178 }
1179
1180 pub fn argmin(&self) -> usize {
1182 let data = self.to_vec();
1183 let mut best_idx = 0;
1184 let mut best_val = f64::INFINITY;
1185 for (i, &v) in data.iter().enumerate() {
1186 if v < best_val || (v == best_val && i < best_idx) {
1187 best_val = v;
1188 best_idx = i;
1189 }
1190 }
1191 best_idx
1192 }
1193
1194 pub fn clamp(&self, min: f64, max: f64) -> Tensor {
1196 let data = self.to_vec();
1197 let result: Vec<f64> = data.iter().map(|&x| x.max(min).min(max)).collect();
1198 Tensor::from_vec(result, &self.shape).unwrap()
1199 }
1200
1201 pub fn one_hot(indices: &[usize], depth: usize) -> Result<Tensor, RuntimeError> {
1204 let n = indices.len();
1205 let mut data = vec![0.0; n * depth];
1206 for (i, &idx) in indices.iter().enumerate() {
1207 if idx >= depth {
1208 return Err(RuntimeError::InvalidOperation(format!(
1209 "one_hot: index {idx} >= depth {depth}"
1210 )));
1211 }
1212 data[i * depth + idx] = 1.0;
1213 }
1214 Tensor::from_vec(data, &[n, depth])
1215 }
1216
1217 pub fn cat(tensors: &[&Tensor], axis: usize) -> Result<Tensor, RuntimeError> {
1223 if tensors.is_empty() {
1224 return Err(RuntimeError::InvalidOperation("cat: no tensors".to_string()));
1225 }
1226 let ndim = tensors[0].ndim();
1227 if axis >= ndim {
1228 return Err(RuntimeError::InvalidOperation(
1229 format!("cat: axis {axis} out of bounds for {ndim}D tensor"),
1230 ));
1231 }
1232 for (i, t) in tensors.iter().enumerate().skip(1) {
1233 if t.ndim() != ndim {
1234 return Err(RuntimeError::InvalidOperation(
1235 format!("cat: tensor {i} has different ndim"),
1236 ));
1237 }
1238 for d in 0..ndim {
1239 if d != axis && t.shape[d] != tensors[0].shape[d] {
1240 return Err(RuntimeError::InvalidOperation(
1241 format!("cat: shape mismatch at dim {d}"),
1242 ));
1243 }
1244 }
1245 }
1246 let mut out_shape = tensors[0].shape.clone();
1247 for t in tensors.iter().skip(1) {
1248 out_shape[axis] += t.shape[axis];
1249 }
1250 let total = out_shape.iter().product::<usize>();
1251 let mut result = vec![0.0; total];
1252 let mut out_strides = vec![1usize; ndim];
1253 for d in (0..ndim - 1).rev() {
1254 out_strides[d] = out_strides[d + 1] * out_shape[d + 1];
1255 }
1256 let mut offset = 0;
1257 for t in tensors {
1258 let t_data = t.to_vec();
1259 let t_total: usize = t.shape.iter().product();
1260 let mut t_strides = vec![1usize; ndim];
1261 for d in (0..ndim - 1).rev() {
1262 t_strides[d] = t_strides[d + 1] * t.shape[d + 1];
1263 }
1264 for idx in 0..t_total {
1265 let mut remaining = idx;
1266 let mut out_flat = 0;
1267 for d in 0..ndim {
1268 let coord = remaining / t_strides[d];
1269 remaining %= t_strides[d];
1270 let out_coord = if d == axis { coord + offset } else { coord };
1271 out_flat += out_coord * out_strides[d];
1272 }
1273 result[out_flat] = t_data[idx];
1274 }
1275 offset += t.shape[axis];
1276 }
1277 Tensor::from_vec(result, &out_shape)
1278 }
1279
1280 pub fn stack(tensors: &[&Tensor], axis: usize) -> Result<Tensor, RuntimeError> {
1282 if tensors.is_empty() {
1283 return Err(RuntimeError::InvalidOperation("stack: no tensors".to_string()));
1284 }
1285 let base_shape = &tensors[0].shape;
1286 let ndim = base_shape.len();
1287 if axis > ndim {
1288 return Err(RuntimeError::InvalidOperation(
1289 format!("stack: axis {axis} out of bounds"),
1290 ));
1291 }
1292 for (i, t) in tensors.iter().enumerate().skip(1) {
1293 if &t.shape != base_shape {
1294 return Err(RuntimeError::InvalidOperation(
1295 format!("stack: tensor {i} shape mismatch"),
1296 ));
1297 }
1298 }
1299 let mut out_shape = Vec::with_capacity(ndim + 1);
1300 for d in 0..axis { out_shape.push(base_shape[d]); }
1301 out_shape.push(tensors.len());
1302 for d in axis..ndim { out_shape.push(base_shape[d]); }
1303 let total: usize = out_shape.iter().product();
1304 let mut result = vec![0.0; total];
1305 let inner_size: usize = base_shape[axis..].iter().product::<usize>().max(1);
1306 let outer_size: usize = base_shape[..axis].iter().product::<usize>().max(1);
1307 for (t_idx, t) in tensors.iter().enumerate() {
1308 let t_data = t.to_vec();
1309 for outer in 0..outer_size {
1310 for inner in 0..inner_size {
1311 let src = outer * inner_size + inner;
1312 let dst = outer * (tensors.len() * inner_size) + t_idx * inner_size + inner;
1313 if src < t_data.len() && dst < result.len() {
1314 result[dst] = t_data[src];
1315 }
1316 }
1317 }
1318 }
1319 Tensor::from_vec(result, &out_shape)
1320 }
1321
1322 pub fn topk(&self, k: usize) -> Result<(Tensor, Vec<usize>), RuntimeError> {
1324 let data = self.to_vec();
1325 let n = data.len();
1326 if k > n {
1327 return Err(RuntimeError::InvalidOperation(
1328 format!("topk: k={k} exceeds data length {n}"),
1329 ));
1330 }
1331 let mut indexed: Vec<(usize, f64)> = data.into_iter().enumerate().collect();
1332 indexed.sort_by(|a, b| b.1.total_cmp(&a.1).then(a.0.cmp(&b.0)));
1333 let top_k: Vec<(usize, f64)> = indexed[..k].to_vec();
1334 let values: Vec<f64> = top_k.iter().map(|&(_, v)| v).collect();
1335 let indices: Vec<usize> = top_k.iter().map(|&(i, _)| i).collect();
1336 Ok((Tensor::from_vec(values, &[k])?, indices))
1337 }
1338
1339 pub fn gelu(&self) -> Tensor {
1341 let data = self.to_vec();
1342 let sqrt_2_over_pi = (2.0_f64 / std::f64::consts::PI).sqrt();
1343 let result: Vec<f64> = data.iter().map(|&x| {
1344 let inner = sqrt_2_over_pi * (x + 0.044715 * x * x * x);
1345 0.5 * x * (1.0 + inner.tanh())
1346 }).collect();
1347 Tensor::from_vec(result, &self.shape).unwrap()
1348 }
1349
1350 pub fn linear(
1356 &self,
1357 weight: &Tensor,
1358 bias: &Tensor,
1359 ) -> Result<Tensor, RuntimeError> {
1360 if weight.ndim() != 2 {
1361 return Err(RuntimeError::InvalidOperation(
1362 "linear: weight must be 2-D [out_features, in_features]".to_string(),
1363 ));
1364 }
1365 let out_features = weight.shape[0];
1366 let in_features = weight.shape[1];
1367 let last_dim = *self.shape.last().ok_or_else(|| {
1368 RuntimeError::InvalidOperation("linear: input must be at least 1-D".to_string())
1369 })?;
1370 if last_dim != in_features {
1371 return Err(RuntimeError::DimensionMismatch {
1372 expected: in_features,
1373 got: last_dim,
1374 });
1375 }
1376 if bias.len() != out_features {
1377 return Err(RuntimeError::InvalidOperation(
1378 format!(
1379 "linear: bias length {} must match out_features {}",
1380 bias.len(),
1381 out_features
1382 ),
1383 ));
1384 }
1385
1386 let data = self.to_vec();
1387 let w = weight.to_vec();
1388 let b = bias.to_vec();
1389 let outer = data.len() / in_features;
1390 let mut result = vec![0.0f64; outer * out_features];
1391
1392 for row in 0..outer {
1393 let x_start = row * in_features;
1394 let x_slice = &data[x_start..x_start + in_features];
1395 let y_start = row * out_features;
1396 for j in 0..out_features {
1397 let w_start = j * in_features;
1398 let mut acc = BinnedAccumulatorF64::new();
1399 for p in 0..in_features {
1400 acc.add(x_slice[p] * w[w_start + p]);
1401 }
1402 result[y_start + j] = acc.finalize() + b[j];
1403 }
1404 }
1405
1406 let mut out_shape = self.shape[..self.shape.len() - 1].to_vec();
1407 out_shape.push(out_features);
1408 Tensor::from_vec(result, &out_shape)
1409 }
1410
1411 pub fn conv1d(
1415 &self,
1416 filters: &Tensor,
1417 bias: &Tensor,
1418 ) -> Result<Tensor, RuntimeError> {
1419 if self.ndim() != 1 {
1420 return Err(RuntimeError::InvalidOperation(
1421 "conv1d: input must be 1-D [signal_len]".to_string(),
1422 ));
1423 }
1424 if filters.ndim() != 2 {
1425 return Err(RuntimeError::InvalidOperation(
1426 "conv1d: filters must be 2-D [out_channels, kernel_size]".to_string(),
1427 ));
1428 }
1429 let signal_len = self.shape[0];
1430 let out_channels = filters.shape[0];
1431 let kernel_size = filters.shape[1];
1432 if signal_len < kernel_size {
1433 return Err(RuntimeError::InvalidOperation(
1434 format!(
1435 "conv1d: signal_len {} < kernel_size {}",
1436 signal_len, kernel_size
1437 ),
1438 ));
1439 }
1440 if bias.len() != out_channels {
1441 return Err(RuntimeError::InvalidOperation(
1442 format!(
1443 "conv1d: bias length {} must match out_channels {}",
1444 bias.len(), out_channels
1445 ),
1446 ));
1447 }
1448 let out_len = signal_len - kernel_size + 1;
1449 let s = self.to_vec();
1450 let f = filters.to_vec();
1451 let b = bias.to_vec();
1452 let mut result = vec![0.0; out_channels * out_len];
1453 kernel_fns::conv1d_raw(&s, &f, &b, &mut result, signal_len, out_channels, kernel_size);
1454 Tensor::from_vec(result, &[out_channels, out_len])
1455 }
1456
1457 pub fn conv2d(
1471 &self,
1472 filters: &Tensor,
1473 bias: &Tensor,
1474 stride: usize,
1475 ) -> Result<Tensor, RuntimeError> {
1476 if self.ndim() != 4 {
1477 return Err(RuntimeError::InvalidOperation(
1478 "conv2d: input must be 4-D [N, C_in, H, W]".to_string(),
1479 ));
1480 }
1481 if filters.ndim() != 4 {
1482 return Err(RuntimeError::InvalidOperation(
1483 "conv2d: filters must be 4-D [C_out, C_in, kH, kW]".to_string(),
1484 ));
1485 }
1486 if stride == 0 {
1487 return Err(RuntimeError::InvalidOperation(
1488 "conv2d: stride must be >= 1".to_string(),
1489 ));
1490 }
1491
1492 let n = self.shape[0];
1493 let c_in = self.shape[1];
1494 let h_in = self.shape[2];
1495 let w_in = self.shape[3];
1496
1497 let c_out = filters.shape[0];
1498 let c_in_check = filters.shape[1];
1499 let kh = filters.shape[2];
1500 let kw = filters.shape[3];
1501
1502 if c_in != c_in_check {
1503 return Err(RuntimeError::InvalidOperation(format!(
1504 "conv2d: input C_in={} does not match filter C_in={}",
1505 c_in, c_in_check
1506 )));
1507 }
1508 if h_in < kh || w_in < kw {
1509 return Err(RuntimeError::InvalidOperation(format!(
1510 "conv2d: input spatial [{}, {}] is smaller than kernel [{}, {}]",
1511 h_in, w_in, kh, kw
1512 )));
1513 }
1514 if bias.len() != c_out {
1515 return Err(RuntimeError::InvalidOperation(format!(
1516 "conv2d: bias length {} must match C_out={}",
1517 bias.len(), c_out
1518 )));
1519 }
1520
1521 let h_out = (h_in - kh) / stride + 1;
1522 let w_out = (w_in - kw) / stride + 1;
1523
1524 let inp = self.to_vec();
1525 let flt = filters.to_vec();
1526 let b = bias.to_vec();
1527 let mut result = vec![0.0f64; n * c_out * h_out * w_out];
1528
1529 kernel_fns::conv2d_raw(&inp, &flt, &b, &mut result,
1530 n, c_in, h_in, w_in, c_out, kh, kw, stride);
1531
1532 Tensor::from_vec(result, &[n, c_out, h_out, w_out])
1533 }
1534
1535 pub fn maxpool2d(&self, ph: usize, pw: usize) -> Result<Tensor, RuntimeError> {
1542 if self.ndim() != 4 {
1543 return Err(RuntimeError::InvalidOperation(
1544 "maxpool2d: input must be 4-D [N, C, H, W]".to_string(),
1545 ));
1546 }
1547 if ph == 0 || pw == 0 {
1548 return Err(RuntimeError::InvalidOperation(
1549 "maxpool2d: pool size must be >= 1".to_string(),
1550 ));
1551 }
1552
1553 let n = self.shape[0];
1554 let c = self.shape[1];
1555 let h_in = self.shape[2];
1556 let w_in = self.shape[3];
1557
1558 if h_in < ph || w_in < pw {
1559 return Err(RuntimeError::InvalidOperation(format!(
1560 "maxpool2d: input [{}, {}] smaller than pool [{}, {}]",
1561 h_in, w_in, ph, pw
1562 )));
1563 }
1564
1565 let h_out = h_in / ph;
1566 let w_out = w_in / pw;
1567
1568 let inp = self.to_vec();
1569 let mut result = vec![0.0f64; n * c * h_out * w_out];
1570
1571 kernel_fns::maxpool2d_raw(&inp, &mut result, n, c, h_in, w_in, ph, pw);
1572
1573 Tensor::from_vec(result, &[n, c, h_out, w_out])
1574 }
1575
1576 pub fn scaled_dot_product_attention(
1585 queries: &Tensor,
1586 keys: &Tensor,
1587 values: &Tensor,
1588 ) -> Result<Tensor, RuntimeError> {
1589 if queries.ndim() < 2 || keys.ndim() < 2 || values.ndim() < 2 {
1590 return Err(RuntimeError::InvalidOperation(
1591 "attention: Q, K, V must be at least 2-D".to_string(),
1592 ));
1593 }
1594 let nd = queries.ndim();
1595 let d_k = queries.shape[nd - 1];
1596 let scale = 1.0 / (d_k as f64).sqrt();
1597
1598 let keys_t = keys.transpose_last_two()?;
1600
1601 let scores = queries.bmm(&keys_t)?;
1603
1604 let scores_scaled = scores.scalar_mul(scale);
1606
1607 let attn_weights = scores_scaled.softmax()?;
1609
1610 attn_weights.bmm(values)
1612 }
1613
1614 pub fn transpose_last_two(&self) -> Result<Tensor, RuntimeError> {
1618 if self.ndim() < 2 {
1619 return Err(RuntimeError::InvalidOperation(
1620 "transpose_last_two requires at least 2-D tensor".to_string(),
1621 ));
1622 }
1623 let nd = self.ndim();
1624 let rows = self.shape[nd - 2];
1625 let cols = self.shape[nd - 1];
1626 let data = self.to_vec();
1627 let batch_size: usize = self.shape[..nd - 2].iter().product::<usize>().max(1);
1628 let mat_size = rows * cols;
1629 let mut result = vec![0.0f64; data.len()];
1630
1631 for b in 0..batch_size {
1632 let off = b * mat_size;
1633 for i in 0..rows {
1634 for j in 0..cols {
1635 result[off + j * rows + i] = data[off + i * cols + j];
1636 }
1637 }
1638 }
1639
1640 let mut out_shape = self.shape.clone();
1641 out_shape[nd - 2] = cols;
1642 out_shape[nd - 1] = rows;
1643 Tensor::from_vec(result, &out_shape)
1644 }
1645
1646 pub fn from_bytes(bytes: &[u8], shape: &[usize], dtype: &str) -> Result<Tensor, RuntimeError> {
1662 let numel = Self::shape_numel(shape);
1663 match dtype {
1664 "f64" => {
1665 let expected = numel * 8;
1666 if bytes.len() != expected {
1667 return Err(RuntimeError::ShapeMismatch {
1668 expected,
1669 got: bytes.len(),
1670 });
1671 }
1672 let mut data = Vec::with_capacity(numel);
1673 for i in 0..numel {
1674 let off = i * 8;
1675 let mut buf = [0u8; 8];
1676 buf.copy_from_slice(&bytes[off..off + 8]);
1677 data.push(f64::from_le_bytes(buf));
1678 }
1679 Ok(Tensor {
1680 buffer: Buffer::from_vec(data),
1681 shape: shape.to_vec(),
1682 strides: Self::compute_strides(shape),
1683 offset: 0,
1684 })
1685 }
1686 "f32" => {
1687 let expected = numel * 4;
1688 if bytes.len() != expected {
1689 return Err(RuntimeError::ShapeMismatch {
1690 expected,
1691 got: bytes.len(),
1692 });
1693 }
1694 let mut data = Vec::with_capacity(numel);
1695 for i in 0..numel {
1696 let off = i * 4;
1697 let mut buf = [0u8; 4];
1698 buf.copy_from_slice(&bytes[off..off + 4]);
1699 data.push(f32::from_le_bytes(buf) as f64);
1700 }
1701 Ok(Tensor {
1702 buffer: Buffer::from_vec(data),
1703 shape: shape.to_vec(),
1704 strides: Self::compute_strides(shape),
1705 offset: 0,
1706 })
1707 }
1708 _ => Err(RuntimeError::InvalidOperation(
1709 format!("from_bytes: unsupported dtype '{}', expected 'f32' or 'f64'", dtype),
1710 )),
1711 }
1712 }
1713
1714 pub fn split_heads(&self, num_heads: usize) -> Result<Tensor, RuntimeError> {
1722 if self.ndim() != 3 {
1723 return Err(RuntimeError::DimensionMismatch {
1724 expected: 3,
1725 got: self.ndim(),
1726 });
1727 }
1728 let batch = self.shape[0];
1729 let seq = self.shape[1];
1730 let model_dim = self.shape[2];
1731 if model_dim % num_heads != 0 {
1732 return Err(RuntimeError::InvalidOperation(
1733 format!(
1734 "split_heads: model_dim {} not divisible by num_heads {}",
1735 model_dim, num_heads
1736 ),
1737 ));
1738 }
1739 let head_dim = model_dim / num_heads;
1740 let tensor = if self.is_contiguous() { self.clone() } else { self.to_contiguous() };
1742 let reshaped = Tensor {
1744 buffer: tensor.buffer.clone(),
1745 shape: vec![batch, seq, num_heads, head_dim],
1746 strides: Self::compute_strides(&[batch, seq, num_heads, head_dim]),
1747 offset: 0,
1748 };
1749 Ok(Tensor {
1752 buffer: reshaped.buffer,
1753 shape: vec![batch, num_heads, seq, head_dim],
1754 strides: vec![
1755 reshaped.strides[0], reshaped.strides[2], reshaped.strides[1], reshaped.strides[3], ],
1760 offset: 0,
1761 })
1762 }
1763
1764 pub fn merge_heads(&self) -> Result<Tensor, RuntimeError> {
1767 if self.ndim() != 4 {
1768 return Err(RuntimeError::DimensionMismatch {
1769 expected: 4,
1770 got: self.ndim(),
1771 });
1772 }
1773 let batch = self.shape[0];
1774 let num_heads = self.shape[1];
1775 let seq = self.shape[2];
1776 let head_dim = self.shape[3];
1777 let transposed = Tensor {
1780 buffer: self.buffer.clone(),
1781 shape: vec![batch, seq, num_heads, head_dim],
1782 strides: vec![
1783 self.strides[0],
1784 self.strides[2], self.strides[1], self.strides[3],
1787 ],
1788 offset: self.offset,
1789 };
1790 let contig = transposed.to_contiguous();
1792 let model_dim = num_heads * head_dim;
1793 Ok(Tensor {
1794 buffer: contig.buffer,
1795 shape: vec![batch, seq, model_dim],
1796 strides: Self::compute_strides(&[batch, seq, model_dim]),
1797 offset: 0,
1798 })
1799 }
1800
1801 pub fn view_reshape(&self, new_shape: &[usize]) -> Result<Tensor, RuntimeError> {
1804 self.reshape(new_shape)
1805 }
1806
1807 pub fn argsort(&self) -> Tensor {
1814 let data = self.to_vec();
1815 let mut indices: Vec<usize> = (0..data.len()).collect();
1816 indices.sort_by(|&a, &b| data[a].total_cmp(&data[b]));
1817 let result: Vec<f64> = indices.iter().map(|&i| i as f64).collect();
1818 Tensor::from_vec_unchecked(result, &[data.len()])
1819 }
1820
1821 pub fn gather(&self, dim: usize, indices: &Tensor) -> Result<Tensor, RuntimeError> {
1826 let data = self.to_vec();
1827 let idx_data = indices.to_vec();
1828 if self.ndim() == 1 {
1829 let mut result = Vec::with_capacity(idx_data.len());
1830 for &idx in &idx_data {
1831 let i = idx as usize;
1832 if i >= data.len() {
1833 return Err(RuntimeError::InvalidOperation(
1834 format!("gather: index {} out of bounds for size {}", i, data.len()),
1835 ));
1836 }
1837 result.push(data[i]);
1838 }
1839 Ok(Tensor::from_vec_unchecked(result, indices.shape()))
1840 } else if self.ndim() == 2 {
1841 let rows = self.shape[0];
1842 let cols = self.shape[1];
1843 let idx_shape = indices.shape();
1844 let out_rows = idx_shape[0];
1845 let out_cols = idx_shape[1];
1846 let mut result = vec![0.0; out_rows * out_cols];
1847 for i in 0..out_rows {
1848 for j in 0..out_cols {
1849 let idx = idx_data[i * out_cols + j] as usize;
1850 let val = if dim == 0 {
1851 if idx >= rows {
1852 return Err(RuntimeError::InvalidOperation(
1853 format!("gather dim=0: index {} out of bounds for {} rows", idx, rows),
1854 ));
1855 }
1856 data[idx * cols + j]
1857 } else {
1858 if idx >= cols {
1859 return Err(RuntimeError::InvalidOperation(
1860 format!("gather dim=1: index {} out of bounds for {} cols", idx, cols),
1861 ));
1862 }
1863 data[i * cols + idx]
1864 };
1865 result[i * out_cols + j] = val;
1866 }
1867 }
1868 Ok(Tensor::from_vec_unchecked(result, idx_shape))
1869 } else {
1870 Err(RuntimeError::InvalidOperation(
1871 "gather: only 1D and 2D tensors supported".into(),
1872 ))
1873 }
1874 }
1875
1876 pub fn scatter(&self, dim: usize, indices: &Tensor, src: &Tensor) -> Result<Tensor, RuntimeError> {
1881 let mut result = self.to_vec();
1882 let idx_data = indices.to_vec();
1883 let src_data = src.to_vec();
1884 if self.ndim() == 1 {
1885 for (k, &idx) in idx_data.iter().enumerate() {
1886 let i = idx as usize;
1887 if i >= result.len() {
1888 return Err(RuntimeError::InvalidOperation(
1889 format!("scatter: index {} out of bounds for size {}", i, result.len()),
1890 ));
1891 }
1892 result[i] = src_data[k];
1893 }
1894 Ok(Tensor::from_vec_unchecked(result, self.shape()))
1895 } else if self.ndim() == 2 {
1896 let cols = self.shape[1];
1897 let idx_shape = indices.shape();
1898 let out_cols = idx_shape[1];
1899 let out_rows = idx_shape[0];
1900 for i in 0..out_rows {
1901 for j in 0..out_cols {
1902 let idx = idx_data[i * out_cols + j] as usize;
1903 let src_val = src_data[i * out_cols + j];
1904 if dim == 0 {
1905 if idx >= self.shape[0] {
1906 return Err(RuntimeError::InvalidOperation(
1907 format!("scatter dim=0: index {} out of bounds for {} rows", idx, self.shape[0]),
1908 ));
1909 }
1910 result[idx * cols + j] = src_val;
1911 } else {
1912 if idx >= cols {
1913 return Err(RuntimeError::InvalidOperation(
1914 format!("scatter dim=1: index {} out of bounds for {} cols", idx, cols),
1915 ));
1916 }
1917 result[i * cols + idx] = src_val;
1918 }
1919 }
1920 }
1921 Ok(Tensor::from_vec_unchecked(result, self.shape()))
1922 } else {
1923 Err(RuntimeError::InvalidOperation(
1924 "scatter: only 1D and 2D tensors supported".into(),
1925 ))
1926 }
1927 }
1928
1929 pub fn index_select(&self, dim: usize, indices: &Tensor) -> Result<Tensor, RuntimeError> {
1933 let data = self.to_vec();
1934 let idx_data = indices.to_vec();
1935 if self.ndim() == 1 {
1936 let mut result = Vec::with_capacity(idx_data.len());
1937 for &idx in &idx_data {
1938 let i = idx as usize;
1939 if i >= data.len() {
1940 return Err(RuntimeError::InvalidOperation(
1941 format!("index_select: index {} out of bounds for size {}", i, data.len()),
1942 ));
1943 }
1944 result.push(data[i]);
1945 }
1946 Ok(Tensor::from_vec_unchecked(result, &[idx_data.len()]))
1947 } else if self.ndim() == 2 {
1948 let rows = self.shape[0];
1949 let cols = self.shape[1];
1950 let n = idx_data.len();
1951 if dim == 0 {
1952 let mut result = Vec::with_capacity(n * cols);
1953 for &idx in &idx_data {
1954 let i = idx as usize;
1955 if i >= rows {
1956 return Err(RuntimeError::InvalidOperation(
1957 format!("index_select dim=0: index {} out of bounds for {} rows", i, rows),
1958 ));
1959 }
1960 for j in 0..cols {
1961 result.push(data[i * cols + j]);
1962 }
1963 }
1964 Ok(Tensor::from_vec_unchecked(result, &[n, cols]))
1965 } else {
1966 let mut result = Vec::with_capacity(rows * n);
1967 for i in 0..rows {
1968 for &idx in &idx_data {
1969 let j = idx as usize;
1970 if j >= cols {
1971 return Err(RuntimeError::InvalidOperation(
1972 format!("index_select dim=1: index {} out of bounds for {} cols", j, cols),
1973 ));
1974 }
1975 result.push(data[i * cols + j]);
1976 }
1977 }
1978 Ok(Tensor::from_vec_unchecked(result, &[rows, n]))
1979 }
1980 } else {
1981 Err(RuntimeError::InvalidOperation(
1982 "index_select: only 1D and 2D tensors supported".into(),
1983 ))
1984 }
1985 }
1986
1987 pub fn tensor_where(&self, condition: &Tensor, other: &Tensor) -> Result<Tensor, RuntimeError> {
1994 if self.shape() != condition.shape() || self.shape() != other.shape() {
1995 return Err(RuntimeError::InvalidOperation(
1996 format!("where: shape mismatch self={:?} cond={:?} other={:?}",
1997 self.shape(), condition.shape(), other.shape()),
1998 ));
1999 }
2000 let s = self.to_vec();
2001 let c = condition.to_vec();
2002 let o = other.to_vec();
2003 let result: Vec<f64> = s.iter().zip(c.iter()).zip(o.iter())
2004 .map(|((&sv, &cv), &ov)| if cv != 0.0 { sv } else { ov })
2005 .collect();
2006 Tensor::from_vec(result, self.shape())
2007 }
2008
2009 pub fn any(&self) -> bool {
2011 let data = self.to_vec();
2012 data.iter().any(|&x| x != 0.0)
2013 }
2014
2015 pub fn all(&self) -> bool {
2017 let data = self.to_vec();
2018 data.iter().all(|&x| x != 0.0)
2019 }
2020
2021 pub fn nonzero(&self) -> Tensor {
2023 let data = self.to_vec();
2024 let indices: Vec<f64> = data.iter().enumerate()
2025 .filter(|(_, &v)| v != 0.0)
2026 .map(|(i, _)| i as f64)
2027 .collect();
2028 let len = indices.len();
2029 if len == 0 {
2030 Tensor::from_vec(vec![], &[0]).unwrap()
2031 } else {
2032 Tensor::from_vec(indices, &[len]).unwrap()
2033 }
2034 }
2035
2036 pub fn masked_fill(&self, mask: &Tensor, value: f64) -> Result<Tensor, RuntimeError> {
2038 if self.shape() != mask.shape() {
2039 return Err(RuntimeError::InvalidOperation(
2040 format!("masked_fill: shape mismatch self={:?} mask={:?}",
2041 self.shape(), mask.shape()),
2042 ));
2043 }
2044 let data = self.to_vec();
2045 let m = mask.to_vec();
2046 let result: Vec<f64> = data.iter().zip(m.iter())
2047 .map(|(&d, &mv)| if mv != 0.0 { value } else { d })
2048 .collect();
2049 Tensor::from_vec(result, self.shape())
2050 }
2051
2052 fn reduce_axis<F>(&self, axis: usize, keepdim: bool, reduce_fn: F)
2059 -> Result<Tensor, RuntimeError>
2060 where
2061 F: Fn(&[f64]) -> f64,
2062 {
2063 let ndim = self.ndim();
2064 if axis >= ndim {
2065 return Err(RuntimeError::IndexOutOfBounds {
2066 index: axis,
2067 length: ndim,
2068 });
2069 }
2070
2071 let axis_len = self.shape[axis];
2072 let mut out_shape: Vec<usize> = self.shape.clone();
2074 out_shape[axis] = 1;
2075 let out_numel = Self::shape_numel(&out_shape);
2076 let out_strides = Self::compute_strides(&out_shape);
2077
2078 let data = self.to_vec();
2079 let mut result = Vec::with_capacity(out_numel);
2080 let mut indices = vec![0usize; ndim];
2081
2082 for out_idx in 0..out_numel {
2083 {
2085 let mut remaining = out_idx;
2086 for d in 0..ndim {
2087 indices[d] = remaining / out_strides[d];
2088 remaining %= out_strides[d];
2089 }
2090 }
2091
2092 let mut vals = Vec::with_capacity(axis_len);
2094 for k in 0..axis_len {
2095 let mut flat = self.offset;
2096 for d in 0..ndim {
2097 let idx = if d == axis { k } else { indices[d] };
2098 flat += idx * self.strides[d];
2099 }
2100 vals.push(data[flat]);
2101 }
2102 result.push(reduce_fn(&vals));
2103 }
2104
2105 let final_shape = if keepdim {
2106 out_shape
2107 } else {
2108 let mut s: Vec<usize> = self.shape.iter().enumerate()
2110 .filter(|&(i, _)| i != axis)
2111 .map(|(_, &v)| v)
2112 .collect();
2113 if s.is_empty() {
2114 s.push(1); }
2116 s
2117 };
2118
2119 Tensor::from_vec(result, &final_shape)
2120 }
2121
2122 pub fn mean_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2124 self.reduce_axis(axis, keepdim, |vals| {
2125 let mut acc = BinnedAccumulatorF64::new();
2126 for &v in vals { acc.add(v); }
2127 acc.finalize() / vals.len() as f64
2128 })
2129 }
2130
2131 pub fn max_axis(&self, axis: usize, keepdim: bool) -> Result<(Tensor, Tensor), RuntimeError> {
2133 let ndim = self.ndim();
2134 if axis >= ndim {
2135 return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2136 }
2137 let axis_len = self.shape[axis];
2138 let mut out_shape = self.shape.clone();
2139 out_shape[axis] = 1;
2140 let out_numel = Self::shape_numel(&out_shape);
2141 let out_strides = Self::compute_strides(&out_shape);
2142 let data = self.to_vec();
2143 let mut values = Vec::with_capacity(out_numel);
2144 let mut idx_vals = Vec::with_capacity(out_numel);
2145 let mut indices = vec![0usize; ndim];
2146
2147 for out_idx in 0..out_numel {
2148 let mut remaining = out_idx;
2149 for d in 0..ndim {
2150 indices[d] = remaining / out_strides[d];
2151 remaining %= out_strides[d];
2152 }
2153 let mut best_val = f64::NEG_INFINITY;
2154 let mut best_idx = 0usize;
2155 for k in 0..axis_len {
2156 let mut flat = self.offset;
2157 for d in 0..ndim {
2158 let idx = if d == axis { k } else { indices[d] };
2159 flat += idx * self.strides[d];
2160 }
2161 let v = data[flat];
2162 if v > best_val {
2163 best_val = v;
2164 best_idx = k;
2165 }
2166 }
2167 values.push(best_val);
2168 idx_vals.push(best_idx as f64);
2169 }
2170
2171 let final_shape = if keepdim {
2172 out_shape
2173 } else {
2174 let mut s: Vec<usize> = self.shape.iter().enumerate()
2175 .filter(|&(i, _)| i != axis).map(|(_, &v)| v).collect();
2176 if s.is_empty() { s.push(1); }
2177 s
2178 };
2179 Ok((
2180 Tensor::from_vec(values, &final_shape)?,
2181 Tensor::from_vec(idx_vals, &final_shape)?,
2182 ))
2183 }
2184
2185 pub fn min_axis(&self, axis: usize, keepdim: bool) -> Result<(Tensor, Tensor), RuntimeError> {
2187 let ndim = self.ndim();
2188 if axis >= ndim {
2189 return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2190 }
2191 let axis_len = self.shape[axis];
2192 let mut out_shape = self.shape.clone();
2193 out_shape[axis] = 1;
2194 let out_numel = Self::shape_numel(&out_shape);
2195 let out_strides = Self::compute_strides(&out_shape);
2196 let data = self.to_vec();
2197 let mut values = Vec::with_capacity(out_numel);
2198 let mut idx_vals = Vec::with_capacity(out_numel);
2199 let mut indices = vec![0usize; ndim];
2200
2201 for out_idx in 0..out_numel {
2202 let mut remaining = out_idx;
2203 for d in 0..ndim {
2204 indices[d] = remaining / out_strides[d];
2205 remaining %= out_strides[d];
2206 }
2207 let mut best_val = f64::INFINITY;
2208 let mut best_idx = 0usize;
2209 for k in 0..axis_len {
2210 let mut flat = self.offset;
2211 for d in 0..ndim {
2212 let idx = if d == axis { k } else { indices[d] };
2213 flat += idx * self.strides[d];
2214 }
2215 let v = data[flat];
2216 if v < best_val {
2217 best_val = v;
2218 best_idx = k;
2219 }
2220 }
2221 values.push(best_val);
2222 idx_vals.push(best_idx as f64);
2223 }
2224
2225 let final_shape = if keepdim {
2226 out_shape
2227 } else {
2228 let mut s: Vec<usize> = self.shape.iter().enumerate()
2229 .filter(|&(i, _)| i != axis).map(|(_, &v)| v).collect();
2230 if s.is_empty() { s.push(1); }
2231 s
2232 };
2233 Ok((
2234 Tensor::from_vec(values, &final_shape)?,
2235 Tensor::from_vec(idx_vals, &final_shape)?,
2236 ))
2237 }
2238
2239 pub fn var_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2241 let mean_t = self.mean_axis(axis, true)?;
2242 let ndim = self.ndim();
2243 if axis >= ndim {
2244 return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2245 }
2246 let axis_len = self.shape[axis];
2247 let mut out_shape = self.shape.clone();
2248 out_shape[axis] = 1;
2249 let out_numel = Self::shape_numel(&out_shape);
2250 let out_strides = Self::compute_strides(&out_shape);
2251 let data = self.to_vec();
2252 let mean_data = mean_t.to_vec();
2253 let mut result = Vec::with_capacity(out_numel);
2254 let mut indices = vec![0usize; ndim];
2255
2256 for out_idx in 0..out_numel {
2257 let mut remaining = out_idx;
2258 for d in 0..ndim {
2259 indices[d] = remaining / out_strides[d];
2260 remaining %= out_strides[d];
2261 }
2262 let mu = mean_data[out_idx];
2263 let mut acc = BinnedAccumulatorF64::new();
2264 for k in 0..axis_len {
2265 let mut flat = self.offset;
2266 for d in 0..ndim {
2267 let idx = if d == axis { k } else { indices[d] };
2268 flat += idx * self.strides[d];
2269 }
2270 let diff = data[flat] - mu;
2271 acc.add(diff * diff);
2272 }
2273 result.push(acc.finalize() / axis_len as f64);
2274 }
2275
2276 let final_shape = if keepdim {
2277 out_shape
2278 } else {
2279 let mut s: Vec<usize> = self.shape.iter().enumerate()
2280 .filter(|&(i, _)| i != axis).map(|(_, &v)| v).collect();
2281 if s.is_empty() { s.push(1); }
2282 s
2283 };
2284 Tensor::from_vec(result, &final_shape)
2285 }
2286
2287 pub fn std_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2289 let var = self.var_axis(axis, keepdim)?;
2290 Ok(var.map(|x| x.sqrt()))
2291 }
2292
2293 pub fn prod_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2295 self.reduce_axis(axis, keepdim, |vals| {
2296 let mut product = 1.0f64;
2299 for &v in vals { product *= v; }
2300 product
2301 })
2302 }
2303
2304 pub fn sort_axis(&self, axis: usize, descending: bool) -> Result<Tensor, RuntimeError> {
2311 let ndim = self.ndim();
2312 if axis >= ndim {
2313 return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2314 }
2315 let data = self.to_vec();
2316 let axis_len = self.shape[axis];
2317 let out_shape = self.shape.clone();
2318 let out_numel = Self::shape_numel(&out_shape);
2319
2320 let mut iter_shape: Vec<usize> = Vec::new();
2322 for (i, &s) in self.shape.iter().enumerate() {
2323 if i != axis { iter_shape.push(s); }
2324 }
2325 let n_slices: usize = iter_shape.iter().product::<usize>().max(1);
2326
2327 let mut result = vec![0.0f64; out_numel];
2328
2329 let mut pos = vec![0usize; ndim];
2331 for slice_idx in 0..n_slices {
2332 let mut remaining = slice_idx;
2334 let mut dim_idx = 0;
2335 for d in 0..ndim {
2336 if d == axis {
2337 pos[d] = 0;
2338 } else {
2339 let stride = {
2340 let mut s = 1usize;
2341 let mut di = 0;
2342 for d2 in 0..ndim {
2343 if d2 == axis { continue; }
2344 if di > dim_idx { s *= self.shape[d2]; }
2345 di += 1;
2346 }
2347 s
2348 };
2349 pos[d] = remaining / stride;
2350 remaining %= stride;
2351 dim_idx += 1;
2352 }
2353 }
2354
2355 let mut vals: Vec<(f64, usize)> = Vec::with_capacity(axis_len);
2357 for k in 0..axis_len {
2358 let mut flat = self.offset;
2359 for d in 0..ndim {
2360 let idx = if d == axis { k } else { pos[d] };
2361 flat += idx * self.strides[d];
2362 }
2363 vals.push((data[flat], k));
2364 }
2365
2366 if descending {
2368 vals.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
2369 .then(a.1.cmp(&b.1)));
2370 } else {
2371 vals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
2372 .then(a.1.cmp(&b.1)));
2373 }
2374
2375 for (k, &(v, _)) in vals.iter().enumerate() {
2377 let mut flat = 0;
2378 let out_strides_local = Self::compute_strides(&out_shape);
2379 for d in 0..ndim {
2380 let idx = if d == axis { k } else { pos[d] };
2381 flat += idx * out_strides_local[d];
2382 }
2383 result[flat] = v;
2384 }
2385 }
2386
2387 Tensor::from_vec(result, &out_shape)
2388 }
2389
2390 pub fn argsort_axis(&self, axis: usize, descending: bool) -> Result<Tensor, RuntimeError> {
2392 let ndim = self.ndim();
2393 if axis >= ndim {
2394 return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2395 }
2396 let data = self.to_vec();
2397 let axis_len = self.shape[axis];
2398 let out_shape = self.shape.clone();
2399 let out_numel = Self::shape_numel(&out_shape);
2400
2401 let mut iter_shape: Vec<usize> = Vec::new();
2402 for (i, &s) in self.shape.iter().enumerate() {
2403 if i != axis { iter_shape.push(s); }
2404 }
2405 let n_slices: usize = iter_shape.iter().product::<usize>().max(1);
2406
2407 let mut result = vec![0.0f64; out_numel];
2408 let mut pos = vec![0usize; ndim];
2409
2410 for slice_idx in 0..n_slices {
2411 let mut remaining = slice_idx;
2412 let mut dim_idx = 0;
2413 for d in 0..ndim {
2414 if d == axis {
2415 pos[d] = 0;
2416 } else {
2417 let stride = {
2418 let mut s = 1usize;
2419 let mut di = 0;
2420 for d2 in 0..ndim {
2421 if d2 == axis { continue; }
2422 if di > dim_idx { s *= self.shape[d2]; }
2423 di += 1;
2424 }
2425 s
2426 };
2427 pos[d] = remaining / stride;
2428 remaining %= stride;
2429 dim_idx += 1;
2430 }
2431 }
2432
2433 let mut vals: Vec<(f64, usize)> = Vec::with_capacity(axis_len);
2434 for k in 0..axis_len {
2435 let mut flat = self.offset;
2436 for d in 0..ndim {
2437 let idx = if d == axis { k } else { pos[d] };
2438 flat += idx * self.strides[d];
2439 }
2440 vals.push((data[flat], k));
2441 }
2442
2443 if descending {
2444 vals.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
2445 .then(a.1.cmp(&b.1)));
2446 } else {
2447 vals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
2448 .then(a.1.cmp(&b.1)));
2449 }
2450
2451 for (k, &(_, orig_idx)) in vals.iter().enumerate() {
2452 let out_strides_local = Self::compute_strides(&out_shape);
2453 let mut flat = 0;
2454 for d in 0..ndim {
2455 let idx = if d == axis { k } else { pos[d] };
2456 flat += idx * out_strides_local[d];
2457 }
2458 result[flat] = orig_idx as f64;
2459 }
2460 }
2461
2462 Tensor::from_vec(result, &out_shape)
2463 }
2464
2465 pub fn einsum(notation: &str, inputs: &[&Tensor]) -> Result<Tensor, RuntimeError> {
2474 let parts: Vec<&str> = notation.split("->").collect();
2476 if parts.len() != 2 {
2477 return Err(RuntimeError::InvalidOperation(
2478 format!("einsum: expected 'subscripts->output' notation, got '{}'", notation),
2479 ));
2480 }
2481 let input_specs: Vec<&str> = parts[0].split(',').collect();
2482 let output_spec = parts[1];
2483
2484 if input_specs.len() != inputs.len() {
2485 return Err(RuntimeError::InvalidOperation(
2486 format!("einsum: {} input specs but {} tensors", input_specs.len(), inputs.len()),
2487 ));
2488 }
2489
2490 let mut label_size = std::collections::BTreeMap::new();
2492 for (i, &spec) in input_specs.iter().enumerate() {
2493 let chars: Vec<char> = spec.chars().collect();
2494 if chars.len() != inputs[i].ndim() {
2495 return Err(RuntimeError::InvalidOperation(
2496 format!("einsum: spec '{}' has {} dims but tensor has {}", spec, chars.len(), inputs[i].ndim()),
2497 ));
2498 }
2499 for (d, &c) in chars.iter().enumerate() {
2500 let sz = inputs[i].shape()[d];
2501 if let Some(&prev) = label_size.get(&c) {
2502 if prev != sz {
2503 return Err(RuntimeError::InvalidOperation(
2504 format!("einsum: label '{}' has conflicting sizes {} vs {}", c, prev, sz),
2505 ));
2506 }
2507 } else {
2508 label_size.insert(c, sz);
2509 }
2510 }
2511 }
2512
2513 let output_chars: Vec<char> = output_spec.chars().collect();
2515 let output_shape: Vec<usize> = output_chars.iter()
2516 .map(|c| label_size.get(c).copied().ok_or_else(||
2517 RuntimeError::InvalidOperation(format!("einsum: unknown label '{}' in output", c))))
2518 .collect::<Result<_, _>>()?;
2519 let out_numel = Self::shape_numel(&output_shape);
2520
2521 let output_set: std::collections::BTreeSet<char> = output_chars.iter().copied().collect();
2523 let contract_labels: Vec<char> = label_size.keys()
2524 .filter(|c| !output_set.contains(c))
2525 .copied()
2526 .collect();
2527 let contract_sizes: Vec<usize> = contract_labels.iter()
2528 .map(|c| label_size[c])
2529 .collect();
2530 let contract_numel: usize = contract_sizes.iter().product::<usize>().max(1);
2531
2532 let input_chars: Vec<Vec<char>> = input_specs.iter().map(|s| s.chars().collect()).collect();
2534
2535 let out_strides = Self::compute_strides(&output_shape);
2537 let mut result = vec![0.0f64; out_numel];
2538
2539 let input_data: Vec<Vec<f64>> = inputs.iter().map(|t| t.to_vec()).collect();
2541 let input_strides: Vec<Vec<usize>> = inputs.iter().map(|t| t.strides.clone()).collect();
2542 let input_offsets: Vec<usize> = inputs.iter().map(|t| t.offset).collect();
2543
2544 for out_idx in 0..out_numel {
2545 let mut label_vals = std::collections::BTreeMap::new();
2547 let mut remaining = out_idx;
2548 for (d, &c) in output_chars.iter().enumerate() {
2549 let stride = if d < out_strides.len() { out_strides[d] } else { 1 };
2550 label_vals.insert(c, remaining / stride);
2551 remaining %= stride;
2552 }
2553
2554 let mut acc = BinnedAccumulatorF64::new();
2555 for cidx in 0..contract_numel {
2557 let mut cr = cidx;
2559 for (ci, &cl) in contract_labels.iter().enumerate() {
2560 let stride: usize = contract_sizes[ci+1..].iter().product::<usize>().max(1);
2561 label_vals.insert(cl, cr / stride);
2562 cr %= stride;
2563 }
2564
2565 let mut product = 1.0f64;
2567 for (inp_idx, chars) in input_chars.iter().enumerate() {
2568 let mut flat = input_offsets[inp_idx];
2569 for (d, &c) in chars.iter().enumerate() {
2570 flat += label_vals[&c] * input_strides[inp_idx][d];
2571 }
2572 product *= input_data[inp_idx][flat];
2573 }
2574 acc.add(product);
2575 }
2576 result[out_idx] = acc.finalize();
2577 }
2578
2579 if output_shape.is_empty() {
2580 Tensor::from_vec(result, &[1])
2581 } else {
2582 Tensor::from_vec(result, &output_shape)
2583 }
2584 }
2585
2586 pub fn unsqueeze(&self, dim: usize) -> Result<Tensor, RuntimeError> {
2592 let ndim = self.ndim();
2593 if dim > ndim {
2594 return Err(RuntimeError::IndexOutOfBounds { index: dim, length: ndim + 1 });
2595 }
2596 let mut new_shape = self.shape.clone();
2597 new_shape.insert(dim, 1);
2598 self.reshape(&new_shape)
2599 }
2600
2601 pub fn squeeze(&self, dim: Option<usize>) -> Result<Tensor, RuntimeError> {
2604 match dim {
2605 Some(d) => {
2606 if d >= self.ndim() {
2607 return Err(RuntimeError::IndexOutOfBounds { index: d, length: self.ndim() });
2608 }
2609 if self.shape[d] != 1 {
2610 return Err(RuntimeError::InvalidOperation(
2611 format!("squeeze: dimension {} has size {}, not 1", d, self.shape[d]),
2612 ));
2613 }
2614 let mut new_shape = self.shape.clone();
2615 new_shape.remove(d);
2616 if new_shape.is_empty() {
2617 new_shape.push(1); }
2619 self.reshape(&new_shape)
2620 }
2621 None => {
2622 let new_shape: Vec<usize> = self.shape.iter()
2623 .filter(|&&s| s != 1)
2624 .copied()
2625 .collect();
2626 let new_shape = if new_shape.is_empty() { vec![1] } else { new_shape };
2627 self.reshape(&new_shape)
2628 }
2629 }
2630 }
2631
2632 pub fn expand(&self, target_shape: &[usize]) -> Result<Tensor, RuntimeError> {
2635 self.broadcast_to(target_shape)
2636 }
2637
2638 pub fn flatten(&self, start_dim: usize, end_dim: usize) -> Result<Tensor, RuntimeError> {
2640 if start_dim > end_dim || end_dim >= self.ndim() {
2641 return Err(RuntimeError::InvalidOperation(
2642 format!("flatten: invalid dim range [{}, {}] for {}D tensor", start_dim, end_dim, self.ndim()),
2643 ));
2644 }
2645 let mut new_shape = Vec::new();
2646 for i in 0..start_dim {
2647 new_shape.push(self.shape[i]);
2648 }
2649 let flat_size: usize = self.shape[start_dim..=end_dim].iter().product();
2650 new_shape.push(flat_size);
2651 for i in (end_dim + 1)..self.ndim() {
2652 new_shape.push(self.shape[i]);
2653 }
2654 self.reshape(&new_shape)
2655 }
2656
2657 pub fn chunk(&self, n: usize, dim: usize) -> Result<Vec<Tensor>, RuntimeError> {
2659 if dim >= self.ndim() {
2660 return Err(RuntimeError::IndexOutOfBounds { index: dim, length: self.ndim() });
2661 }
2662 if n == 0 {
2663 return Err(RuntimeError::InvalidOperation("chunk: n must be > 0".into()));
2664 }
2665 let dim_size = self.shape[dim];
2666 let chunk_size = (dim_size + n - 1) / n;
2667 let mut sizes = Vec::new();
2668 let mut remaining = dim_size;
2669 while remaining > 0 {
2670 let s = remaining.min(chunk_size);
2671 sizes.push(s);
2672 remaining -= s;
2673 }
2674 self.split(&sizes, dim)
2675 }
2676
2677 pub fn split(&self, sizes: &[usize], dim: usize) -> Result<Vec<Tensor>, RuntimeError> {
2679 if dim >= self.ndim() {
2680 return Err(RuntimeError::IndexOutOfBounds { index: dim, length: self.ndim() });
2681 }
2682 let total: usize = sizes.iter().sum();
2683 if total != self.shape[dim] {
2684 return Err(RuntimeError::InvalidOperation(
2685 format!("split: sizes sum {} != dim size {}", total, self.shape[dim]),
2686 ));
2687 }
2688
2689 let mut results = Vec::new();
2690 let mut offset = 0;
2691
2692 for &sz in sizes {
2693 let ranges: Vec<(usize, usize)> = self.shape.iter()
2694 .enumerate()
2695 .map(|(i, &s)| {
2696 if i == dim { (offset, offset + sz) } else { (0, s) }
2697 })
2698 .collect();
2699 let chunk = self.slice(&ranges)?;
2700 results.push(chunk.to_contiguous());
2702 offset += sz;
2703 }
2704
2705 Ok(results)
2706 }
2707
2708 pub fn scale_add(&self, alpha: f64, other: &Tensor, beta: f64) -> Result<Tensor, RuntimeError> {
2713 if self.shape != other.shape {
2714 return Err(RuntimeError::InvalidOperation(
2715 "scale_add: shape mismatch".to_string(),
2716 ));
2717 }
2718 let a = self.to_vec();
2719 let b = other.to_vec();
2720 let result: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| alpha * x + beta * y).collect();
2721 Tensor::from_vec(result, &self.shape)
2722 }
2723}
2724