1use cjc_repro::Rng;
43
44use crate::accumulator::{binned_sum_f64, BinnedAccumulatorF64};
45use cjc_repro::KahanAccumulatorF64;
46
47use crate::accumulator;
48use crate::buffer::Buffer;
49use crate::dispatch;
50use crate::error::RuntimeError;
51use crate::kernel as kernel_fns;
52use crate::tensor_simd::{self, BinOp, UnaryOp};
53use crate::tensor_tiled::TiledMatmul;
54
55#[derive(Debug, Clone)]
73pub struct Tensor {
74 pub buffer: Buffer<f64>,
76 pub(crate) shape: Vec<usize>,
78 pub(crate) strides: Vec<usize>,
82 pub(crate) offset: usize,
85}
86
87impl Tensor {
88 pub(crate) fn compute_strides(shape: &[usize]) -> Vec<usize> {
92 let mut strides = vec![1usize; shape.len()];
93 for i in (0..shape.len().saturating_sub(1)).rev() {
94 strides[i] = strides[i + 1] * shape[i + 1];
95 }
96 strides
97 }
98
99 fn shape_numel(shape: &[usize]) -> usize {
101 shape.iter().product()
102 }
103
104 pub fn zeros(shape: &[usize]) -> Self {
106 let numel = Self::shape_numel(shape);
107 Tensor {
108 buffer: Buffer::alloc(numel, 0.0),
109 shape: shape.to_vec(),
110 strides: Self::compute_strides(shape),
111 offset: 0,
112 }
113 }
114
115 pub fn ones(shape: &[usize]) -> Self {
117 let numel = Self::shape_numel(shape);
118 Tensor {
119 buffer: Buffer::alloc(numel, 1.0),
120 shape: shape.to_vec(),
121 strides: Self::compute_strides(shape),
122 offset: 0,
123 }
124 }
125
126 pub fn randn(shape: &[usize], rng: &mut Rng) -> Self {
129 let numel = Self::shape_numel(shape);
130 let data: Vec<f64> = (0..numel).map(|_| rng.next_normal_f64()).collect();
131 Tensor {
132 buffer: Buffer::from_vec(data),
133 shape: shape.to_vec(),
134 strides: Self::compute_strides(shape),
135 offset: 0,
136 }
137 }
138
139 pub fn from_vec(data: Vec<f64>, shape: &[usize]) -> Result<Self, RuntimeError> {
142 let numel = Self::shape_numel(shape);
143 if data.len() != numel {
144 return Err(RuntimeError::ShapeMismatch {
145 expected: numel,
146 got: data.len(),
147 });
148 }
149 Ok(Tensor {
150 buffer: Buffer::from_vec(data),
151 shape: shape.to_vec(),
152 strides: Self::compute_strides(shape),
153 offset: 0,
154 })
155 }
156
157 pub fn shape(&self) -> &[usize] {
161 &self.shape
162 }
163
164 pub fn ndim(&self) -> usize {
166 self.shape.len()
167 }
168
169 pub fn len(&self) -> usize {
171 Self::shape_numel(&self.shape)
172 }
173
174 pub fn is_empty(&self) -> bool {
176 self.len() == 0
177 }
178
179 fn linear_index(&self, indices: &[usize]) -> Result<usize, RuntimeError> {
181 if indices.len() != self.shape.len() {
182 return Err(RuntimeError::DimensionMismatch {
183 expected: self.shape.len(),
184 got: indices.len(),
185 });
186 }
187 let mut off = self.offset;
188 for (i, &idx) in indices.iter().enumerate() {
189 if idx >= self.shape[i] {
190 return Err(RuntimeError::IndexOutOfBounds {
191 index: idx,
192 length: self.shape[i],
193 });
194 }
195 off += idx * self.strides[i];
196 }
197 Ok(off)
198 }
199
200 pub fn is_contiguous(&self) -> bool {
202 if self.offset != 0 {
203 return false;
204 }
205 let expected = Self::compute_strides(&self.shape);
206 self.strides == expected
207 }
208
209 pub fn slice(&self, ranges: &[(usize, usize)]) -> Result<Tensor, RuntimeError> {
212 if ranges.len() != self.shape.len() {
213 return Err(RuntimeError::DimensionMismatch {
214 expected: self.shape.len(),
215 got: ranges.len(),
216 });
217 }
218 let mut new_offset = self.offset;
219 let mut new_shape = Vec::with_capacity(ranges.len());
220 for (i, &(start, end)) in ranges.iter().enumerate() {
221 if end > self.shape[i] || start > end {
222 return Err(RuntimeError::IndexOutOfBounds {
223 index: end,
224 length: self.shape[i],
225 });
226 }
227 new_offset += start * self.strides[i];
228 new_shape.push(end - start);
229 }
230 Ok(Tensor {
231 buffer: self.buffer.clone(), shape: new_shape,
233 strides: self.strides.clone(),
234 offset: new_offset,
235 })
236 }
237
238 pub fn to_contiguous(&self) -> Tensor {
240 if self.is_contiguous() {
241 return self.clone();
242 }
243 let data = self.to_vec();
244 Tensor {
245 buffer: Buffer::from_vec(data),
246 shape: self.shape.clone(),
247 strides: Self::compute_strides(&self.shape),
248 offset: 0,
249 }
250 }
251
252 pub fn broadcast_to(&self, target_shape: &[usize]) -> Result<Tensor, RuntimeError> {
255 let src_ndim = self.shape.len();
256 let tgt_ndim = target_shape.len();
257 if tgt_ndim < src_ndim {
258 return Err(RuntimeError::InvalidOperation(
259 "cannot broadcast to a smaller rank".to_string(),
260 ));
261 }
262 let pad = tgt_ndim - src_ndim;
263 let mut new_strides = vec![0usize; tgt_ndim];
264 for i in 0..tgt_ndim {
265 if i < pad {
266 new_strides[i] = 0;
268 } else {
269 let src_i = i - pad;
270 if self.shape[src_i] == target_shape[i] {
271 new_strides[i] = self.strides[src_i];
272 } else if self.shape[src_i] == 1 {
273 new_strides[i] = 0; } else {
275 return Err(RuntimeError::ShapeMismatch {
276 expected: target_shape[i],
277 got: self.shape[src_i],
278 });
279 }
280 }
281 }
282 Ok(Tensor {
283 buffer: self.buffer.clone(),
284 shape: target_shape.to_vec(),
285 strides: new_strides,
286 offset: self.offset,
287 })
288 }
289
290 pub fn get(&self, indices: &[usize]) -> Result<f64, RuntimeError> {
292 let offset = self.linear_index(indices)?;
293 self.buffer
294 .get(offset)
295 .ok_or(RuntimeError::IndexOutOfBounds {
296 index: offset,
297 length: self.buffer.len(),
298 })
299 }
300
301 pub fn set(&mut self, indices: &[usize], val: f64) -> Result<(), RuntimeError> {
303 let offset = self.linear_index(indices)?;
304 self.buffer.set(offset, val)
305 }
306
307 pub fn to_vec(&self) -> Vec<f64> {
309 if self.is_contiguous() {
310 let full = self.buffer.borrow_data();
311 let numel = self.len();
312 if full.len() == numel {
313 return full.to_vec();
314 }
315 return full[..numel].to_vec();
318 }
319 let numel = self.len();
321 let mut result = Vec::with_capacity(numel);
322 let ndim = self.shape.len();
323 let mut indices = vec![0usize; ndim];
324 for _ in 0..numel {
325 let mut off = self.offset;
326 for d in 0..ndim {
327 off += indices[d] * self.strides[d];
328 }
329 result.push(self.buffer.get(off).unwrap_or(0.0));
330 for d in (0..ndim).rev() {
332 indices[d] += 1;
333 if indices[d] < self.shape[d] {
334 break;
335 }
336 indices[d] = 0;
337 }
338 }
339 result
340 }
341
342 pub fn reshape(&self, new_shape: &[usize]) -> Result<Tensor, RuntimeError> {
347 let new_numel = Self::shape_numel(new_shape);
348 if new_numel != self.len() {
349 return Err(RuntimeError::ShapeMismatch {
350 expected: self.len(),
351 got: new_numel,
352 });
353 }
354 let tensor = if self.is_contiguous() { self.clone() } else { self.to_contiguous() };
356 Ok(Tensor {
357 buffer: tensor.buffer,
358 shape: new_shape.to_vec(),
359 strides: Self::compute_strides(new_shape),
360 offset: 0,
361 })
362 }
363
364 fn elementwise_binop(
368 &self,
369 other: &Tensor,
370 op: impl Fn(f64, f64) -> f64,
371 ) -> Result<Tensor, RuntimeError> {
372 if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
373 let a = self.buffer.borrow_data();
375 let b = other.buffer.borrow_data();
376 let data: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| op(x, y)).collect();
377 return Ok(Tensor {
378 buffer: Buffer::from_vec(data),
379 shape: self.shape.clone(),
380 strides: Self::compute_strides(&self.shape),
381 offset: 0,
382 });
383 }
384
385 let result_shape = Self::broadcast_result_shape(&self.shape, &other.shape)?;
387 let a_broadcast = self.broadcast_to(&result_shape)?;
388 let b_broadcast = other.broadcast_to(&result_shape)?;
389
390 let numel = Self::shape_numel(&result_shape);
391 let ndim = result_shape.len();
392 let mut data = Vec::with_capacity(numel);
393 let mut indices = vec![0usize; ndim];
394
395 for _ in 0..numel {
396 let mut off_a = a_broadcast.offset;
397 let mut off_b = b_broadcast.offset;
398 for d in 0..ndim {
399 off_a += indices[d] * a_broadcast.strides[d];
400 off_b += indices[d] * b_broadcast.strides[d];
401 }
402 let va = a_broadcast.buffer.get(off_a).ok_or_else(|| {
403 RuntimeError::InvalidOperation(format!(
404 "broadcast binop: left operand index {} out of bounds (buffer len {})",
405 off_a,
406 a_broadcast.buffer.len()
407 ))
408 })?;
409 let vb = b_broadcast.buffer.get(off_b).ok_or_else(|| {
410 RuntimeError::InvalidOperation(format!(
411 "broadcast binop: right operand index {} out of bounds (buffer len {})",
412 off_b,
413 b_broadcast.buffer.len()
414 ))
415 })?;
416 data.push(op(va, vb));
417
418 for d in (0..ndim).rev() {
419 indices[d] += 1;
420 if indices[d] < result_shape[d] {
421 break;
422 }
423 indices[d] = 0;
424 }
425 }
426
427 Ok(Tensor {
428 buffer: Buffer::from_vec(data),
429 shape: result_shape.clone(),
430 strides: Self::compute_strides(&result_shape),
431 offset: 0,
432 })
433 }
434
435 fn broadcast_result_shape(a: &[usize], b: &[usize]) -> Result<Vec<usize>, RuntimeError> {
437 let max_ndim = a.len().max(b.len());
438 let mut result = Vec::with_capacity(max_ndim);
439 for i in 0..max_ndim {
440 let da = if i < max_ndim - a.len() { 1 } else { a[i - (max_ndim - a.len())] };
441 let db = if i < max_ndim - b.len() { 1 } else { b[i - (max_ndim - b.len())] };
442 if da == db {
443 result.push(da);
444 } else if da == 1 {
445 result.push(db);
446 } else if db == 1 {
447 result.push(da);
448 } else {
449 return Err(RuntimeError::ShapeMismatch {
450 expected: da,
451 got: db,
452 });
453 }
454 }
455 Ok(result)
456 }
457
458 fn elementwise_binop_simd(
463 &self,
464 other: &Tensor,
465 op: BinOp,
466 fallback: impl Fn(f64, f64) -> f64,
467 ) -> Result<Tensor, RuntimeError> {
468 if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
469 let a = self.buffer.borrow_data();
471 let b = other.buffer.borrow_data();
472 let data = tensor_simd::simd_binop(&a, &b, op);
473 return Ok(Tensor {
474 buffer: Buffer::from_vec(data),
475 shape: self.shape.clone(),
476 strides: Self::compute_strides(&self.shape),
477 offset: 0,
478 });
479 }
480 self.elementwise_binop(other, fallback)
482 }
483
484 pub fn add(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
486 self.elementwise_binop_simd(other, BinOp::Add, |a, b| a + b)
487 }
488
489 pub fn sub(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
491 self.elementwise_binop_simd(other, BinOp::Sub, |a, b| a - b)
492 }
493
494 pub fn mul_elem(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
496 self.elementwise_binop_simd(other, BinOp::Mul, |a, b| a * b)
497 }
498
499 pub fn div_elem(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
501 self.elementwise_binop_simd(other, BinOp::Div, |a, b| a / b)
502 }
503
504 pub fn fused_mul_add(&self, b: &Tensor, c: &Tensor) -> Result<Tensor, RuntimeError> {
510 if self.shape != b.shape || self.shape != c.shape {
511 return Err(RuntimeError::InvalidOperation(
512 "broadcast_fma: all three tensors must have the same shape".to_string(),
513 ));
514 }
515 if self.is_contiguous() && b.is_contiguous() && c.is_contiguous() {
516 let a_data = self.buffer.borrow_data();
517 let b_data = b.buffer.borrow_data();
518 let c_data = c.buffer.borrow_data();
519 let n = a_data.len();
520 let mut out = vec![0.0f64; n];
521 for i in 0..n {
524 out[i] = a_data[i] * b_data[i] + c_data[i];
525 }
526 return Ok(Tensor {
527 buffer: Buffer::from_vec(out),
528 shape: self.shape.clone(),
529 strides: Self::compute_strides(&self.shape),
530 offset: 0,
531 });
532 }
533 let temp = self.mul_elem(b)?;
535 temp.add(c)
536 }
537
538 pub fn fused_axpy(&self, alpha: f64, y: &Tensor) -> Result<Tensor, RuntimeError> {
544 if self.shape != y.shape {
545 return Err(RuntimeError::InvalidOperation(
546 "fused_axpy: self and y must have the same shape".to_string(),
547 ));
548 }
549 if self.is_contiguous() && y.is_contiguous() {
550 let x_data = self.buffer.borrow_data();
551 let y_data = y.buffer.borrow_data();
552 let n = x_data.len();
553 let mut out = vec![0.0f64; n];
554 for i in 0..n {
555 out[i] = alpha * x_data[i] + y_data[i];
556 }
557 return Ok(Tensor {
558 buffer: Buffer::from_vec(out),
559 shape: self.shape.clone(),
560 strides: Self::compute_strides(&self.shape),
561 offset: 0,
562 });
563 }
564 self.scalar_mul(alpha).add(y)
566 }
567
568 pub fn fused_mul_sub(&self, b: &Tensor, c: &Tensor) -> Result<Tensor, RuntimeError> {
571 if self.shape != b.shape || self.shape != c.shape {
572 return Err(RuntimeError::InvalidOperation(
573 "fused_mul_sub: all three tensors must have the same shape".to_string(),
574 ));
575 }
576 if self.is_contiguous() && b.is_contiguous() && c.is_contiguous() {
577 let a_data = self.buffer.borrow_data();
578 let b_data = b.buffer.borrow_data();
579 let c_data = c.buffer.borrow_data();
580 let n = a_data.len();
581 let mut out = vec![0.0f64; n];
582 for i in 0..n {
583 out[i] = a_data[i] * b_data[i] - c_data[i];
584 }
585 return Ok(Tensor {
586 buffer: Buffer::from_vec(out),
587 shape: self.shape.clone(),
588 strides: Self::compute_strides(&self.shape),
589 offset: 0,
590 });
591 }
592 let temp = self.mul_elem(b)?;
593 temp.sub(c)
594 }
595
596 pub fn fused_sub_sq(&self, b: &Tensor) -> Result<Tensor, RuntimeError> {
600 if self.shape != b.shape {
601 return Err(RuntimeError::InvalidOperation(
602 "fused_sub_sq: self and b must have the same shape".to_string(),
603 ));
604 }
605 if self.is_contiguous() && b.is_contiguous() {
606 let a_data = self.buffer.borrow_data();
607 let b_data = b.buffer.borrow_data();
608 let n = a_data.len();
609 let mut out = vec![0.0f64; n];
610 for i in 0..n {
611 let d = a_data[i] - b_data[i];
612 out[i] = d * d;
613 }
614 return Ok(Tensor {
615 buffer: Buffer::from_vec(out),
616 shape: self.shape.clone(),
617 strides: Self::compute_strides(&self.shape),
618 offset: 0,
619 });
620 }
621 let d = self.sub(b)?;
622 d.mul_elem(&d)
623 }
624
625 pub fn elem_pow(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
629 self.elementwise_binop(other, |a, b| a.powf(b))
630 }
631
632 pub fn elem_min(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
634 self.elementwise_binop(other, |a, b| a.min(b))
635 }
636
637 pub fn elem_max(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
639 self.elementwise_binop(other, |a, b| a.max(b))
640 }
641
642 pub fn elem_atan2(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
644 self.elementwise_binop(other, |a, b| a.atan2(b))
645 }
646
647 pub fn elem_hypot(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
649 self.elementwise_binop(other, |a, b| a.hypot(b))
650 }
651
652 pub fn map(&self, f: impl Fn(f64) -> f64) -> Tensor {
654 let data: Vec<f64> = self.to_vec().iter().map(|&x| f(x)).collect();
655 Tensor {
656 buffer: Buffer::from_vec(data),
657 shape: self.shape.clone(),
658 strides: Self::compute_strides(&self.shape),
659 offset: 0,
660 }
661 }
662
663 pub fn map_simd(&self, op: UnaryOp) -> Tensor {
668 let src = self.to_vec();
669 let data = tensor_simd::simd_unary(&src, op);
670 Tensor {
671 buffer: Buffer::from_vec(data),
672 shape: self.shape.clone(),
673 strides: Self::compute_strides(&self.shape),
674 offset: 0,
675 }
676 }
677
678 pub fn sum(&self) -> f64 {
682 let data = self.buffer.borrow_data();
683 binned_sum_f64(&data)
684 }
685
686 pub fn binned_sum(&self) -> f64 {
690 let data = self.buffer.borrow_data();
691 accumulator::binned_sum_f64(&data)
692 }
693
694 pub fn dispatched_sum(&self, ctx: &dispatch::ReductionContext) -> f64 {
698 let data = self.buffer.borrow_data();
699 dispatch::dispatch_sum_f64(&data, ctx)
700 }
701
702 pub fn mean(&self) -> f64 {
704 let n = self.len();
705 if n == 0 {
706 return 0.0;
707 }
708 self.sum() / n as f64
709 }
710
711 pub fn dispatched_mean(&self, ctx: &dispatch::ReductionContext) -> f64 {
713 let n = self.len();
714 if n == 0 {
715 return 0.0;
716 }
717 self.dispatched_sum(ctx) / n as f64
718 }
719
720 pub fn sum_axis(&self, axis: usize) -> Result<Tensor, RuntimeError> {
730 let ndim = self.ndim();
731 if axis >= ndim {
732 return Err(RuntimeError::IndexOutOfBounds {
733 index: axis,
734 length: ndim,
735 });
736 }
737
738 let mut out_shape = self.shape.clone();
740 out_shape[axis] = 1;
741 let out_numel = Self::shape_numel(&out_shape);
742 let out_strides = Self::compute_strides(&out_shape);
743
744 let data = self.to_vec();
745 let axis_len = self.shape[axis];
746 let mut result = vec![0.0f64; out_numel];
747
748 let mut indices = vec![0usize; ndim];
750 for out_idx in 0..out_numel {
751 {
753 let mut remaining = out_idx;
754 for d in 0..ndim {
755 indices[d] = remaining / out_strides[d];
756 remaining %= out_strides[d];
757 }
758 }
759
760 let mut acc = BinnedAccumulatorF64::new();
761 for k in 0..axis_len {
762 let mut flat = self.offset;
764 for d in 0..ndim {
765 let idx = if d == axis { k } else { indices[d] };
766 flat += idx * self.strides[d];
767 }
768 acc.add(data[flat]);
769 }
770 result[out_idx] = acc.finalize();
771 }
772
773 Tensor::from_vec(result, &out_shape)
774 }
775
776 pub fn neg(&self) -> Tensor {
780 self.map(|x| -x)
781 }
782
783 pub fn transpose(&self) -> Tensor {
786 let ndim = self.ndim();
787 if ndim <= 1 {
788 return self.clone();
789 }
790 let mut new_shape = self.shape.clone();
792 let mut new_strides = self.strides.clone();
793 new_shape.reverse();
794 new_strides.reverse();
795 Tensor {
796 buffer: self.buffer.clone(), shape: new_shape,
798 strides: new_strides,
799 offset: self.offset,
800 }
801 }
802
803 pub fn transpose_axes(&self, axes: &[usize]) -> Result<Tensor, RuntimeError> {
807 let ndim = self.ndim();
808 if axes.len() != ndim {
809 return Err(RuntimeError::InvalidOperation(
810 format!("transpose_axes: expected {} axes, got {}", ndim, axes.len()),
811 ));
812 }
813 let mut seen = vec![false; ndim];
815 for &ax in axes {
816 if ax >= ndim {
817 return Err(RuntimeError::IndexOutOfBounds { index: ax, length: ndim });
818 }
819 if seen[ax] {
820 return Err(RuntimeError::InvalidOperation(
821 format!("transpose_axes: duplicate axis {ax}"),
822 ));
823 }
824 seen[ax] = true;
825 }
826 let new_shape: Vec<usize> = axes.iter().map(|&ax| self.shape[ax]).collect();
827 let new_strides: Vec<usize> = axes.iter().map(|&ax| self.strides[ax]).collect();
828 Ok(Tensor {
829 buffer: self.buffer.clone(),
830 shape: new_shape,
831 strides: new_strides,
832 offset: self.offset,
833 })
834 }
835
836 pub fn scalar_mul(&self, s: f64) -> Tensor {
838 self.map(|x| x * s)
839 }
840
841 pub fn from_vec_unchecked(data: Vec<f64>, shape: &[usize]) -> Tensor {
846 Self::from_vec(data, shape).expect("Tensor::from_vec_unchecked: shape mismatch")
847 }
848
849 pub fn add_unchecked(&self, other: &Tensor) -> Tensor {
851 self.add(other).expect("Tensor::add shape mismatch")
852 }
853
854 pub fn add_assign_unchecked(&mut self, other: &Tensor) {
859 assert_eq!(self.shape, other.shape, "Tensor::add_assign shape mismatch");
860 self.buffer.make_unique();
861 let other_data = other.buffer.borrow_data();
862 let mut self_data = self.buffer.borrow_data_mut();
863 for (a, b) in self_data.iter_mut().zip(other_data.iter()) {
864 *a += b;
865 }
866 }
867
868 pub fn sub_unchecked(&self, other: &Tensor) -> Tensor {
870 self.sub(other).expect("Tensor::sub shape mismatch")
871 }
872
873 pub fn mul_elem_unchecked(&self, other: &Tensor) -> Tensor {
875 self.mul_elem(other).expect("Tensor::mul_elem shape mismatch")
876 }
877
878 pub fn div_elem_unchecked(&self, other: &Tensor) -> Tensor {
880 self.div_elem(other).expect("Tensor::div_elem shape mismatch")
881 }
882
883 pub fn matmul_unchecked(&self, other: &Tensor) -> Tensor {
885 self.matmul(other).expect("Tensor::matmul dimension mismatch")
886 }
887
888 pub fn matmul(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
892 if self.ndim() != 2 || other.ndim() != 2 {
893 return Err(RuntimeError::InvalidOperation(
894 "matmul requires 2-D tensors".to_string(),
895 ));
896 }
897 let m = self.shape[0];
898 let k = self.shape[1];
899 let k2 = other.shape[0];
900 let n = other.shape[1];
901 if k != k2 {
902 return Err(RuntimeError::DimensionMismatch {
903 expected: k,
904 got: k2,
905 });
906 }
907
908 let a = self.to_vec();
909 let b = other.to_vec();
910
911 #[cfg(feature = "parallel")]
914 {
915 if m >= 256 || n >= 256 || k >= 256 {
916 return Self::matmul_parallel_mode_a(&a, &b, m, n, k);
917 }
918 }
919
920 if m >= 64 || n >= 64 || k >= 64 {
925 return Self::matmul_tiled(&a, &b, m, n, k);
926 }
927
928 Self::matmul_sequential(&a, &b, m, n, k)
930 }
931
932 fn matmul_sequential(
934 a: &[f64], b: &[f64], m: usize, n: usize, k: usize,
935 ) -> Result<Tensor, RuntimeError> {
936 let mut result = vec![0.0f64; m * n];
937 for i in 0..m {
938 for j in 0..n {
939 let mut acc = KahanAccumulatorF64::new();
940 for p in 0..k {
941 acc.add(a[i * k + p] * b[p * n + j]);
942 }
943 result[i * n + j] = acc.finalize();
944 }
945 }
946 Tensor::from_vec(result, &[m, n])
947 }
948
949 fn matmul_tiled(
956 a: &[f64], b: &[f64], m: usize, n: usize, k: usize,
957 ) -> Result<Tensor, RuntimeError> {
958 let engine = TiledMatmul::new();
959 let result = engine.matmul(a, m, k, b, n);
960 Tensor::from_vec(result, &[m, n])
961 }
962
963 #[cfg(feature = "parallel")]
974 fn matmul_parallel_mode_a(
975 a: &[f64], b: &[f64], m: usize, n: usize, k: usize,
976 ) -> Result<Tensor, RuntimeError> {
977 use rayon::prelude::*;
978 use cjc_repro::KahanAccumulatorF64;
979
980 if m >= 512 && n >= 512 {
985 let result = crate::runtime_policy::run_parallel(|| {
992 let band_size =
993 ((m + rayon::current_num_threads() - 1) / rayon::current_num_threads()).max(64);
994 let mut result = vec![0.0f64; m * n];
995 result
996 .par_chunks_mut(band_size * n)
997 .enumerate()
998 .for_each(|(band_idx, band)| {
999 let i_start = band_idx * band_size;
1000 let i_end = (i_start + band_size).min(m);
1001 let band_m = i_end - i_start;
1002 let a_band = &a[i_start * k..i_end * k];
1003 let engine = crate::tensor_tiled::TiledMatmul::new();
1004 let tiled_result = engine.matmul(a_band, band_m, k, b, n);
1005 band[..band_m * n].copy_from_slice(&tiled_result);
1006 });
1007 result
1008 });
1009
1010 return Tensor::from_vec(result, &[m, n]);
1011 }
1012
1013 let result = crate::runtime_policy::run_parallel(|| {
1017 let mut result = vec![0.0f64; m * n];
1018 result
1019 .par_chunks_mut(n)
1020 .enumerate()
1021 .for_each(|(i, row)| {
1022 for j in 0..n {
1023 let mut acc = KahanAccumulatorF64::new();
1024 for p in 0..k {
1025 acc.add(a[i * k + p] * b[p * n + j]);
1026 }
1027 row[j] = acc.finalize();
1028 }
1029 });
1030 result
1031 });
1032
1033 Tensor::from_vec(result, &[m, n])
1034 }
1035
1036 pub fn bmm(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
1044 if self.ndim() < 2 || other.ndim() < 2 {
1045 return Err(RuntimeError::InvalidOperation(
1046 "bmm requires at least 2-D tensors".to_string(),
1047 ));
1048 }
1049 if self.ndim() == 2 && other.ndim() == 2 {
1050 return self.matmul(other);
1051 }
1052 if self.ndim() != other.ndim() {
1053 return Err(RuntimeError::InvalidOperation(
1054 format!(
1055 "bmm requires same number of dimensions, got {} and {}",
1056 self.ndim(),
1057 other.ndim()
1058 ),
1059 ));
1060 }
1061 let nd = self.ndim();
1062 let batch_dims_a = &self.shape[..nd - 2];
1063 let batch_dims_b = &other.shape[..nd - 2];
1064 if batch_dims_a != batch_dims_b {
1065 return Err(RuntimeError::InvalidOperation(
1066 format!(
1067 "bmm batch dimensions mismatch: {:?} vs {:?}",
1068 batch_dims_a, batch_dims_b
1069 ),
1070 ));
1071 }
1072 let m = self.shape[nd - 2];
1073 let k = self.shape[nd - 1];
1074 let k2 = other.shape[nd - 2];
1075 let n = other.shape[nd - 1];
1076 if k != k2 {
1077 return Err(RuntimeError::DimensionMismatch {
1078 expected: k,
1079 got: k2,
1080 });
1081 }
1082
1083 let batch_size: usize = batch_dims_a.iter().product();
1084 let a = self.to_vec();
1085 let b = other.to_vec();
1086 let mat_a_stride = m * k;
1087 let mat_b_stride = k * n;
1088 let mat_c_stride = m * n;
1089 let mut result = vec![0.0f64; batch_size * mat_c_stride];
1090
1091 let compute_batch = |batch: usize, c_slice: &mut [f64]| {
1093 let a_slice = &a[batch * mat_a_stride..(batch + 1) * mat_a_stride];
1094 let b_slice = &b[batch * mat_b_stride..(batch + 1) * mat_b_stride];
1095
1096 if m >= 64 || n >= 64 || k >= 64 {
1097 let engine = crate::tensor_tiled::TiledMatmul::new();
1098 let tiled = engine.matmul(a_slice, m, k, b_slice, n);
1099 c_slice.copy_from_slice(&tiled);
1100 } else {
1101 for i in 0..m {
1102 for j in 0..n {
1103 let mut acc = KahanAccumulatorF64::new();
1104 for p in 0..k {
1105 acc.add(a_slice[i * k + p] * b_slice[p * n + j]);
1106 }
1107 c_slice[i * n + j] = acc.finalize();
1108 }
1109 }
1110 }
1111 };
1112
1113 #[cfg(feature = "parallel")]
1115 {
1116 if batch_size > 1 && m * k >= 4096 {
1117 use rayon::prelude::*;
1118 crate::runtime_policy::run_parallel(|| {
1121 result
1122 .par_chunks_mut(mat_c_stride)
1123 .enumerate()
1124 .for_each(|(batch, c_slice)| {
1125 compute_batch(batch, c_slice);
1126 });
1127 });
1128
1129 let mut out_shape = batch_dims_a.to_vec();
1130 out_shape.push(m);
1131 out_shape.push(n);
1132 return Tensor::from_vec(result, &out_shape);
1133 }
1134 }
1135
1136 for batch in 0..batch_size {
1138 let c_off = batch * mat_c_stride;
1139 compute_batch(batch, &mut result[c_off..c_off + mat_c_stride]);
1140 }
1141
1142 let mut out_shape = batch_dims_a.to_vec();
1143 out_shape.push(m);
1144 out_shape.push(n);
1145 Tensor::from_vec(result, &out_shape)
1146 }
1147
1148 pub fn softmax(&self) -> Result<Tensor, RuntimeError> {
1156 if self.ndim() == 0 {
1157 return Err(RuntimeError::InvalidOperation(
1158 "softmax requires at least 1-D tensor".to_string(),
1159 ));
1160 }
1161 let data_ref;
1163 let data_vec;
1164 let data: &[f64] = if self.is_contiguous() && self.offset == 0 {
1165 data_ref = self.buffer.borrow_data();
1166 &data_ref
1167 } else {
1168 data_vec = self.to_vec();
1169 &data_vec
1170 };
1171 let n = *self.shape.last().unwrap(); let outer: usize = data.len() / n; let mut result = vec![0.0f64; data.len()];
1174
1175 for row in 0..outer {
1176 let start = row * n;
1177 let end = start + n;
1178 let slice = &data[start..end];
1179
1180 let mut max_val = f64::NEG_INFINITY;
1182 for &v in slice {
1183 if v > max_val {
1184 max_val = v;
1185 }
1186 }
1187
1188 let mut exp_vals = vec![0.0f64; n];
1190 let mut sum = 0.0f64;
1191 let mut comp = 0.0f64; for i in 0..n {
1193 let e = (slice[i] - max_val).exp();
1194 exp_vals[i] = e;
1195 let y = e - comp;
1197 let t = sum + y;
1198 comp = (t - sum) - y;
1199 sum = t;
1200 }
1201
1202 if sum == 0.0 {
1204 let uniform = 1.0 / n as f64;
1206 for i in 0..n {
1207 result[start + i] = uniform;
1208 }
1209 } else {
1210 for i in 0..n {
1211 result[start + i] = exp_vals[i] / sum;
1212 }
1213 }
1214 }
1215
1216 Tensor::from_vec(result, &self.shape)
1217 }
1218
1219 pub fn layer_norm(
1230 &self,
1231 gamma: &Tensor,
1232 beta: &Tensor,
1233 eps: f64,
1234 ) -> Result<Tensor, RuntimeError> {
1235 if self.ndim() == 0 {
1236 return Err(RuntimeError::InvalidOperation(
1237 "layer_norm requires at least 1-D tensor".to_string(),
1238 ));
1239 }
1240 let d = *self.shape.last().unwrap();
1241 if gamma.len() != d || beta.len() != d {
1242 return Err(RuntimeError::InvalidOperation(
1243 format!(
1244 "layer_norm: gamma/beta length {} must match last dim {}",
1245 gamma.len(),
1246 d
1247 ),
1248 ));
1249 }
1250
1251 let data = self.to_vec();
1252 let gamma_data = gamma.to_vec();
1253 let beta_data = beta.to_vec();
1254 let outer = data.len() / d;
1255 let mut result = vec![0.0f64; data.len()];
1256
1257 for row in 0..outer {
1258 let start = row * d;
1259 let slice = &data[start..start + d];
1260
1261 let mean = binned_sum_f64(slice) / d as f64;
1263
1264 let diffs: Vec<f64> = slice.iter().map(|&x| {
1266 let diff = x - mean;
1267 diff * diff
1268 }).collect();
1269 let variance = binned_sum_f64(&diffs) / d as f64;
1270
1271 let inv_std = 1.0 / (variance + eps).sqrt();
1273 for i in 0..d {
1274 let normalized = (slice[i] - mean) * inv_std;
1275 result[start + i] = gamma_data[i] * normalized + beta_data[i];
1276 }
1277 }
1278
1279 Tensor::from_vec(result, &self.shape)
1280 }
1281
1282 fn map_elementwise(&self, f: impl Fn(f64) -> f64) -> Tensor {
1288 if self.is_contiguous() && self.offset == 0 && self.buffer.refcount() == 1 {
1289 let mut data = self.buffer.borrow_data().clone();
1291 for x in data.iter_mut() {
1292 *x = f(*x);
1293 }
1294 Tensor::from_vec(data, &self.shape).unwrap()
1295 } else {
1296 let data = self.to_vec();
1298 let result: Vec<f64> = data.iter().map(|&x| f(x)).collect();
1299 Tensor::from_vec(result, &self.shape).unwrap()
1300 }
1301 }
1302
1303 pub fn relu(&self) -> Tensor {
1305 self.map_elementwise(|x| if x > 0.0 { x } else { 0.0 })
1306 }
1307
1308 pub fn sigmoid(&self) -> Tensor {
1310 self.map_elementwise(|x| 1.0 / (1.0 + (-x).exp()))
1311 }
1312
1313 pub fn tanh_activation(&self) -> Tensor {
1315 self.map_elementwise(|x| x.tanh())
1316 }
1317
1318 pub fn leaky_relu(&self, alpha: f64) -> Tensor {
1320 self.map_elementwise(move |x| if x > 0.0 { x } else { alpha * x })
1321 }
1322
1323 pub fn silu(&self) -> Tensor {
1325 let data = self.to_vec();
1326 let result: Vec<f64> = data.iter().map(|&x| x / (1.0 + (-x).exp())).collect();
1327 Tensor::from_vec(result, &self.shape).unwrap()
1328 }
1329
1330 pub fn mish(&self) -> Tensor {
1332 let data = self.to_vec();
1333 let result: Vec<f64> = data.iter().map(|&x| {
1334 let sp = (1.0 + x.exp()).ln();
1335 x * sp.tanh()
1336 }).collect();
1337 Tensor::from_vec(result, &self.shape).unwrap()
1338 }
1339
1340 pub fn argmax(&self) -> usize {
1342 let data = self.to_vec();
1343 let mut best_idx = 0;
1344 let mut best_val = f64::NEG_INFINITY;
1345 for (i, &v) in data.iter().enumerate() {
1346 if v > best_val || (v == best_val && i < best_idx) {
1347 best_val = v;
1348 best_idx = i;
1349 }
1350 }
1351 best_idx
1352 }
1353
1354 pub fn argmin(&self) -> usize {
1356 let data = self.to_vec();
1357 let mut best_idx = 0;
1358 let mut best_val = f64::INFINITY;
1359 for (i, &v) in data.iter().enumerate() {
1360 if v < best_val || (v == best_val && i < best_idx) {
1361 best_val = v;
1362 best_idx = i;
1363 }
1364 }
1365 best_idx
1366 }
1367
1368 pub fn clamp(&self, min: f64, max: f64) -> Tensor {
1370 let data = self.to_vec();
1371 let result: Vec<f64> = data.iter().map(|&x| x.max(min).min(max)).collect();
1372 Tensor::from_vec(result, &self.shape).unwrap()
1373 }
1374
1375 pub fn one_hot(indices: &[usize], depth: usize) -> Result<Tensor, RuntimeError> {
1378 let n = indices.len();
1379 let mut data = vec![0.0; n * depth];
1380 for (i, &idx) in indices.iter().enumerate() {
1381 if idx >= depth {
1382 return Err(RuntimeError::InvalidOperation(format!(
1383 "one_hot: index {idx} >= depth {depth}"
1384 )));
1385 }
1386 data[i * depth + idx] = 1.0;
1387 }
1388 Tensor::from_vec(data, &[n, depth])
1389 }
1390
1391 pub fn cat(tensors: &[&Tensor], axis: usize) -> Result<Tensor, RuntimeError> {
1397 if tensors.is_empty() {
1398 return Err(RuntimeError::InvalidOperation("cat: no tensors".to_string()));
1399 }
1400 let ndim = tensors[0].ndim();
1401 if axis >= ndim {
1402 return Err(RuntimeError::InvalidOperation(
1403 format!("cat: axis {axis} out of bounds for {ndim}D tensor"),
1404 ));
1405 }
1406 for (i, t) in tensors.iter().enumerate().skip(1) {
1407 if t.ndim() != ndim {
1408 return Err(RuntimeError::InvalidOperation(
1409 format!("cat: tensor {i} has different ndim"),
1410 ));
1411 }
1412 for d in 0..ndim {
1413 if d != axis && t.shape[d] != tensors[0].shape[d] {
1414 return Err(RuntimeError::InvalidOperation(
1415 format!("cat: shape mismatch at dim {d}"),
1416 ));
1417 }
1418 }
1419 }
1420 let mut out_shape = tensors[0].shape.clone();
1421 for t in tensors.iter().skip(1) {
1422 out_shape[axis] += t.shape[axis];
1423 }
1424 let total = out_shape.iter().product::<usize>();
1425 let mut result = vec![0.0; total];
1426 let mut out_strides = vec![1usize; ndim];
1427 for d in (0..ndim - 1).rev() {
1428 out_strides[d] = out_strides[d + 1] * out_shape[d + 1];
1429 }
1430 let mut offset = 0;
1431 for t in tensors {
1432 let t_data = t.to_vec();
1433 let t_total: usize = t.shape.iter().product();
1434 let mut t_strides = vec![1usize; ndim];
1435 for d in (0..ndim - 1).rev() {
1436 t_strides[d] = t_strides[d + 1] * t.shape[d + 1];
1437 }
1438 for idx in 0..t_total {
1439 let mut remaining = idx;
1440 let mut out_flat = 0;
1441 for d in 0..ndim {
1442 let coord = remaining / t_strides[d];
1443 remaining %= t_strides[d];
1444 let out_coord = if d == axis { coord + offset } else { coord };
1445 out_flat += out_coord * out_strides[d];
1446 }
1447 result[out_flat] = t_data[idx];
1448 }
1449 offset += t.shape[axis];
1450 }
1451 Tensor::from_vec(result, &out_shape)
1452 }
1453
1454 pub fn stack(tensors: &[&Tensor], axis: usize) -> Result<Tensor, RuntimeError> {
1456 if tensors.is_empty() {
1457 return Err(RuntimeError::InvalidOperation("stack: no tensors".to_string()));
1458 }
1459 let base_shape = &tensors[0].shape;
1460 let ndim = base_shape.len();
1461 if axis > ndim {
1462 return Err(RuntimeError::InvalidOperation(
1463 format!("stack: axis {axis} out of bounds"),
1464 ));
1465 }
1466 for (i, t) in tensors.iter().enumerate().skip(1) {
1467 if &t.shape != base_shape {
1468 return Err(RuntimeError::InvalidOperation(
1469 format!("stack: tensor {i} shape mismatch"),
1470 ));
1471 }
1472 }
1473 let mut out_shape = Vec::with_capacity(ndim + 1);
1474 for d in 0..axis { out_shape.push(base_shape[d]); }
1475 out_shape.push(tensors.len());
1476 for d in axis..ndim { out_shape.push(base_shape[d]); }
1477 let total: usize = out_shape.iter().product();
1478 let mut result = vec![0.0; total];
1479 let inner_size: usize = base_shape[axis..].iter().product::<usize>().max(1);
1480 let outer_size: usize = base_shape[..axis].iter().product::<usize>().max(1);
1481 for (t_idx, t) in tensors.iter().enumerate() {
1482 let t_data = t.to_vec();
1483 for outer in 0..outer_size {
1484 for inner in 0..inner_size {
1485 let src = outer * inner_size + inner;
1486 let dst = outer * (tensors.len() * inner_size) + t_idx * inner_size + inner;
1487 if src < t_data.len() && dst < result.len() {
1488 result[dst] = t_data[src];
1489 }
1490 }
1491 }
1492 }
1493 Tensor::from_vec(result, &out_shape)
1494 }
1495
1496 pub fn topk(&self, k: usize) -> Result<(Tensor, Vec<usize>), RuntimeError> {
1498 let data = self.to_vec();
1499 let n = data.len();
1500 if k > n {
1501 return Err(RuntimeError::InvalidOperation(
1502 format!("topk: k={k} exceeds data length {n}"),
1503 ));
1504 }
1505 let mut indexed: Vec<(usize, f64)> = data.into_iter().enumerate().collect();
1506 indexed.sort_by(|a, b| b.1.total_cmp(&a.1).then(a.0.cmp(&b.0)));
1507 let top_k: Vec<(usize, f64)> = indexed[..k].to_vec();
1508 let values: Vec<f64> = top_k.iter().map(|&(_, v)| v).collect();
1509 let indices: Vec<usize> = top_k.iter().map(|&(i, _)| i).collect();
1510 Ok((Tensor::from_vec(values, &[k])?, indices))
1511 }
1512
1513 pub fn gelu(&self) -> Tensor {
1515 let data = self.to_vec();
1516 let sqrt_2_over_pi = (2.0_f64 / std::f64::consts::PI).sqrt();
1517 let result: Vec<f64> = data.iter().map(|&x| {
1518 let inner = sqrt_2_over_pi * (x + 0.044715 * x * x * x);
1519 0.5 * x * (1.0 + inner.tanh())
1520 }).collect();
1521 Tensor::from_vec(result, &self.shape).unwrap()
1522 }
1523
1524 pub fn linear(
1530 &self,
1531 weight: &Tensor,
1532 bias: &Tensor,
1533 ) -> Result<Tensor, RuntimeError> {
1534 if weight.ndim() != 2 {
1535 return Err(RuntimeError::InvalidOperation(
1536 "linear: weight must be 2-D [out_features, in_features]".to_string(),
1537 ));
1538 }
1539 let out_features = weight.shape[0];
1540 let in_features = weight.shape[1];
1541 let last_dim = *self.shape.last().ok_or_else(|| {
1542 RuntimeError::InvalidOperation("linear: input must be at least 1-D".to_string())
1543 })?;
1544 if last_dim != in_features {
1545 return Err(RuntimeError::DimensionMismatch {
1546 expected: in_features,
1547 got: last_dim,
1548 });
1549 }
1550 if bias.len() != out_features {
1551 return Err(RuntimeError::InvalidOperation(
1552 format!(
1553 "linear: bias length {} must match out_features {}",
1554 bias.len(),
1555 out_features
1556 ),
1557 ));
1558 }
1559
1560 let data = self.to_vec();
1561 let w = weight.to_vec();
1562 let b = bias.to_vec();
1563 let outer = data.len() / in_features;
1564 let mut result = vec![0.0f64; outer * out_features];
1565
1566 for row in 0..outer {
1567 let x_start = row * in_features;
1568 let x_slice = &data[x_start..x_start + in_features];
1569 let y_start = row * out_features;
1570 for j in 0..out_features {
1571 let w_start = j * in_features;
1572 let mut acc = BinnedAccumulatorF64::new();
1573 for p in 0..in_features {
1574 acc.add(x_slice[p] * w[w_start + p]);
1575 }
1576 result[y_start + j] = acc.finalize() + b[j];
1577 }
1578 }
1579
1580 let mut out_shape = self.shape[..self.shape.len() - 1].to_vec();
1581 out_shape.push(out_features);
1582 Tensor::from_vec(result, &out_shape)
1583 }
1584
1585 pub fn conv1d(
1589 &self,
1590 filters: &Tensor,
1591 bias: &Tensor,
1592 ) -> Result<Tensor, RuntimeError> {
1593 if self.ndim() != 1 {
1594 return Err(RuntimeError::InvalidOperation(
1595 "conv1d: input must be 1-D [signal_len]".to_string(),
1596 ));
1597 }
1598 if filters.ndim() != 2 {
1599 return Err(RuntimeError::InvalidOperation(
1600 "conv1d: filters must be 2-D [out_channels, kernel_size]".to_string(),
1601 ));
1602 }
1603 let signal_len = self.shape[0];
1604 let out_channels = filters.shape[0];
1605 let kernel_size = filters.shape[1];
1606 if signal_len < kernel_size {
1607 return Err(RuntimeError::InvalidOperation(
1608 format!(
1609 "conv1d: signal_len {} < kernel_size {}",
1610 signal_len, kernel_size
1611 ),
1612 ));
1613 }
1614 if bias.len() != out_channels {
1615 return Err(RuntimeError::InvalidOperation(
1616 format!(
1617 "conv1d: bias length {} must match out_channels {}",
1618 bias.len(), out_channels
1619 ),
1620 ));
1621 }
1622 let out_len = signal_len - kernel_size + 1;
1623 let s = self.to_vec();
1624 let f = filters.to_vec();
1625 let b = bias.to_vec();
1626 let mut result = vec![0.0; out_channels * out_len];
1627 kernel_fns::conv1d_raw(&s, &f, &b, &mut result, signal_len, out_channels, kernel_size);
1628 Tensor::from_vec(result, &[out_channels, out_len])
1629 }
1630
1631 pub fn conv2d(
1645 &self,
1646 filters: &Tensor,
1647 bias: &Tensor,
1648 stride: usize,
1649 ) -> Result<Tensor, RuntimeError> {
1650 if self.ndim() != 4 {
1651 return Err(RuntimeError::InvalidOperation(
1652 "conv2d: input must be 4-D [N, C_in, H, W]".to_string(),
1653 ));
1654 }
1655 if filters.ndim() != 4 {
1656 return Err(RuntimeError::InvalidOperation(
1657 "conv2d: filters must be 4-D [C_out, C_in, kH, kW]".to_string(),
1658 ));
1659 }
1660 if stride == 0 {
1661 return Err(RuntimeError::InvalidOperation(
1662 "conv2d: stride must be >= 1".to_string(),
1663 ));
1664 }
1665
1666 let n = self.shape[0];
1667 let c_in = self.shape[1];
1668 let h_in = self.shape[2];
1669 let w_in = self.shape[3];
1670
1671 let c_out = filters.shape[0];
1672 let c_in_check = filters.shape[1];
1673 let kh = filters.shape[2];
1674 let kw = filters.shape[3];
1675
1676 if c_in != c_in_check {
1677 return Err(RuntimeError::InvalidOperation(format!(
1678 "conv2d: input C_in={} does not match filter C_in={}",
1679 c_in, c_in_check
1680 )));
1681 }
1682 if h_in < kh || w_in < kw {
1683 return Err(RuntimeError::InvalidOperation(format!(
1684 "conv2d: input spatial [{}, {}] is smaller than kernel [{}, {}]",
1685 h_in, w_in, kh, kw
1686 )));
1687 }
1688 if bias.len() != c_out {
1689 return Err(RuntimeError::InvalidOperation(format!(
1690 "conv2d: bias length {} must match C_out={}",
1691 bias.len(), c_out
1692 )));
1693 }
1694
1695 let h_out = (h_in - kh) / stride + 1;
1696 let w_out = (w_in - kw) / stride + 1;
1697
1698 let inp = self.to_vec();
1699 let flt = filters.to_vec();
1700 let b = bias.to_vec();
1701 let mut result = vec![0.0f64; n * c_out * h_out * w_out];
1702
1703 kernel_fns::conv2d_raw(&inp, &flt, &b, &mut result,
1704 n, c_in, h_in, w_in, c_out, kh, kw, stride);
1705
1706 Tensor::from_vec(result, &[n, c_out, h_out, w_out])
1707 }
1708
1709 pub fn maxpool2d(&self, ph: usize, pw: usize) -> Result<Tensor, RuntimeError> {
1716 if self.ndim() != 4 {
1717 return Err(RuntimeError::InvalidOperation(
1718 "maxpool2d: input must be 4-D [N, C, H, W]".to_string(),
1719 ));
1720 }
1721 if ph == 0 || pw == 0 {
1722 return Err(RuntimeError::InvalidOperation(
1723 "maxpool2d: pool size must be >= 1".to_string(),
1724 ));
1725 }
1726
1727 let n = self.shape[0];
1728 let c = self.shape[1];
1729 let h_in = self.shape[2];
1730 let w_in = self.shape[3];
1731
1732 if h_in < ph || w_in < pw {
1733 return Err(RuntimeError::InvalidOperation(format!(
1734 "maxpool2d: input [{}, {}] smaller than pool [{}, {}]",
1735 h_in, w_in, ph, pw
1736 )));
1737 }
1738
1739 let h_out = h_in / ph;
1740 let w_out = w_in / pw;
1741
1742 let inp = self.to_vec();
1743 let mut result = vec![0.0f64; n * c * h_out * w_out];
1744
1745 kernel_fns::maxpool2d_raw(&inp, &mut result, n, c, h_in, w_in, ph, pw);
1746
1747 Tensor::from_vec(result, &[n, c, h_out, w_out])
1748 }
1749
1750 pub fn avgpool2d(&self, kernel_h: usize, kernel_w: usize, stride_h: usize, stride_w: usize) -> Result<Tensor, RuntimeError> {
1765 let shape = self.shape();
1766 if shape.len() != 3 {
1767 return Err(RuntimeError::InvalidOperation(format!("avgpool2d requires 3-D [C,H,W], got {:?}", shape)));
1768 }
1769 let (c, h, w) = (shape[0], shape[1], shape[2]);
1770 if kernel_h > h || kernel_w > w {
1771 return Err(RuntimeError::InvalidOperation("avgpool2d: kernel larger than input".into()));
1772 }
1773 let out_h = (h - kernel_h) / stride_h + 1;
1774 let out_w = (w - kernel_w) / stride_w + 1;
1775 let data = self.to_vec();
1776 let mut out = Vec::with_capacity(c * out_h * out_w);
1777 let pool_size = (kernel_h * kernel_w) as f64;
1778
1779 for ch in 0..c {
1780 for oh in 0..out_h {
1781 for ow in 0..out_w {
1782 let mut sum = 0.0f64;
1783 for kh in 0..kernel_h {
1784 for kw in 0..kernel_w {
1785 let ih = oh * stride_h + kh;
1786 let iw = ow * stride_w + kw;
1787 sum += data[ch * h * w + ih * w + iw];
1788 }
1789 }
1790 out.push(sum / pool_size);
1791 }
1792 }
1793 }
1794 Tensor::from_vec(out, &[c, out_h, out_w])
1795 }
1796
1797 pub fn scaled_dot_product_attention(
1806 queries: &Tensor,
1807 keys: &Tensor,
1808 values: &Tensor,
1809 ) -> Result<Tensor, RuntimeError> {
1810 if queries.ndim() < 2 || keys.ndim() < 2 || values.ndim() < 2 {
1811 return Err(RuntimeError::InvalidOperation(
1812 "attention: Q, K, V must be at least 2-D".to_string(),
1813 ));
1814 }
1815 let nd = queries.ndim();
1816 let d_k = queries.shape[nd - 1];
1817 let scale = 1.0 / (d_k as f64).sqrt();
1818
1819 let keys_t = keys.transpose_last_two()?;
1821
1822 let scores = queries.bmm(&keys_t)?;
1824
1825 let scores_scaled = scores.scalar_mul(scale);
1827
1828 let attn_weights = scores_scaled.softmax()?;
1830
1831 attn_weights.bmm(values)
1833 }
1834
1835 pub fn transpose_last_two(&self) -> Result<Tensor, RuntimeError> {
1839 if self.ndim() < 2 {
1840 return Err(RuntimeError::InvalidOperation(
1841 "transpose_last_two requires at least 2-D tensor".to_string(),
1842 ));
1843 }
1844 let nd = self.ndim();
1845 let rows = self.shape[nd - 2];
1846 let cols = self.shape[nd - 1];
1847 let data = self.to_vec();
1848 let batch_size: usize = self.shape[..nd - 2].iter().product::<usize>().max(1);
1849 let mat_size = rows * cols;
1850 let mut result = vec![0.0f64; data.len()];
1851
1852 for b in 0..batch_size {
1853 let off = b * mat_size;
1854 for i in 0..rows {
1855 for j in 0..cols {
1856 result[off + j * rows + i] = data[off + i * cols + j];
1857 }
1858 }
1859 }
1860
1861 let mut out_shape = self.shape.clone();
1862 out_shape[nd - 2] = cols;
1863 out_shape[nd - 1] = rows;
1864 Tensor::from_vec(result, &out_shape)
1865 }
1866
1867 pub fn from_bytes(bytes: &[u8], shape: &[usize], dtype: &str) -> Result<Tensor, RuntimeError> {
1883 let numel = Self::shape_numel(shape);
1884 match dtype {
1885 "f64" => {
1886 let expected = numel * 8;
1887 if bytes.len() != expected {
1888 return Err(RuntimeError::ShapeMismatch {
1889 expected,
1890 got: bytes.len(),
1891 });
1892 }
1893 let mut data = Vec::with_capacity(numel);
1894 for i in 0..numel {
1895 let off = i * 8;
1896 let mut buf = [0u8; 8];
1897 buf.copy_from_slice(&bytes[off..off + 8]);
1898 data.push(f64::from_le_bytes(buf));
1899 }
1900 Ok(Tensor {
1901 buffer: Buffer::from_vec(data),
1902 shape: shape.to_vec(),
1903 strides: Self::compute_strides(shape),
1904 offset: 0,
1905 })
1906 }
1907 "f32" => {
1908 let expected = numel * 4;
1909 if bytes.len() != expected {
1910 return Err(RuntimeError::ShapeMismatch {
1911 expected,
1912 got: bytes.len(),
1913 });
1914 }
1915 let mut data = Vec::with_capacity(numel);
1916 for i in 0..numel {
1917 let off = i * 4;
1918 let mut buf = [0u8; 4];
1919 buf.copy_from_slice(&bytes[off..off + 4]);
1920 data.push(f32::from_le_bytes(buf) as f64);
1921 }
1922 Ok(Tensor {
1923 buffer: Buffer::from_vec(data),
1924 shape: shape.to_vec(),
1925 strides: Self::compute_strides(shape),
1926 offset: 0,
1927 })
1928 }
1929 _ => Err(RuntimeError::InvalidOperation(
1930 format!("from_bytes: unsupported dtype '{}', expected 'f32' or 'f64'", dtype),
1931 )),
1932 }
1933 }
1934
1935 pub fn split_heads(&self, num_heads: usize) -> Result<Tensor, RuntimeError> {
1943 if self.ndim() != 3 {
1944 return Err(RuntimeError::DimensionMismatch {
1945 expected: 3,
1946 got: self.ndim(),
1947 });
1948 }
1949 let batch = self.shape[0];
1950 let seq = self.shape[1];
1951 let model_dim = self.shape[2];
1952 if model_dim % num_heads != 0 {
1953 return Err(RuntimeError::InvalidOperation(
1954 format!(
1955 "split_heads: model_dim {} not divisible by num_heads {}",
1956 model_dim, num_heads
1957 ),
1958 ));
1959 }
1960 let head_dim = model_dim / num_heads;
1961 let tensor = if self.is_contiguous() { self.clone() } else { self.to_contiguous() };
1963 let reshaped = Tensor {
1965 buffer: tensor.buffer.clone(),
1966 shape: vec![batch, seq, num_heads, head_dim],
1967 strides: Self::compute_strides(&[batch, seq, num_heads, head_dim]),
1968 offset: 0,
1969 };
1970 Ok(Tensor {
1973 buffer: reshaped.buffer,
1974 shape: vec![batch, num_heads, seq, head_dim],
1975 strides: vec![
1976 reshaped.strides[0], reshaped.strides[2], reshaped.strides[1], reshaped.strides[3], ],
1981 offset: 0,
1982 })
1983 }
1984
1985 pub fn merge_heads(&self) -> Result<Tensor, RuntimeError> {
1988 if self.ndim() != 4 {
1989 return Err(RuntimeError::DimensionMismatch {
1990 expected: 4,
1991 got: self.ndim(),
1992 });
1993 }
1994 let batch = self.shape[0];
1995 let num_heads = self.shape[1];
1996 let seq = self.shape[2];
1997 let head_dim = self.shape[3];
1998 let transposed = Tensor {
2001 buffer: self.buffer.clone(),
2002 shape: vec![batch, seq, num_heads, head_dim],
2003 strides: vec![
2004 self.strides[0],
2005 self.strides[2], self.strides[1], self.strides[3],
2008 ],
2009 offset: self.offset,
2010 };
2011 let contig = transposed.to_contiguous();
2013 let model_dim = num_heads * head_dim;
2014 Ok(Tensor {
2015 buffer: contig.buffer,
2016 shape: vec![batch, seq, model_dim],
2017 strides: Self::compute_strides(&[batch, seq, model_dim]),
2018 offset: 0,
2019 })
2020 }
2021
2022 pub fn view_reshape(&self, new_shape: &[usize]) -> Result<Tensor, RuntimeError> {
2025 self.reshape(new_shape)
2026 }
2027
2028 pub fn argsort(&self) -> Tensor {
2035 let data = self.to_vec();
2036 let mut indices: Vec<usize> = (0..data.len()).collect();
2037 indices.sort_by(|&a, &b| data[a].total_cmp(&data[b]));
2038 let result: Vec<f64> = indices.iter().map(|&i| i as f64).collect();
2039 Tensor::from_vec_unchecked(result, &[data.len()])
2040 }
2041
2042 pub fn gather(&self, dim: usize, indices: &Tensor) -> Result<Tensor, RuntimeError> {
2047 let data = self.to_vec();
2048 let idx_data = indices.to_vec();
2049 if self.ndim() == 1 {
2050 let mut result = Vec::with_capacity(idx_data.len());
2051 for &idx in &idx_data {
2052 let i = idx as usize;
2053 if i >= data.len() {
2054 return Err(RuntimeError::InvalidOperation(
2055 format!("gather: index {} out of bounds for size {}", i, data.len()),
2056 ));
2057 }
2058 result.push(data[i]);
2059 }
2060 Ok(Tensor::from_vec_unchecked(result, indices.shape()))
2061 } else if self.ndim() == 2 {
2062 let rows = self.shape[0];
2063 let cols = self.shape[1];
2064 let idx_shape = indices.shape();
2065 let out_rows = idx_shape[0];
2066 let out_cols = idx_shape[1];
2067 let mut result = vec![0.0; out_rows * out_cols];
2068 for i in 0..out_rows {
2069 for j in 0..out_cols {
2070 let idx = idx_data[i * out_cols + j] as usize;
2071 let val = if dim == 0 {
2072 if idx >= rows {
2073 return Err(RuntimeError::InvalidOperation(
2074 format!("gather dim=0: index {} out of bounds for {} rows", idx, rows),
2075 ));
2076 }
2077 data[idx * cols + j]
2078 } else {
2079 if idx >= cols {
2080 return Err(RuntimeError::InvalidOperation(
2081 format!("gather dim=1: index {} out of bounds for {} cols", idx, cols),
2082 ));
2083 }
2084 data[i * cols + idx]
2085 };
2086 result[i * out_cols + j] = val;
2087 }
2088 }
2089 Ok(Tensor::from_vec_unchecked(result, idx_shape))
2090 } else {
2091 Err(RuntimeError::InvalidOperation(
2092 "gather: only 1D and 2D tensors supported".into(),
2093 ))
2094 }
2095 }
2096
2097 pub fn scatter(&self, dim: usize, indices: &Tensor, src: &Tensor) -> Result<Tensor, RuntimeError> {
2102 let mut result = self.to_vec();
2103 let idx_data = indices.to_vec();
2104 let src_data = src.to_vec();
2105 if self.ndim() == 1 {
2106 for (k, &idx) in idx_data.iter().enumerate() {
2107 let i = idx as usize;
2108 if i >= result.len() {
2109 return Err(RuntimeError::InvalidOperation(
2110 format!("scatter: index {} out of bounds for size {}", i, result.len()),
2111 ));
2112 }
2113 result[i] = src_data[k];
2114 }
2115 Ok(Tensor::from_vec_unchecked(result, self.shape()))
2116 } else if self.ndim() == 2 {
2117 let cols = self.shape[1];
2118 let idx_shape = indices.shape();
2119 let out_cols = idx_shape[1];
2120 let out_rows = idx_shape[0];
2121 for i in 0..out_rows {
2122 for j in 0..out_cols {
2123 let idx = idx_data[i * out_cols + j] as usize;
2124 let src_val = src_data[i * out_cols + j];
2125 if dim == 0 {
2126 if idx >= self.shape[0] {
2127 return Err(RuntimeError::InvalidOperation(
2128 format!("scatter dim=0: index {} out of bounds for {} rows", idx, self.shape[0]),
2129 ));
2130 }
2131 result[idx * cols + j] = src_val;
2132 } else {
2133 if idx >= cols {
2134 return Err(RuntimeError::InvalidOperation(
2135 format!("scatter dim=1: index {} out of bounds for {} cols", idx, cols),
2136 ));
2137 }
2138 result[i * cols + idx] = src_val;
2139 }
2140 }
2141 }
2142 Ok(Tensor::from_vec_unchecked(result, self.shape()))
2143 } else {
2144 Err(RuntimeError::InvalidOperation(
2145 "scatter: only 1D and 2D tensors supported".into(),
2146 ))
2147 }
2148 }
2149
2150 pub fn index_select(&self, dim: usize, indices: &Tensor) -> Result<Tensor, RuntimeError> {
2154 let data = self.to_vec();
2155 let idx_data = indices.to_vec();
2156 if self.ndim() == 1 {
2157 let mut result = Vec::with_capacity(idx_data.len());
2158 for &idx in &idx_data {
2159 let i = idx as usize;
2160 if i >= data.len() {
2161 return Err(RuntimeError::InvalidOperation(
2162 format!("index_select: index {} out of bounds for size {}", i, data.len()),
2163 ));
2164 }
2165 result.push(data[i]);
2166 }
2167 Ok(Tensor::from_vec_unchecked(result, &[idx_data.len()]))
2168 } else if self.ndim() == 2 {
2169 let rows = self.shape[0];
2170 let cols = self.shape[1];
2171 let n = idx_data.len();
2172 if dim == 0 {
2173 let mut result = Vec::with_capacity(n * cols);
2174 for &idx in &idx_data {
2175 let i = idx as usize;
2176 if i >= rows {
2177 return Err(RuntimeError::InvalidOperation(
2178 format!("index_select dim=0: index {} out of bounds for {} rows", i, rows),
2179 ));
2180 }
2181 for j in 0..cols {
2182 result.push(data[i * cols + j]);
2183 }
2184 }
2185 Ok(Tensor::from_vec_unchecked(result, &[n, cols]))
2186 } else {
2187 let mut result = Vec::with_capacity(rows * n);
2188 for i in 0..rows {
2189 for &idx in &idx_data {
2190 let j = idx as usize;
2191 if j >= cols {
2192 return Err(RuntimeError::InvalidOperation(
2193 format!("index_select dim=1: index {} out of bounds for {} cols", j, cols),
2194 ));
2195 }
2196 result.push(data[i * cols + j]);
2197 }
2198 }
2199 Ok(Tensor::from_vec_unchecked(result, &[rows, n]))
2200 }
2201 } else {
2202 Err(RuntimeError::InvalidOperation(
2203 "index_select: only 1D and 2D tensors supported".into(),
2204 ))
2205 }
2206 }
2207
2208 pub fn tensor_where(&self, condition: &Tensor, other: &Tensor) -> Result<Tensor, RuntimeError> {
2215 if self.shape() != condition.shape() || self.shape() != other.shape() {
2216 return Err(RuntimeError::InvalidOperation(
2217 format!("where: shape mismatch self={:?} cond={:?} other={:?}",
2218 self.shape(), condition.shape(), other.shape()),
2219 ));
2220 }
2221 let s = self.to_vec();
2222 let c = condition.to_vec();
2223 let o = other.to_vec();
2224 let result: Vec<f64> = s.iter().zip(c.iter()).zip(o.iter())
2225 .map(|((&sv, &cv), &ov)| if cv != 0.0 { sv } else { ov })
2226 .collect();
2227 Tensor::from_vec(result, self.shape())
2228 }
2229
2230 pub fn any(&self) -> bool {
2232 let data = self.to_vec();
2233 data.iter().any(|&x| x != 0.0)
2234 }
2235
2236 pub fn all(&self) -> bool {
2238 let data = self.to_vec();
2239 data.iter().all(|&x| x != 0.0)
2240 }
2241
2242 pub fn nonzero(&self) -> Tensor {
2246 let data = self.to_vec();
2247 let indices: Vec<f64> = data.iter().enumerate()
2248 .filter(|(_, &v)| v != 0.0)
2249 .map(|(i, _)| i as f64)
2250 .collect();
2251 let len = indices.len();
2252 if len == 0 {
2253 Tensor::from_vec(vec![], &[0]).unwrap()
2254 } else {
2255 Tensor::from_vec(indices, &[len]).unwrap()
2256 }
2257 }
2258
2259 pub fn masked_fill(&self, mask: &Tensor, value: f64) -> Result<Tensor, RuntimeError> {
2261 if self.shape() != mask.shape() {
2262 return Err(RuntimeError::InvalidOperation(
2263 format!("masked_fill: shape mismatch self={:?} mask={:?}",
2264 self.shape(), mask.shape()),
2265 ));
2266 }
2267 let data = self.to_vec();
2268 let m = mask.to_vec();
2269 let result: Vec<f64> = data.iter().zip(m.iter())
2270 .map(|(&d, &mv)| if mv != 0.0 { value } else { d })
2271 .collect();
2272 Tensor::from_vec(result, self.shape())
2273 }
2274
2275 fn reduce_axis<F>(&self, axis: usize, keepdim: bool, reduce_fn: F)
2285 -> Result<Tensor, RuntimeError>
2286 where
2287 F: Fn(&[f64]) -> f64,
2288 {
2289 let ndim = self.ndim();
2290 if axis >= ndim {
2291 return Err(RuntimeError::IndexOutOfBounds {
2292 index: axis,
2293 length: ndim,
2294 });
2295 }
2296
2297 let axis_len = self.shape[axis];
2298 let mut out_shape: Vec<usize> = self.shape.clone();
2300 out_shape[axis] = 1;
2301 let out_numel = Self::shape_numel(&out_shape);
2302 let out_strides = Self::compute_strides(&out_shape);
2303
2304 let data = self.to_vec();
2305 let mut result = Vec::with_capacity(out_numel);
2306 let mut indices = vec![0usize; ndim];
2307
2308 for out_idx in 0..out_numel {
2309 {
2311 let mut remaining = out_idx;
2312 for d in 0..ndim {
2313 indices[d] = remaining / out_strides[d];
2314 remaining %= out_strides[d];
2315 }
2316 }
2317
2318 let mut vals = Vec::with_capacity(axis_len);
2320 for k in 0..axis_len {
2321 let mut flat = self.offset;
2322 for d in 0..ndim {
2323 let idx = if d == axis { k } else { indices[d] };
2324 flat += idx * self.strides[d];
2325 }
2326 vals.push(data[flat]);
2327 }
2328 result.push(reduce_fn(&vals));
2329 }
2330
2331 let final_shape = if keepdim {
2332 out_shape
2333 } else {
2334 let mut s: Vec<usize> = self.shape.iter().enumerate()
2336 .filter(|&(i, _)| i != axis)
2337 .map(|(_, &v)| v)
2338 .collect();
2339 if s.is_empty() {
2340 s.push(1); }
2342 s
2343 };
2344
2345 Tensor::from_vec(result, &final_shape)
2346 }
2347
2348 pub fn mean_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2353 self.reduce_axis(axis, keepdim, |vals| {
2354 let mut acc = BinnedAccumulatorF64::new();
2355 for &v in vals { acc.add(v); }
2356 acc.finalize() / vals.len() as f64
2357 })
2358 }
2359
2360 pub fn max_axis(&self, axis: usize, keepdim: bool) -> Result<(Tensor, Tensor), RuntimeError> {
2364 let ndim = self.ndim();
2365 if axis >= ndim {
2366 return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2367 }
2368 let axis_len = self.shape[axis];
2369 let mut out_shape = self.shape.clone();
2370 out_shape[axis] = 1;
2371 let out_numel = Self::shape_numel(&out_shape);
2372 let out_strides = Self::compute_strides(&out_shape);
2373 let data = self.to_vec();
2374 let mut values = Vec::with_capacity(out_numel);
2375 let mut idx_vals = Vec::with_capacity(out_numel);
2376 let mut indices = vec![0usize; ndim];
2377
2378 for out_idx in 0..out_numel {
2379 let mut remaining = out_idx;
2380 for d in 0..ndim {
2381 indices[d] = remaining / out_strides[d];
2382 remaining %= out_strides[d];
2383 }
2384 let mut best_val = f64::NEG_INFINITY;
2385 let mut best_idx = 0usize;
2386 for k in 0..axis_len {
2387 let mut flat = self.offset;
2388 for d in 0..ndim {
2389 let idx = if d == axis { k } else { indices[d] };
2390 flat += idx * self.strides[d];
2391 }
2392 let v = data[flat];
2393 if v > best_val {
2394 best_val = v;
2395 best_idx = k;
2396 }
2397 }
2398 values.push(best_val);
2399 idx_vals.push(best_idx as f64);
2400 }
2401
2402 let final_shape = if keepdim {
2403 out_shape
2404 } else {
2405 let mut s: Vec<usize> = self.shape.iter().enumerate()
2406 .filter(|&(i, _)| i != axis).map(|(_, &v)| v).collect();
2407 if s.is_empty() { s.push(1); }
2408 s
2409 };
2410 Ok((
2411 Tensor::from_vec(values, &final_shape)?,
2412 Tensor::from_vec(idx_vals, &final_shape)?,
2413 ))
2414 }
2415
2416 pub fn min_axis(&self, axis: usize, keepdim: bool) -> Result<(Tensor, Tensor), RuntimeError> {
2420 let ndim = self.ndim();
2421 if axis >= ndim {
2422 return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2423 }
2424 let axis_len = self.shape[axis];
2425 let mut out_shape = self.shape.clone();
2426 out_shape[axis] = 1;
2427 let out_numel = Self::shape_numel(&out_shape);
2428 let out_strides = Self::compute_strides(&out_shape);
2429 let data = self.to_vec();
2430 let mut values = Vec::with_capacity(out_numel);
2431 let mut idx_vals = Vec::with_capacity(out_numel);
2432 let mut indices = vec![0usize; ndim];
2433
2434 for out_idx in 0..out_numel {
2435 let mut remaining = out_idx;
2436 for d in 0..ndim {
2437 indices[d] = remaining / out_strides[d];
2438 remaining %= out_strides[d];
2439 }
2440 let mut best_val = f64::INFINITY;
2441 let mut best_idx = 0usize;
2442 for k in 0..axis_len {
2443 let mut flat = self.offset;
2444 for d in 0..ndim {
2445 let idx = if d == axis { k } else { indices[d] };
2446 flat += idx * self.strides[d];
2447 }
2448 let v = data[flat];
2449 if v < best_val {
2450 best_val = v;
2451 best_idx = k;
2452 }
2453 }
2454 values.push(best_val);
2455 idx_vals.push(best_idx as f64);
2456 }
2457
2458 let final_shape = if keepdim {
2459 out_shape
2460 } else {
2461 let mut s: Vec<usize> = self.shape.iter().enumerate()
2462 .filter(|&(i, _)| i != axis).map(|(_, &v)| v).collect();
2463 if s.is_empty() { s.push(1); }
2464 s
2465 };
2466 Ok((
2467 Tensor::from_vec(values, &final_shape)?,
2468 Tensor::from_vec(idx_vals, &final_shape)?,
2469 ))
2470 }
2471
2472 pub fn var_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2478 let mean_t = self.mean_axis(axis, true)?;
2479 let ndim = self.ndim();
2480 if axis >= ndim {
2481 return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2482 }
2483 let axis_len = self.shape[axis];
2484 let mut out_shape = self.shape.clone();
2485 out_shape[axis] = 1;
2486 let out_numel = Self::shape_numel(&out_shape);
2487 let out_strides = Self::compute_strides(&out_shape);
2488 let data = self.to_vec();
2489 let mean_data = mean_t.to_vec();
2490 let mut result = Vec::with_capacity(out_numel);
2491 let mut indices = vec![0usize; ndim];
2492
2493 for out_idx in 0..out_numel {
2494 let mut remaining = out_idx;
2495 for d in 0..ndim {
2496 indices[d] = remaining / out_strides[d];
2497 remaining %= out_strides[d];
2498 }
2499 let mu = mean_data[out_idx];
2500 let mut acc = BinnedAccumulatorF64::new();
2501 for k in 0..axis_len {
2502 let mut flat = self.offset;
2503 for d in 0..ndim {
2504 let idx = if d == axis { k } else { indices[d] };
2505 flat += idx * self.strides[d];
2506 }
2507 let diff = data[flat] - mu;
2508 acc.add(diff * diff);
2509 }
2510 result.push(acc.finalize() / axis_len as f64);
2511 }
2512
2513 let final_shape = if keepdim {
2514 out_shape
2515 } else {
2516 let mut s: Vec<usize> = self.shape.iter().enumerate()
2517 .filter(|&(i, _)| i != axis).map(|(_, &v)| v).collect();
2518 if s.is_empty() { s.push(1); }
2519 s
2520 };
2521 Tensor::from_vec(result, &final_shape)
2522 }
2523
2524 pub fn std_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2528 let var = self.var_axis(axis, keepdim)?;
2529 Ok(var.map(|x| x.sqrt()))
2530 }
2531
2532 pub fn prod_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2536 self.reduce_axis(axis, keepdim, |vals| {
2537 let mut product = 1.0f64;
2540 for &v in vals { product *= v; }
2541 product
2542 })
2543 }
2544
2545 pub fn sort_axis(&self, axis: usize, descending: bool) -> Result<Tensor, RuntimeError> {
2555 let ndim = self.ndim();
2556 if axis >= ndim {
2557 return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2558 }
2559 let data = self.to_vec();
2560 let axis_len = self.shape[axis];
2561 let out_shape = self.shape.clone();
2562 let out_numel = Self::shape_numel(&out_shape);
2563
2564 let mut iter_shape: Vec<usize> = Vec::new();
2566 for (i, &s) in self.shape.iter().enumerate() {
2567 if i != axis { iter_shape.push(s); }
2568 }
2569 let n_slices: usize = iter_shape.iter().product::<usize>().max(1);
2570
2571 let mut result = vec![0.0f64; out_numel];
2572
2573 let mut pos = vec![0usize; ndim];
2575 for slice_idx in 0..n_slices {
2576 let mut remaining = slice_idx;
2578 let mut dim_idx = 0;
2579 for d in 0..ndim {
2580 if d == axis {
2581 pos[d] = 0;
2582 } else {
2583 let stride = {
2584 let mut s = 1usize;
2585 let mut di = 0;
2586 for d2 in 0..ndim {
2587 if d2 == axis { continue; }
2588 if di > dim_idx { s *= self.shape[d2]; }
2589 di += 1;
2590 }
2591 s
2592 };
2593 pos[d] = remaining / stride;
2594 remaining %= stride;
2595 dim_idx += 1;
2596 }
2597 }
2598
2599 let mut vals: Vec<(f64, usize)> = Vec::with_capacity(axis_len);
2601 for k in 0..axis_len {
2602 let mut flat = self.offset;
2603 for d in 0..ndim {
2604 let idx = if d == axis { k } else { pos[d] };
2605 flat += idx * self.strides[d];
2606 }
2607 vals.push((data[flat], k));
2608 }
2609
2610 if descending {
2612 vals.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
2613 .then(a.1.cmp(&b.1)));
2614 } else {
2615 vals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
2616 .then(a.1.cmp(&b.1)));
2617 }
2618
2619 for (k, &(v, _)) in vals.iter().enumerate() {
2621 let mut flat = 0;
2622 let out_strides_local = Self::compute_strides(&out_shape);
2623 for d in 0..ndim {
2624 let idx = if d == axis { k } else { pos[d] };
2625 flat += idx * out_strides_local[d];
2626 }
2627 result[flat] = v;
2628 }
2629 }
2630
2631 Tensor::from_vec(result, &out_shape)
2632 }
2633
2634 pub fn argsort_axis(&self, axis: usize, descending: bool) -> Result<Tensor, RuntimeError> {
2639 let ndim = self.ndim();
2640 if axis >= ndim {
2641 return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2642 }
2643 let data = self.to_vec();
2644 let axis_len = self.shape[axis];
2645 let out_shape = self.shape.clone();
2646 let out_numel = Self::shape_numel(&out_shape);
2647
2648 let mut iter_shape: Vec<usize> = Vec::new();
2649 for (i, &s) in self.shape.iter().enumerate() {
2650 if i != axis { iter_shape.push(s); }
2651 }
2652 let n_slices: usize = iter_shape.iter().product::<usize>().max(1);
2653
2654 let mut result = vec![0.0f64; out_numel];
2655 let mut pos = vec![0usize; ndim];
2656
2657 for slice_idx in 0..n_slices {
2658 let mut remaining = slice_idx;
2659 let mut dim_idx = 0;
2660 for d in 0..ndim {
2661 if d == axis {
2662 pos[d] = 0;
2663 } else {
2664 let stride = {
2665 let mut s = 1usize;
2666 let mut di = 0;
2667 for d2 in 0..ndim {
2668 if d2 == axis { continue; }
2669 if di > dim_idx { s *= self.shape[d2]; }
2670 di += 1;
2671 }
2672 s
2673 };
2674 pos[d] = remaining / stride;
2675 remaining %= stride;
2676 dim_idx += 1;
2677 }
2678 }
2679
2680 let mut vals: Vec<(f64, usize)> = Vec::with_capacity(axis_len);
2681 for k in 0..axis_len {
2682 let mut flat = self.offset;
2683 for d in 0..ndim {
2684 let idx = if d == axis { k } else { pos[d] };
2685 flat += idx * self.strides[d];
2686 }
2687 vals.push((data[flat], k));
2688 }
2689
2690 if descending {
2691 vals.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
2692 .then(a.1.cmp(&b.1)));
2693 } else {
2694 vals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
2695 .then(a.1.cmp(&b.1)));
2696 }
2697
2698 for (k, &(_, orig_idx)) in vals.iter().enumerate() {
2699 let out_strides_local = Self::compute_strides(&out_shape);
2700 let mut flat = 0;
2701 for d in 0..ndim {
2702 let idx = if d == axis { k } else { pos[d] };
2703 flat += idx * out_strides_local[d];
2704 }
2705 result[flat] = orig_idx as f64;
2706 }
2707 }
2708
2709 Tensor::from_vec(result, &out_shape)
2710 }
2711
2712 pub fn einsum(notation: &str, inputs: &[&Tensor]) -> Result<Tensor, RuntimeError> {
2721 let parts: Vec<&str> = notation.split("->").collect();
2723 if parts.len() != 2 {
2724 return Err(RuntimeError::InvalidOperation(
2725 format!("einsum: expected 'subscripts->output' notation, got '{}'", notation),
2726 ));
2727 }
2728 let input_specs: Vec<&str> = parts[0].split(',').collect();
2729 let output_spec = parts[1];
2730
2731 if input_specs.len() != inputs.len() {
2732 return Err(RuntimeError::InvalidOperation(
2733 format!("einsum: {} input specs but {} tensors", input_specs.len(), inputs.len()),
2734 ));
2735 }
2736
2737 let mut label_size = std::collections::BTreeMap::new();
2739 for (i, &spec) in input_specs.iter().enumerate() {
2740 let chars: Vec<char> = spec.chars().collect();
2741 if chars.len() != inputs[i].ndim() {
2742 return Err(RuntimeError::InvalidOperation(
2743 format!("einsum: spec '{}' has {} dims but tensor has {}", spec, chars.len(), inputs[i].ndim()),
2744 ));
2745 }
2746 for (d, &c) in chars.iter().enumerate() {
2747 let sz = inputs[i].shape()[d];
2748 if let Some(&prev) = label_size.get(&c) {
2749 if prev != sz {
2750 return Err(RuntimeError::InvalidOperation(
2751 format!("einsum: label '{}' has conflicting sizes {} vs {}", c, prev, sz),
2752 ));
2753 }
2754 } else {
2755 label_size.insert(c, sz);
2756 }
2757 }
2758 }
2759
2760 let output_chars: Vec<char> = output_spec.chars().collect();
2762 let output_shape: Vec<usize> = output_chars.iter()
2763 .map(|c| label_size.get(c).copied().ok_or_else(||
2764 RuntimeError::InvalidOperation(format!("einsum: unknown label '{}' in output", c))))
2765 .collect::<Result<_, _>>()?;
2766 let out_numel = Self::shape_numel(&output_shape);
2767
2768 let output_set: std::collections::BTreeSet<char> = output_chars.iter().copied().collect();
2770 let contract_labels: Vec<char> = label_size.keys()
2771 .filter(|c| !output_set.contains(c))
2772 .copied()
2773 .collect();
2774 let contract_sizes: Vec<usize> = contract_labels.iter()
2775 .map(|c| label_size[c])
2776 .collect();
2777 let contract_numel: usize = contract_sizes.iter().product::<usize>().max(1);
2778
2779 let input_chars: Vec<Vec<char>> = input_specs.iter().map(|s| s.chars().collect()).collect();
2781
2782 let out_strides = Self::compute_strides(&output_shape);
2784 let mut result = vec![0.0f64; out_numel];
2785
2786 let input_data: Vec<Vec<f64>> = inputs.iter().map(|t| t.to_vec()).collect();
2788 let input_strides: Vec<Vec<usize>> = inputs.iter().map(|t| t.strides.clone()).collect();
2789 let input_offsets: Vec<usize> = inputs.iter().map(|t| t.offset).collect();
2790
2791 for out_idx in 0..out_numel {
2792 let mut label_vals = std::collections::BTreeMap::new();
2794 let mut remaining = out_idx;
2795 for (d, &c) in output_chars.iter().enumerate() {
2796 let stride = if d < out_strides.len() { out_strides[d] } else { 1 };
2797 label_vals.insert(c, remaining / stride);
2798 remaining %= stride;
2799 }
2800
2801 let mut acc = BinnedAccumulatorF64::new();
2802 for cidx in 0..contract_numel {
2804 let mut cr = cidx;
2806 for (ci, &cl) in contract_labels.iter().enumerate() {
2807 let stride: usize = contract_sizes[ci+1..].iter().product::<usize>().max(1);
2808 label_vals.insert(cl, cr / stride);
2809 cr %= stride;
2810 }
2811
2812 let mut product = 1.0f64;
2814 for (inp_idx, chars) in input_chars.iter().enumerate() {
2815 let mut flat = input_offsets[inp_idx];
2816 for (d, &c) in chars.iter().enumerate() {
2817 flat += label_vals[&c] * input_strides[inp_idx][d];
2818 }
2819 product *= input_data[inp_idx][flat];
2820 }
2821 acc.add(product);
2822 }
2823 result[out_idx] = acc.finalize();
2824 }
2825
2826 if output_shape.is_empty() {
2827 Tensor::from_vec(result, &[1])
2828 } else {
2829 Tensor::from_vec(result, &output_shape)
2830 }
2831 }
2832
2833 pub fn unsqueeze(&self, dim: usize) -> Result<Tensor, RuntimeError> {
2842 let ndim = self.ndim();
2843 if dim > ndim {
2844 return Err(RuntimeError::IndexOutOfBounds { index: dim, length: ndim + 1 });
2845 }
2846 let mut new_shape = self.shape.clone();
2847 new_shape.insert(dim, 1);
2848 self.reshape(&new_shape)
2849 }
2850
2851 pub fn squeeze(&self, dim: Option<usize>) -> Result<Tensor, RuntimeError> {
2854 match dim {
2855 Some(d) => {
2856 if d >= self.ndim() {
2857 return Err(RuntimeError::IndexOutOfBounds { index: d, length: self.ndim() });
2858 }
2859 if self.shape[d] != 1 {
2860 return Err(RuntimeError::InvalidOperation(
2861 format!("squeeze: dimension {} has size {}, not 1", d, self.shape[d]),
2862 ));
2863 }
2864 let mut new_shape = self.shape.clone();
2865 new_shape.remove(d);
2866 if new_shape.is_empty() {
2867 new_shape.push(1); }
2869 self.reshape(&new_shape)
2870 }
2871 None => {
2872 let new_shape: Vec<usize> = self.shape.iter()
2873 .filter(|&&s| s != 1)
2874 .copied()
2875 .collect();
2876 let new_shape = if new_shape.is_empty() { vec![1] } else { new_shape };
2877 self.reshape(&new_shape)
2878 }
2879 }
2880 }
2881
2882 pub fn expand(&self, target_shape: &[usize]) -> Result<Tensor, RuntimeError> {
2886 self.broadcast_to(target_shape)
2887 }
2888
2889 pub fn flatten(&self, start_dim: usize, end_dim: usize) -> Result<Tensor, RuntimeError> {
2891 if start_dim > end_dim || end_dim >= self.ndim() {
2892 return Err(RuntimeError::InvalidOperation(
2893 format!("flatten: invalid dim range [{}, {}] for {}D tensor", start_dim, end_dim, self.ndim()),
2894 ));
2895 }
2896 let mut new_shape = Vec::new();
2897 for i in 0..start_dim {
2898 new_shape.push(self.shape[i]);
2899 }
2900 let flat_size: usize = self.shape[start_dim..=end_dim].iter().product();
2901 new_shape.push(flat_size);
2902 for i in (end_dim + 1)..self.ndim() {
2903 new_shape.push(self.shape[i]);
2904 }
2905 self.reshape(&new_shape)
2906 }
2907
2908 pub fn chunk(&self, n: usize, dim: usize) -> Result<Vec<Tensor>, RuntimeError> {
2910 if dim >= self.ndim() {
2911 return Err(RuntimeError::IndexOutOfBounds { index: dim, length: self.ndim() });
2912 }
2913 if n == 0 {
2914 return Err(RuntimeError::InvalidOperation("chunk: n must be > 0".into()));
2915 }
2916 let dim_size = self.shape[dim];
2917 let chunk_size = (dim_size + n - 1) / n;
2918 let mut sizes = Vec::new();
2919 let mut remaining = dim_size;
2920 while remaining > 0 {
2921 let s = remaining.min(chunk_size);
2922 sizes.push(s);
2923 remaining -= s;
2924 }
2925 self.split(&sizes, dim)
2926 }
2927
2928 pub fn split(&self, sizes: &[usize], dim: usize) -> Result<Vec<Tensor>, RuntimeError> {
2930 if dim >= self.ndim() {
2931 return Err(RuntimeError::IndexOutOfBounds { index: dim, length: self.ndim() });
2932 }
2933 let total: usize = sizes.iter().sum();
2934 if total != self.shape[dim] {
2935 return Err(RuntimeError::InvalidOperation(
2936 format!("split: sizes sum {} != dim size {}", total, self.shape[dim]),
2937 ));
2938 }
2939
2940 let mut results = Vec::new();
2941 let mut offset = 0;
2942
2943 for &sz in sizes {
2944 let ranges: Vec<(usize, usize)> = self.shape.iter()
2945 .enumerate()
2946 .map(|(i, &s)| {
2947 if i == dim { (offset, offset + sz) } else { (0, s) }
2948 })
2949 .collect();
2950 let chunk = self.slice(&ranges)?;
2951 results.push(chunk.to_contiguous());
2953 offset += sz;
2954 }
2955
2956 Ok(results)
2957 }
2958
2959 pub fn scale_add(&self, alpha: f64, other: &Tensor, beta: f64) -> Result<Tensor, RuntimeError> {
2964 if self.shape != other.shape {
2965 return Err(RuntimeError::InvalidOperation(
2966 "scale_add: shape mismatch".to_string(),
2967 ));
2968 }
2969 let a = self.to_vec();
2970 let b = other.to_vec();
2971 let result: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| alpha * x + beta * y).collect();
2972 Tensor::from_vec(result, &self.shape)
2973 }
2974}
2975