Skip to main content

cjc_runtime/
tensor.rs

1
2use cjc_repro::Rng;
3
4use crate::accumulator::{binned_sum_f64, BinnedAccumulatorF64};
5use cjc_repro::KahanAccumulatorF64;
6
7use crate::accumulator;
8use crate::buffer::Buffer;
9use crate::dispatch;
10use crate::error::RuntimeError;
11use crate::kernel as kernel_fns;
12use crate::tensor_simd::{self, BinOp, UnaryOp};
13use crate::tensor_tiled::TiledMatmul;
14
15// ---------------------------------------------------------------------------
16// 2. Tensor Runtime
17// ---------------------------------------------------------------------------
18
19/// An N-dimensional tensor backed by a `Buffer<f64>`.
20///
21/// Supports element-wise arithmetic, matrix multiplication (2-D), and
22/// numerically-stable reductions via BinnedAccumulator summation.
23#[derive(Debug, Clone)]
24pub struct Tensor {
25    pub buffer: Buffer<f64>,
26    pub(crate) shape: Vec<usize>,
27    pub(crate) strides: Vec<usize>,
28    pub(crate) offset: usize,
29}
30
31impl Tensor {
32    // -- Construction -------------------------------------------------------
33
34    /// Compute row-major strides for a given shape.
35    pub(crate) fn compute_strides(shape: &[usize]) -> Vec<usize> {
36        let mut strides = vec![1usize; shape.len()];
37        for i in (0..shape.len().saturating_sub(1)).rev() {
38            strides[i] = strides[i + 1] * shape[i + 1];
39        }
40        strides
41    }
42
43    /// Total number of elements implied by `shape`.
44    fn shape_numel(shape: &[usize]) -> usize {
45        shape.iter().product()
46    }
47
48    /// Create a tensor filled with zeros.
49    pub fn zeros(shape: &[usize]) -> Self {
50        let numel = Self::shape_numel(shape);
51        Tensor {
52            buffer: Buffer::alloc(numel, 0.0),
53            shape: shape.to_vec(),
54            strides: Self::compute_strides(shape),
55            offset: 0,
56        }
57    }
58
59    /// Create a tensor filled with ones.
60    pub fn ones(shape: &[usize]) -> Self {
61        let numel = Self::shape_numel(shape);
62        Tensor {
63            buffer: Buffer::alloc(numel, 1.0),
64            shape: shape.to_vec(),
65            strides: Self::compute_strides(shape),
66            offset: 0,
67        }
68    }
69
70    /// Create a tensor filled with samples from the standard normal
71    /// distribution, drawn deterministically from `rng`.
72    pub fn randn(shape: &[usize], rng: &mut Rng) -> Self {
73        let numel = Self::shape_numel(shape);
74        let data: Vec<f64> = (0..numel).map(|_| rng.next_normal_f64()).collect();
75        Tensor {
76            buffer: Buffer::from_vec(data),
77            shape: shape.to_vec(),
78            strides: Self::compute_strides(shape),
79            offset: 0,
80        }
81    }
82
83    /// Create a tensor from raw data and a shape. Returns an error if the
84    /// number of elements does not match the shape.
85    pub fn from_vec(data: Vec<f64>, shape: &[usize]) -> Result<Self, RuntimeError> {
86        let numel = Self::shape_numel(shape);
87        if data.len() != numel {
88            return Err(RuntimeError::ShapeMismatch {
89                expected: numel,
90                got: data.len(),
91            });
92        }
93        Ok(Tensor {
94            buffer: Buffer::from_vec(data),
95            shape: shape.to_vec(),
96            strides: Self::compute_strides(shape),
97            offset: 0,
98        })
99    }
100
101    // -- Accessors ----------------------------------------------------------
102
103    /// The shape of this tensor.
104    pub fn shape(&self) -> &[usize] {
105        &self.shape
106    }
107
108    /// Number of dimensions.
109    pub fn ndim(&self) -> usize {
110        self.shape.len()
111    }
112
113    /// Total number of elements.
114    pub fn len(&self) -> usize {
115        Self::shape_numel(&self.shape)
116    }
117
118    /// Whether the tensor has zero elements.
119    pub fn is_empty(&self) -> bool {
120        self.len() == 0
121    }
122
123    /// Flatten a multi-dimensional index to a linear offset in the buffer.
124    fn linear_index(&self, indices: &[usize]) -> Result<usize, RuntimeError> {
125        if indices.len() != self.shape.len() {
126            return Err(RuntimeError::DimensionMismatch {
127                expected: self.shape.len(),
128                got: indices.len(),
129            });
130        }
131        let mut off = self.offset;
132        for (i, &idx) in indices.iter().enumerate() {
133            if idx >= self.shape[i] {
134                return Err(RuntimeError::IndexOutOfBounds {
135                    index: idx,
136                    length: self.shape[i],
137                });
138            }
139            off += idx * self.strides[i];
140        }
141        Ok(off)
142    }
143
144    /// Whether this tensor is contiguous in memory (row-major, no offset).
145    pub fn is_contiguous(&self) -> bool {
146        if self.offset != 0 {
147            return false;
148        }
149        let expected = Self::compute_strides(&self.shape);
150        self.strides == expected
151    }
152
153    /// Create a zero-copy slice (view) of this tensor.
154    /// `ranges` contains `(start, end)` for each dimension.
155    pub fn slice(&self, ranges: &[(usize, usize)]) -> Result<Tensor, RuntimeError> {
156        if ranges.len() != self.shape.len() {
157            return Err(RuntimeError::DimensionMismatch {
158                expected: self.shape.len(),
159                got: ranges.len(),
160            });
161        }
162        let mut new_offset = self.offset;
163        let mut new_shape = Vec::with_capacity(ranges.len());
164        for (i, &(start, end)) in ranges.iter().enumerate() {
165            if end > self.shape[i] || start > end {
166                return Err(RuntimeError::IndexOutOfBounds {
167                    index: end,
168                    length: self.shape[i],
169                });
170            }
171            new_offset += start * self.strides[i];
172            new_shape.push(end - start);
173        }
174        Ok(Tensor {
175            buffer: self.buffer.clone(), // shared — zero copy
176            shape: new_shape,
177            strides: self.strides.clone(),
178            offset: new_offset,
179        })
180    }
181
182    /// Materialize a contiguous copy if this tensor is non-contiguous.
183    pub fn to_contiguous(&self) -> Tensor {
184        if self.is_contiguous() {
185            return self.clone();
186        }
187        let data = self.to_vec();
188        Tensor {
189            buffer: Buffer::from_vec(data),
190            shape: self.shape.clone(),
191            strides: Self::compute_strides(&self.shape),
192            offset: 0,
193        }
194    }
195
196    /// Create a broadcast view of this tensor to `target_shape`.
197    /// Uses stride=0 for dimensions that need broadcasting (size 1 -> target size).
198    pub fn broadcast_to(&self, target_shape: &[usize]) -> Result<Tensor, RuntimeError> {
199        let src_ndim = self.shape.len();
200        let tgt_ndim = target_shape.len();
201        if tgt_ndim < src_ndim {
202            return Err(RuntimeError::InvalidOperation(
203                "cannot broadcast to a smaller rank".to_string(),
204            ));
205        }
206        let pad = tgt_ndim - src_ndim;
207        let mut new_strides = vec![0usize; tgt_ndim];
208        for i in 0..tgt_ndim {
209            if i < pad {
210                // Padded dimension: stride = 0 (broadcast)
211                new_strides[i] = 0;
212            } else {
213                let src_i = i - pad;
214                if self.shape[src_i] == target_shape[i] {
215                    new_strides[i] = self.strides[src_i];
216                } else if self.shape[src_i] == 1 {
217                    new_strides[i] = 0; // broadcast
218                } else {
219                    return Err(RuntimeError::ShapeMismatch {
220                        expected: target_shape[i],
221                        got: self.shape[src_i],
222                    });
223                }
224            }
225        }
226        Ok(Tensor {
227            buffer: self.buffer.clone(),
228            shape: target_shape.to_vec(),
229            strides: new_strides,
230            offset: self.offset,
231        })
232    }
233
234    /// Read the element at the given multi-dimensional index.
235    pub fn get(&self, indices: &[usize]) -> Result<f64, RuntimeError> {
236        let offset = self.linear_index(indices)?;
237        self.buffer
238            .get(offset)
239            .ok_or(RuntimeError::IndexOutOfBounds {
240                index: offset,
241                length: self.buffer.len(),
242            })
243    }
244
245    /// Write the element at the given multi-dimensional index.
246    pub fn set(&mut self, indices: &[usize], val: f64) -> Result<(), RuntimeError> {
247        let offset = self.linear_index(indices)?;
248        self.buffer.set(offset, val)
249    }
250
251    /// Extract the raw data as a `Vec<f64>`, respecting strides and offset.
252    pub fn to_vec(&self) -> Vec<f64> {
253        if self.is_contiguous() {
254            let full = self.buffer.borrow_data();
255            let numel = self.len();
256            if full.len() == numel {
257                return full.to_vec();
258            }
259            // Buffer may be larger than the tensor's logical size
260            // (e.g. Scratchpad pre-allocates extra capacity)
261            return full[..numel].to_vec();
262        }
263        // Non-contiguous: iterate via strided access
264        let numel = self.len();
265        let mut result = Vec::with_capacity(numel);
266        let ndim = self.shape.len();
267        let mut indices = vec![0usize; ndim];
268        for _ in 0..numel {
269            let mut off = self.offset;
270            for d in 0..ndim {
271                off += indices[d] * self.strides[d];
272            }
273            result.push(self.buffer.get(off).unwrap_or(0.0));
274            // Increment multi-index (row-major order)
275            for d in (0..ndim).rev() {
276                indices[d] += 1;
277                if indices[d] < self.shape[d] {
278                    break;
279                }
280                indices[d] = 0;
281            }
282        }
283        result
284    }
285
286    // -- Reshape ------------------------------------------------------------
287
288    /// Reshape to `new_shape`. The new shape must have the same total number
289    /// of elements. The returned tensor **shares** the underlying buffer.
290    pub fn reshape(&self, new_shape: &[usize]) -> Result<Tensor, RuntimeError> {
291        let new_numel = Self::shape_numel(new_shape);
292        if new_numel != self.len() {
293            return Err(RuntimeError::ShapeMismatch {
294                expected: self.len(),
295                got: new_numel,
296            });
297        }
298        // Reshape requires contiguous data; materialize if needed
299        let tensor = if self.is_contiguous() { self.clone() } else { self.to_contiguous() };
300        Ok(Tensor {
301            buffer: tensor.buffer,
302            shape: new_shape.to_vec(),
303            strides: Self::compute_strides(new_shape),
304            offset: 0,
305        })
306    }
307
308    // -- Element-wise operations --------------------------------------------
309
310    /// Apply a binary operation element-wise with broadcasting support.
311    fn elementwise_binop(
312        &self,
313        other: &Tensor,
314        op: impl Fn(f64, f64) -> f64,
315    ) -> Result<Tensor, RuntimeError> {
316        if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
317            // Fast path: same shape, both contiguous — borrow without cloning
318            let a = self.buffer.borrow_data();
319            let b = other.buffer.borrow_data();
320            let data: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| op(x, y)).collect();
321            return Ok(Tensor {
322                buffer: Buffer::from_vec(data),
323                shape: self.shape.clone(),
324                strides: Self::compute_strides(&self.shape),
325                offset: 0,
326            });
327        }
328
329        // Broadcasting path: compute result shape
330        let result_shape = Self::broadcast_result_shape(&self.shape, &other.shape)?;
331        let a_broadcast = self.broadcast_to(&result_shape)?;
332        let b_broadcast = other.broadcast_to(&result_shape)?;
333
334        let numel = Self::shape_numel(&result_shape);
335        let ndim = result_shape.len();
336        let mut data = Vec::with_capacity(numel);
337        let mut indices = vec![0usize; ndim];
338
339        for _ in 0..numel {
340            let mut off_a = a_broadcast.offset;
341            let mut off_b = b_broadcast.offset;
342            for d in 0..ndim {
343                off_a += indices[d] * a_broadcast.strides[d];
344                off_b += indices[d] * b_broadcast.strides[d];
345            }
346            let va = a_broadcast.buffer.get(off_a).ok_or_else(|| {
347                RuntimeError::InvalidOperation(format!(
348                    "broadcast binop: left operand index {} out of bounds (buffer len {})",
349                    off_a,
350                    a_broadcast.buffer.len()
351                ))
352            })?;
353            let vb = b_broadcast.buffer.get(off_b).ok_or_else(|| {
354                RuntimeError::InvalidOperation(format!(
355                    "broadcast binop: right operand index {} out of bounds (buffer len {})",
356                    off_b,
357                    b_broadcast.buffer.len()
358                ))
359            })?;
360            data.push(op(va, vb));
361
362            for d in (0..ndim).rev() {
363                indices[d] += 1;
364                if indices[d] < result_shape[d] {
365                    break;
366                }
367                indices[d] = 0;
368            }
369        }
370
371        Ok(Tensor {
372            buffer: Buffer::from_vec(data),
373            shape: result_shape.clone(),
374            strides: Self::compute_strides(&result_shape),
375            offset: 0,
376        })
377    }
378
379    /// Compute the broadcast result shape for two shapes (NumPy rules).
380    fn broadcast_result_shape(a: &[usize], b: &[usize]) -> Result<Vec<usize>, RuntimeError> {
381        let max_ndim = a.len().max(b.len());
382        let mut result = Vec::with_capacity(max_ndim);
383        for i in 0..max_ndim {
384            let da = if i < max_ndim - a.len() { 1 } else { a[i - (max_ndim - a.len())] };
385            let db = if i < max_ndim - b.len() { 1 } else { b[i - (max_ndim - b.len())] };
386            if da == db {
387                result.push(da);
388            } else if da == 1 {
389                result.push(db);
390            } else if db == 1 {
391                result.push(da);
392            } else {
393                return Err(RuntimeError::ShapeMismatch {
394                    expected: da,
395                    got: db,
396                });
397            }
398        }
399        Ok(result)
400    }
401
402    /// SIMD-accelerated element-wise binary operation for known ops.
403    ///
404    /// For same-shape contiguous tensors, uses AVX2 (4-wide f64) when available.
405    /// Falls back to the generic closure path for broadcast cases.
406    fn elementwise_binop_simd(
407        &self,
408        other: &Tensor,
409        op: BinOp,
410        fallback: impl Fn(f64, f64) -> f64,
411    ) -> Result<Tensor, RuntimeError> {
412        if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
413            // SIMD fast path: same shape, both contiguous
414            let a = self.buffer.borrow_data();
415            let b = other.buffer.borrow_data();
416            let data = tensor_simd::simd_binop(&a, &b, op);
417            return Ok(Tensor {
418                buffer: Buffer::from_vec(data),
419                shape: self.shape.clone(),
420                strides: Self::compute_strides(&self.shape),
421                offset: 0,
422            });
423        }
424        // Broadcast path: fall through to generic
425        self.elementwise_binop(other, fallback)
426    }
427
428    /// Element-wise addition (SIMD-accelerated for contiguous same-shape tensors).
429    pub fn add(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
430        self.elementwise_binop_simd(other, BinOp::Add, |a, b| a + b)
431    }
432
433    /// Element-wise subtraction (SIMD-accelerated for contiguous same-shape tensors).
434    pub fn sub(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
435        self.elementwise_binop_simd(other, BinOp::Sub, |a, b| a - b)
436    }
437
438    /// Element-wise (Hadamard) multiplication (SIMD-accelerated for contiguous same-shape tensors).
439    pub fn mul_elem(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
440        self.elementwise_binop_simd(other, BinOp::Mul, |a, b| a * b)
441    }
442
443    /// Element-wise division (SIMD-accelerated for contiguous same-shape tensors).
444    pub fn div_elem(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
445        self.elementwise_binop_simd(other, BinOp::Div, |a, b| a / b)
446    }
447
448    /// Fused multiply-add: `self * b + c` element-wise in a single pass.
449    ///
450    /// Eliminates the intermediate tensor that separate mul + add would create.
451    /// Uses software FMA (`a * b + c` with two roundings, not hardware FMA)
452    /// to preserve bit-identity with the non-fused path.
453    pub fn fused_mul_add(&self, b: &Tensor, c: &Tensor) -> Result<Tensor, RuntimeError> {
454        if self.shape != b.shape || self.shape != c.shape {
455            return Err(RuntimeError::InvalidOperation(
456                "broadcast_fma: all three tensors must have the same shape".to_string(),
457            ));
458        }
459        if self.is_contiguous() && b.is_contiguous() && c.is_contiguous() {
460            let a_data = self.buffer.borrow_data();
461            let b_data = b.buffer.borrow_data();
462            let c_data = c.buffer.borrow_data();
463            let n = a_data.len();
464            let mut out = vec![0.0f64; n];
465            // Software FMA: a*b + c (two roundings — NOT hardware FMA which uses one rounding).
466            // This produces identical results to separate broadcast2("mul") + broadcast2("add").
467            for i in 0..n {
468                out[i] = a_data[i] * b_data[i] + c_data[i];
469            }
470            return Ok(Tensor {
471                buffer: Buffer::from_vec(out),
472                shape: self.shape.clone(),
473                strides: Self::compute_strides(&self.shape),
474                offset: 0,
475            });
476        }
477        // Non-contiguous fallback: mul then add
478        let temp = self.mul_elem(b)?;
479        temp.add(c)
480    }
481
482    // ── v0.1 Broadcasting: additional element-wise binary ops ──
483
484    /// Element-wise power: `a^b`.
485    pub fn elem_pow(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
486        self.elementwise_binop(other, |a, b| a.powf(b))
487    }
488
489    /// Element-wise minimum.
490    pub fn elem_min(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
491        self.elementwise_binop(other, |a, b| a.min(b))
492    }
493
494    /// Element-wise maximum.
495    pub fn elem_max(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
496        self.elementwise_binop(other, |a, b| a.max(b))
497    }
498
499    /// Element-wise atan2(self, other).
500    pub fn elem_atan2(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
501        self.elementwise_binop(other, |a, b| a.atan2(b))
502    }
503
504    /// Element-wise hypot(self, other).
505    pub fn elem_hypot(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
506        self.elementwise_binop(other, |a, b| a.hypot(b))
507    }
508
509    /// Apply a unary function to every element, returning a new contiguous tensor.
510    pub fn map(&self, f: impl Fn(f64) -> f64) -> Tensor {
511        let data: Vec<f64> = self.to_vec().iter().map(|&x| f(x)).collect();
512        Tensor {
513            buffer: Buffer::from_vec(data),
514            shape: self.shape.clone(),
515            strides: Self::compute_strides(&self.shape),
516            offset: 0,
517        }
518    }
519
520    /// SIMD-accelerated unary map for known operations (sqrt, abs, neg, relu).
521    ///
522    /// Uses AVX2 (4-wide f64) when available, scalar fallback otherwise.
523    /// Bit-identical to `map(f)` for the supported operations.
524    pub fn map_simd(&self, op: UnaryOp) -> Tensor {
525        let src = self.to_vec();
526        let data = tensor_simd::simd_unary(&src, op);
527        Tensor {
528            buffer: Buffer::from_vec(data),
529            shape: self.shape.clone(),
530            strides: Self::compute_strides(&self.shape),
531            offset: 0,
532        }
533    }
534
535    // -- Reductions (using BinnedAccumulator) --------------------------------
536
537    /// Sum of all elements (binned accumulation — order-invariant, deterministic).
538    pub fn sum(&self) -> f64 {
539        let data = self.buffer.borrow_data();
540        binned_sum_f64(&data)
541    }
542
543    /// Sum of all elements using BinnedAccumulator (order-invariant, deterministic).
544    ///
545    /// Bit-identical results regardless of element ordering or reduction schedule.
546    pub fn binned_sum(&self) -> f64 {
547        let data = self.buffer.borrow_data();
548        accumulator::binned_sum_f64(&data)
549    }
550
551    /// Sum with dispatched strategy based on execution context.
552    ///
553    /// Uses Kahan in serial mode, Binned in parallel/@nogc/strict/linalg mode.
554    pub fn dispatched_sum(&self, ctx: &dispatch::ReductionContext) -> f64 {
555        let data = self.buffer.borrow_data();
556        dispatch::dispatch_sum_f64(&data, ctx)
557    }
558
559    /// Mean of all elements (binned sum / count).
560    pub fn mean(&self) -> f64 {
561        let n = self.len();
562        if n == 0 {
563            return 0.0;
564        }
565        self.sum() / n as f64
566    }
567
568    /// Mean with dispatched strategy based on execution context.
569    pub fn dispatched_mean(&self, ctx: &dispatch::ReductionContext) -> f64 {
570        let n = self.len();
571        if n == 0 {
572            return 0.0;
573        }
574        self.dispatched_sum(ctx) / n as f64
575    }
576
577    /// Sum along a specific axis, returning a tensor with that dimension reduced.
578    ///
579    /// Supports N-D tensors. The reduced axis becomes size 1 in the output.
580    /// Uses BinnedAccumulator for order-invariant, deterministic summation.
581    ///
582    /// Examples:
583    /// - 2D [M, N] with axis=0: result [1, N] (sum columns)
584    /// - 2D [M, N] with axis=1: result [M, 1] (sum rows)
585    /// - 3D [A, B, C] with axis=1: result [A, 1, C]
586    pub fn sum_axis(&self, axis: usize) -> Result<Tensor, RuntimeError> {
587        let ndim = self.ndim();
588        if axis >= ndim {
589            return Err(RuntimeError::IndexOutOfBounds {
590                index: axis,
591                length: ndim,
592            });
593        }
594
595        // Build output shape: same as input but with axis dimension = 1
596        let mut out_shape = self.shape.clone();
597        out_shape[axis] = 1;
598        let out_numel = Self::shape_numel(&out_shape);
599        let out_strides = Self::compute_strides(&out_shape);
600
601        let data = self.to_vec();
602        let axis_len = self.shape[axis];
603        let mut result = vec![0.0f64; out_numel];
604
605        // For each output position, sum over the reduced axis with binned accumulation.
606        let mut indices = vec![0usize; ndim];
607        for out_idx in 0..out_numel {
608            // Compute the N-D index from flat output index
609            {
610                let mut remaining = out_idx;
611                for d in 0..ndim {
612                    indices[d] = remaining / out_strides[d];
613                    remaining %= out_strides[d];
614                }
615            }
616
617            let mut acc = BinnedAccumulatorF64::new();
618            for k in 0..axis_len {
619                // Compute input flat index with indices[axis] = k
620                let mut flat = self.offset;
621                for d in 0..ndim {
622                    let idx = if d == axis { k } else { indices[d] };
623                    flat += idx * self.strides[d];
624                }
625                acc.add(data[flat]);
626            }
627            result[out_idx] = acc.finalize();
628        }
629
630        Tensor::from_vec(result, &out_shape)
631    }
632
633    // -- Matrix multiplication (2-D only) -----------------------------------
634
635    /// Negate every element, returning a new tensor.
636    pub fn neg(&self) -> Tensor {
637        self.map(|x| -x)
638    }
639
640    /// Transpose a tensor. For 2-D: swaps rows and columns (zero-copy view).
641    /// For N-D: reverses all axes (zero-copy view).
642    pub fn transpose(&self) -> Tensor {
643        let ndim = self.ndim();
644        if ndim <= 1 {
645            return self.clone();
646        }
647        // Reverse shape and strides — zero-copy view
648        let mut new_shape = self.shape.clone();
649        let mut new_strides = self.strides.clone();
650        new_shape.reverse();
651        new_strides.reverse();
652        Tensor {
653            buffer: self.buffer.clone(), // shared — zero copy
654            shape: new_shape,
655            strides: new_strides,
656            offset: self.offset,
657        }
658    }
659
660    /// Transpose with explicit axis permutation (N-D). Zero-copy view.
661    ///
662    /// `axes` must be a permutation of `[0, 1, ..., ndim-1]`.
663    pub fn transpose_axes(&self, axes: &[usize]) -> Result<Tensor, RuntimeError> {
664        let ndim = self.ndim();
665        if axes.len() != ndim {
666            return Err(RuntimeError::InvalidOperation(
667                format!("transpose_axes: expected {} axes, got {}", ndim, axes.len()),
668            ));
669        }
670        // Validate permutation
671        let mut seen = vec![false; ndim];
672        for &ax in axes {
673            if ax >= ndim {
674                return Err(RuntimeError::IndexOutOfBounds { index: ax, length: ndim });
675            }
676            if seen[ax] {
677                return Err(RuntimeError::InvalidOperation(
678                    format!("transpose_axes: duplicate axis {ax}"),
679                ));
680            }
681            seen[ax] = true;
682        }
683        let new_shape: Vec<usize> = axes.iter().map(|&ax| self.shape[ax]).collect();
684        let new_strides: Vec<usize> = axes.iter().map(|&ax| self.strides[ax]).collect();
685        Ok(Tensor {
686            buffer: self.buffer.clone(),
687            shape: new_shape,
688            strides: new_strides,
689            offset: self.offset,
690        })
691    }
692
693    /// Multiply every element by a scalar, returning a new tensor.
694    pub fn scalar_mul(&self, s: f64) -> Tensor {
695        self.map(|x| x * s)
696    }
697
698    // ── Panicking convenience constructors (used by AD engine) --------
699
700    /// Create a tensor from raw data and shape.
701    /// **Panics** if `data.len()` does not match the shape.
702    pub fn from_vec_unchecked(data: Vec<f64>, shape: &[usize]) -> Tensor {
703        Self::from_vec(data, shape).expect("Tensor::from_vec_unchecked: shape mismatch")
704    }
705
706    /// Element-wise addition. **Panics** on shape mismatch.
707    pub fn add_unchecked(&self, other: &Tensor) -> Tensor {
708        self.add(other).expect("Tensor::add shape mismatch")
709    }
710
711    /// Element-wise subtraction. **Panics** on shape mismatch.
712    pub fn sub_unchecked(&self, other: &Tensor) -> Tensor {
713        self.sub(other).expect("Tensor::sub shape mismatch")
714    }
715
716    /// Element-wise multiplication. **Panics** on shape mismatch.
717    pub fn mul_elem_unchecked(&self, other: &Tensor) -> Tensor {
718        self.mul_elem(other).expect("Tensor::mul_elem shape mismatch")
719    }
720
721    /// Element-wise division. **Panics** on shape mismatch.
722    pub fn div_elem_unchecked(&self, other: &Tensor) -> Tensor {
723        self.div_elem(other).expect("Tensor::div_elem shape mismatch")
724    }
725
726    /// Matrix multiplication. **Panics** on dimension mismatch.
727    pub fn matmul_unchecked(&self, other: &Tensor) -> Tensor {
728        self.matmul(other).expect("Tensor::matmul dimension mismatch")
729    }
730
731    /// Matrix multiplication for 2-D tensors.
732    ///
733    /// `self` is (M, K), `other` is (K, N) => result is (M, N).
734    pub fn matmul(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
735        if self.ndim() != 2 || other.ndim() != 2 {
736            return Err(RuntimeError::InvalidOperation(
737                "matmul requires 2-D tensors".to_string(),
738            ));
739        }
740        let m = self.shape[0];
741        let k = self.shape[1];
742        let k2 = other.shape[0];
743        let n = other.shape[1];
744        if k != k2 {
745            return Err(RuntimeError::DimensionMismatch {
746                expected: k,
747                got: k2,
748            });
749        }
750
751        let a = self.to_vec();
752        let b = other.to_vec();
753
754        // Parallel path (Mode A): parallelize over output rows when the parallel
755        // feature is enabled and the matrix is large enough (>= 256 in any dim).
756        #[cfg(feature = "parallel")]
757        {
758            if m >= 256 || n >= 256 || k >= 256 {
759                return Self::matmul_parallel_mode_a(&a, &b, m, n, k);
760            }
761        }
762
763        // Tiled path: use L2-friendly tiled matmul for medium-to-large matrices.
764        // Threshold: any dimension >= 64 (the default tile size).
765        // NOTE: tiled path uses naive accumulation (not binned) — different
766        // numerical path for large matrices, but better cache locality.
767        if m >= 64 || n >= 64 || k >= 64 {
768            return Self::matmul_tiled(&a, &b, m, n, k);
769        }
770
771        // Sequential path: single-threaded with binned accumulation.
772        Self::matmul_sequential(&a, &b, m, n, k)
773    }
774
775    /// Sequential matmul (always available, deterministic reference).
776    fn matmul_sequential(
777        a: &[f64], b: &[f64], m: usize, n: usize, k: usize,
778    ) -> Result<Tensor, RuntimeError> {
779        let mut result = vec![0.0f64; m * n];
780        for i in 0..m {
781            for j in 0..n {
782                let mut acc = KahanAccumulatorF64::new();
783                for p in 0..k {
784                    acc.add(a[i * k + p] * b[p * n + j]);
785                }
786                result[i * n + j] = acc.finalize();
787            }
788        }
789        Tensor::from_vec(result, &[m, n])
790    }
791
792    /// Tiled matmul: delegates to `TiledMatmul` for L2-cache-friendly tiling.
793    ///
794    /// Used for medium matrices (any dimension >= 64) where cache locality
795    /// matters but parallel overhead isn't justified. The tiled path uses
796    /// naive accumulation (not binned accumulation), trading a small amount of
797    /// floating-point precision for better cache behavior.
798    fn matmul_tiled(
799        a: &[f64], b: &[f64], m: usize, n: usize, k: usize,
800    ) -> Result<Tensor, RuntimeError> {
801        let engine = TiledMatmul::new();
802        let result = engine.matmul(a, m, k, b, n);
803        Tensor::from_vec(result, &[m, n])
804    }
805
806    /// Parallel matmul Mode A: parallelize over output rows using tiled
807    /// micro-kernels for cache locality.
808    ///
809    /// Deterministic because:
810    /// - Each output row is computed by exactly one thread.
811    /// - Within each row, accumulation uses tiled AXPY (deterministic order).
812    /// - No cross-thread reduction or merge of partial sums.
813    ///
814    /// Uses KahanAccumulatorF64 (lightweight) instead of BinnedAccumulatorF64
815    /// (32KB per accumulator) to avoid massive stack pressure in parallel mode.
816    #[cfg(feature = "parallel")]
817    fn matmul_parallel_mode_a(
818        a: &[f64], b: &[f64], m: usize, n: usize, k: usize,
819    ) -> Result<Tensor, RuntimeError> {
820        use rayon::prelude::*;
821        use cjc_repro::KahanAccumulatorF64;
822
823        // For large matrices, use tiled matmul (sequential but cache-friendly).
824        // The tiling provides far better cache locality than parallel row-wise
825        // with column-strided B access, and avoids the 32KB-per-element
826        // BinnedAccumulator overhead that caused the 128→256 regression.
827        if m >= 512 && n >= 512 {
828            // Only use rayon for very large matrices where thread overhead
829            // is amortized. Split into row-bands, each processed with tiled matmul.
830            let band_size = (m + rayon::current_num_threads() - 1) / rayon::current_num_threads();
831            let band_size = band_size.max(64); // At least 64 rows per band
832            let mut result = vec![0.0f64; m * n];
833
834            result
835                .par_chunks_mut(band_size * n)
836                .enumerate()
837                .for_each(|(band_idx, band)| {
838                    let i_start = band_idx * band_size;
839                    let i_end = (i_start + band_size).min(m);
840                    let band_m = i_end - i_start;
841                    let a_band = &a[i_start * k .. i_end * k];
842                    let engine = crate::tensor_tiled::TiledMatmul::new();
843                    let tiled_result = engine.matmul(a_band, band_m, k, b, n);
844                    band[..band_m * n].copy_from_slice(&tiled_result);
845                });
846
847            return Tensor::from_vec(result, &[m, n]);
848        }
849
850        // For medium matrices (256-511), use Kahan accumulator (16 bytes, not 32KB).
851        let mut result = vec![0.0f64; m * n];
852        result
853            .par_chunks_mut(n)
854            .enumerate()
855            .for_each(|(i, row)| {
856                for j in 0..n {
857                    let mut acc = KahanAccumulatorF64::new();
858                    for p in 0..k {
859                        acc.add(a[i * k + p] * b[p * n + j]);
860                    }
861                    row[j] = acc.finalize();
862                }
863            });
864
865        Tensor::from_vec(result, &[m, n])
866    }
867
868    // -- Transformer Kernels ------------------------------------------------
869
870    /// Batched matrix multiplication.
871    ///
872    /// `self` is `[..., M, K]`, `other` is `[..., K, N]` => result is `[..., M, N]`.
873    /// The batch dimensions must be identical (no broadcast).
874    /// For 2-D inputs, delegates to `matmul`.
875    pub fn bmm(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
876        if self.ndim() < 2 || other.ndim() < 2 {
877            return Err(RuntimeError::InvalidOperation(
878                "bmm requires at least 2-D tensors".to_string(),
879            ));
880        }
881        if self.ndim() == 2 && other.ndim() == 2 {
882            return self.matmul(other);
883        }
884        if self.ndim() != other.ndim() {
885            return Err(RuntimeError::InvalidOperation(
886                format!(
887                    "bmm requires same number of dimensions, got {} and {}",
888                    self.ndim(),
889                    other.ndim()
890                ),
891            ));
892        }
893        let nd = self.ndim();
894        let batch_dims_a = &self.shape[..nd - 2];
895        let batch_dims_b = &other.shape[..nd - 2];
896        if batch_dims_a != batch_dims_b {
897            return Err(RuntimeError::InvalidOperation(
898                format!(
899                    "bmm batch dimensions mismatch: {:?} vs {:?}",
900                    batch_dims_a, batch_dims_b
901                ),
902            ));
903        }
904        let m = self.shape[nd - 2];
905        let k = self.shape[nd - 1];
906        let k2 = other.shape[nd - 2];
907        let n = other.shape[nd - 1];
908        if k != k2 {
909            return Err(RuntimeError::DimensionMismatch {
910                expected: k,
911                got: k2,
912            });
913        }
914
915        let batch_size: usize = batch_dims_a.iter().product();
916        let a = self.to_vec();
917        let b = other.to_vec();
918        let mat_a_stride = m * k;
919        let mat_b_stride = k * n;
920        let mat_c_stride = m * n;
921        let mut result = vec![0.0f64; batch_size * mat_c_stride];
922
923        // Helper closure: compute one batch into c_slice
924        let compute_batch = |batch: usize, c_slice: &mut [f64]| {
925            let a_slice = &a[batch * mat_a_stride..(batch + 1) * mat_a_stride];
926            let b_slice = &b[batch * mat_b_stride..(batch + 1) * mat_b_stride];
927
928            if m >= 64 || n >= 64 || k >= 64 {
929                let engine = crate::tensor_tiled::TiledMatmul::new();
930                let tiled = engine.matmul(a_slice, m, k, b_slice, n);
931                c_slice.copy_from_slice(&tiled);
932            } else {
933                for i in 0..m {
934                    for j in 0..n {
935                        let mut acc = KahanAccumulatorF64::new();
936                        for p in 0..k {
937                            acc.add(a_slice[i * k + p] * b_slice[p * n + j]);
938                        }
939                        c_slice[i * n + j] = acc.finalize();
940                    }
941                }
942            }
943        };
944
945        // Parallel path: parallelize over batches when workload is large enough
946        #[cfg(feature = "parallel")]
947        {
948            if batch_size > 1 && m * k >= 4096 {
949                use rayon::prelude::*;
950                result
951                    .par_chunks_mut(mat_c_stride)
952                    .enumerate()
953                    .for_each(|(batch, c_slice)| {
954                        compute_batch(batch, c_slice);
955                    });
956
957                let mut out_shape = batch_dims_a.to_vec();
958                out_shape.push(m);
959                out_shape.push(n);
960                return Tensor::from_vec(result, &out_shape);
961            }
962        }
963
964        // Sequential fallback
965        for batch in 0..batch_size {
966            let c_off = batch * mat_c_stride;
967            compute_batch(batch, &mut result[c_off..c_off + mat_c_stride]);
968        }
969
970        let mut out_shape = batch_dims_a.to_vec();
971        out_shape.push(m);
972        out_shape.push(n);
973        Tensor::from_vec(result, &out_shape)
974    }
975
976    /// Softmax along the last dimension (two-pass stable algorithm).
977    ///
978    /// Pass 1: find max per row (prevents overflow in exp)
979    /// Pass 2: compute exp(x - max), accumulate sum, normalize
980    ///
981    /// For a tensor of shape `[..., N]`, softmax is applied independently
982    /// to each length-N slice along the last axis.
983    pub fn softmax(&self) -> Result<Tensor, RuntimeError> {
984        if self.ndim() == 0 {
985            return Err(RuntimeError::InvalidOperation(
986                "softmax requires at least 1-D tensor".to_string(),
987            ));
988        }
989        // Avoid allocation when tensor is already contiguous and starts at offset 0
990        let data_ref;
991        let data_vec;
992        let data: &[f64] = if self.is_contiguous() && self.offset == 0 {
993            data_ref = self.buffer.borrow_data();
994            &data_ref
995        } else {
996            data_vec = self.to_vec();
997            &data_vec
998        };
999        let n = *self.shape.last().unwrap(); // last dimension size
1000        let outer: usize = data.len() / n;  // product of all dims except last
1001        let mut result = vec![0.0f64; data.len()];
1002
1003        for row in 0..outer {
1004            let start = row * n;
1005            let end = start + n;
1006            let slice = &data[start..end];
1007
1008            // Pass 1: find max for numerical stability
1009            let mut max_val = f64::NEG_INFINITY;
1010            for &v in slice {
1011                if v > max_val {
1012                    max_val = v;
1013                }
1014            }
1015
1016            // Pass 2: exp(x - max) and accumulate sum
1017            let mut exp_vals = vec![0.0f64; n];
1018            let mut sum = 0.0f64;
1019            let mut comp = 0.0f64; // Kahan compensation
1020            for i in 0..n {
1021                let e = (slice[i] - max_val).exp();
1022                exp_vals[i] = e;
1023                // Kahan summation for the denominator
1024                let y = e - comp;
1025                let t = sum + y;
1026                comp = (t - sum) - y;
1027                sum = t;
1028            }
1029
1030            // Normalize
1031            if sum == 0.0 {
1032                // Degenerate case: all -inf inputs → uniform
1033                let uniform = 1.0 / n as f64;
1034                for i in 0..n {
1035                    result[start + i] = uniform;
1036                }
1037            } else {
1038                for i in 0..n {
1039                    result[start + i] = exp_vals[i] / sum;
1040                }
1041            }
1042        }
1043
1044        Tensor::from_vec(result, &self.shape)
1045    }
1046
1047    /// Layer normalization over the last dimension.
1048    ///
1049    /// For each length-D slice along the last axis:
1050    ///   1. mean = Σx / D  (BinnedAccumulator)
1051    ///   2. var  = Σ(x - mean)² / D  (BinnedAccumulator)
1052    ///   3. normalized = (x - mean) / √(var + eps)
1053    ///   4. output = gamma * normalized + beta
1054    ///
1055    /// `gamma` and `beta` are 1-D tensors of shape `[D]`.
1056    /// `eps` is a small constant (typically 1e-5).
1057    pub fn layer_norm(
1058        &self,
1059        gamma: &Tensor,
1060        beta: &Tensor,
1061        eps: f64,
1062    ) -> Result<Tensor, RuntimeError> {
1063        if self.ndim() == 0 {
1064            return Err(RuntimeError::InvalidOperation(
1065                "layer_norm requires at least 1-D tensor".to_string(),
1066            ));
1067        }
1068        let d = *self.shape.last().unwrap();
1069        if gamma.len() != d || beta.len() != d {
1070            return Err(RuntimeError::InvalidOperation(
1071                format!(
1072                    "layer_norm: gamma/beta length {} must match last dim {}",
1073                    gamma.len(),
1074                    d
1075                ),
1076            ));
1077        }
1078
1079        let data = self.to_vec();
1080        let gamma_data = gamma.to_vec();
1081        let beta_data = beta.to_vec();
1082        let outer = data.len() / d;
1083        let mut result = vec![0.0f64; data.len()];
1084
1085        for row in 0..outer {
1086            let start = row * d;
1087            let slice = &data[start..start + d];
1088
1089            // Pass 1: compute mean via BinnedAccumulator
1090            let mean = binned_sum_f64(slice) / d as f64;
1091
1092            // Pass 2: compute variance via BinnedAccumulator
1093            let diffs: Vec<f64> = slice.iter().map(|&x| {
1094                let diff = x - mean;
1095                diff * diff
1096            }).collect();
1097            let variance = binned_sum_f64(&diffs) / d as f64;
1098
1099            // Normalize, scale, shift
1100            let inv_std = 1.0 / (variance + eps).sqrt();
1101            for i in 0..d {
1102                let normalized = (slice[i] - mean) * inv_std;
1103                result[start + i] = gamma_data[i] * normalized + beta_data[i];
1104            }
1105        }
1106
1107        Tensor::from_vec(result, &self.shape)
1108    }
1109
1110    /// ReLU activation: max(0, x) element-wise.
1111    /// Apply a function element-wise, reusing the buffer when possible (COW).
1112    /// If refcount == 1, mutates in place (zero allocations).
1113    /// Otherwise, allocates a new buffer.
1114    fn map_elementwise(&self, f: impl Fn(f64) -> f64) -> Tensor {
1115        if self.is_contiguous() && self.offset == 0 && self.buffer.refcount() == 1 {
1116            // Fast path: mutate in place (COW — we're the sole owner)
1117            let mut data = self.buffer.borrow_data().clone();
1118            for x in data.iter_mut() {
1119                *x = f(*x);
1120            }
1121            Tensor::from_vec(data, &self.shape).unwrap()
1122        } else {
1123            // Fallback: allocate new buffer
1124            let data = self.to_vec();
1125            let result: Vec<f64> = data.iter().map(|&x| f(x)).collect();
1126            Tensor::from_vec(result, &self.shape).unwrap()
1127        }
1128    }
1129
1130    pub fn relu(&self) -> Tensor {
1131        self.map_elementwise(|x| if x > 0.0 { x } else { 0.0 })
1132    }
1133
1134    /// Sigmoid activation: 1 / (1 + exp(-x)) element-wise.
1135    pub fn sigmoid(&self) -> Tensor {
1136        self.map_elementwise(|x| 1.0 / (1.0 + (-x).exp()))
1137    }
1138
1139    /// Tanh activation element-wise.
1140    pub fn tanh_activation(&self) -> Tensor {
1141        self.map_elementwise(|x| x.tanh())
1142    }
1143
1144    /// Leaky ReLU activation: max(alpha*x, x) element-wise.
1145    pub fn leaky_relu(&self, alpha: f64) -> Tensor {
1146        self.map_elementwise(move |x| if x > 0.0 { x } else { alpha * x })
1147    }
1148
1149    /// SiLU (Swish) activation: x * sigmoid(x) element-wise.
1150    pub fn silu(&self) -> Tensor {
1151        let data = self.to_vec();
1152        let result: Vec<f64> = data.iter().map(|&x| x / (1.0 + (-x).exp())).collect();
1153        Tensor::from_vec(result, &self.shape).unwrap()
1154    }
1155
1156    /// Mish activation: x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))).
1157    pub fn mish(&self) -> Tensor {
1158        let data = self.to_vec();
1159        let result: Vec<f64> = data.iter().map(|&x| {
1160            let sp = (1.0 + x.exp()).ln();
1161            x * sp.tanh()
1162        }).collect();
1163        Tensor::from_vec(result, &self.shape).unwrap()
1164    }
1165
1166    /// Argmax: index of the maximum element (first occurrence, deterministic).
1167    pub fn argmax(&self) -> usize {
1168        let data = self.to_vec();
1169        let mut best_idx = 0;
1170        let mut best_val = f64::NEG_INFINITY;
1171        for (i, &v) in data.iter().enumerate() {
1172            if v > best_val || (v == best_val && i < best_idx) {
1173                best_val = v;
1174                best_idx = i;
1175            }
1176        }
1177        best_idx
1178    }
1179
1180    /// Argmin: index of the minimum element (first occurrence, deterministic).
1181    pub fn argmin(&self) -> usize {
1182        let data = self.to_vec();
1183        let mut best_idx = 0;
1184        let mut best_val = f64::INFINITY;
1185        for (i, &v) in data.iter().enumerate() {
1186            if v < best_val || (v == best_val && i < best_idx) {
1187                best_val = v;
1188                best_idx = i;
1189            }
1190        }
1191        best_idx
1192    }
1193
1194    /// Clamp all elements to [min, max].
1195    pub fn clamp(&self, min: f64, max: f64) -> Tensor {
1196        let data = self.to_vec();
1197        let result: Vec<f64> = data.iter().map(|&x| x.max(min).min(max)).collect();
1198        Tensor::from_vec(result, &self.shape).unwrap()
1199    }
1200
1201    /// One-hot encoding: given a 1D tensor of integer indices and a depth,
1202    /// returns a 2D tensor of shape [len, depth].
1203    pub fn one_hot(indices: &[usize], depth: usize) -> Result<Tensor, RuntimeError> {
1204        let n = indices.len();
1205        let mut data = vec![0.0; n * depth];
1206        for (i, &idx) in indices.iter().enumerate() {
1207            if idx >= depth {
1208                return Err(RuntimeError::InvalidOperation(format!(
1209                    "one_hot: index {idx} >= depth {depth}"
1210                )));
1211            }
1212            data[i * depth + idx] = 1.0;
1213        }
1214        Tensor::from_vec(data, &[n, depth])
1215    }
1216
1217    // -----------------------------------------------------------------------
1218    // Phase B4: Tensor extensions (cat, stack, topk)
1219    // -----------------------------------------------------------------------
1220
1221    /// Concatenate tensors along existing axis.
1222    pub fn cat(tensors: &[&Tensor], axis: usize) -> Result<Tensor, RuntimeError> {
1223        if tensors.is_empty() {
1224            return Err(RuntimeError::InvalidOperation("cat: no tensors".to_string()));
1225        }
1226        let ndim = tensors[0].ndim();
1227        if axis >= ndim {
1228            return Err(RuntimeError::InvalidOperation(
1229                format!("cat: axis {axis} out of bounds for {ndim}D tensor"),
1230            ));
1231        }
1232        for (i, t) in tensors.iter().enumerate().skip(1) {
1233            if t.ndim() != ndim {
1234                return Err(RuntimeError::InvalidOperation(
1235                    format!("cat: tensor {i} has different ndim"),
1236                ));
1237            }
1238            for d in 0..ndim {
1239                if d != axis && t.shape[d] != tensors[0].shape[d] {
1240                    return Err(RuntimeError::InvalidOperation(
1241                        format!("cat: shape mismatch at dim {d}"),
1242                    ));
1243                }
1244            }
1245        }
1246        let mut out_shape = tensors[0].shape.clone();
1247        for t in tensors.iter().skip(1) {
1248            out_shape[axis] += t.shape[axis];
1249        }
1250        let total = out_shape.iter().product::<usize>();
1251        let mut result = vec![0.0; total];
1252        let mut out_strides = vec![1usize; ndim];
1253        for d in (0..ndim - 1).rev() {
1254            out_strides[d] = out_strides[d + 1] * out_shape[d + 1];
1255        }
1256        let mut offset = 0;
1257        for t in tensors {
1258            let t_data = t.to_vec();
1259            let t_total: usize = t.shape.iter().product();
1260            let mut t_strides = vec![1usize; ndim];
1261            for d in (0..ndim - 1).rev() {
1262                t_strides[d] = t_strides[d + 1] * t.shape[d + 1];
1263            }
1264            for idx in 0..t_total {
1265                let mut remaining = idx;
1266                let mut out_flat = 0;
1267                for d in 0..ndim {
1268                    let coord = remaining / t_strides[d];
1269                    remaining %= t_strides[d];
1270                    let out_coord = if d == axis { coord + offset } else { coord };
1271                    out_flat += out_coord * out_strides[d];
1272                }
1273                result[out_flat] = t_data[idx];
1274            }
1275            offset += t.shape[axis];
1276        }
1277        Tensor::from_vec(result, &out_shape)
1278    }
1279
1280    /// Stack tensors along a new axis.
1281    pub fn stack(tensors: &[&Tensor], axis: usize) -> Result<Tensor, RuntimeError> {
1282        if tensors.is_empty() {
1283            return Err(RuntimeError::InvalidOperation("stack: no tensors".to_string()));
1284        }
1285        let base_shape = &tensors[0].shape;
1286        let ndim = base_shape.len();
1287        if axis > ndim {
1288            return Err(RuntimeError::InvalidOperation(
1289                format!("stack: axis {axis} out of bounds"),
1290            ));
1291        }
1292        for (i, t) in tensors.iter().enumerate().skip(1) {
1293            if &t.shape != base_shape {
1294                return Err(RuntimeError::InvalidOperation(
1295                    format!("stack: tensor {i} shape mismatch"),
1296                ));
1297            }
1298        }
1299        let mut out_shape = Vec::with_capacity(ndim + 1);
1300        for d in 0..axis { out_shape.push(base_shape[d]); }
1301        out_shape.push(tensors.len());
1302        for d in axis..ndim { out_shape.push(base_shape[d]); }
1303        let total: usize = out_shape.iter().product();
1304        let mut result = vec![0.0; total];
1305        let inner_size: usize = base_shape[axis..].iter().product::<usize>().max(1);
1306        let outer_size: usize = base_shape[..axis].iter().product::<usize>().max(1);
1307        for (t_idx, t) in tensors.iter().enumerate() {
1308            let t_data = t.to_vec();
1309            for outer in 0..outer_size {
1310                for inner in 0..inner_size {
1311                    let src = outer * inner_size + inner;
1312                    let dst = outer * (tensors.len() * inner_size) + t_idx * inner_size + inner;
1313                    if src < t_data.len() && dst < result.len() {
1314                        result[dst] = t_data[src];
1315                    }
1316                }
1317            }
1318        }
1319        Tensor::from_vec(result, &out_shape)
1320    }
1321
1322    /// Top-k values and indices (largest k values from flat data).
1323    pub fn topk(&self, k: usize) -> Result<(Tensor, Vec<usize>), RuntimeError> {
1324        let data = self.to_vec();
1325        let n = data.len();
1326        if k > n {
1327            return Err(RuntimeError::InvalidOperation(
1328                format!("topk: k={k} exceeds data length {n}"),
1329            ));
1330        }
1331        let mut indexed: Vec<(usize, f64)> = data.into_iter().enumerate().collect();
1332        indexed.sort_by(|a, b| b.1.total_cmp(&a.1).then(a.0.cmp(&b.0)));
1333        let top_k: Vec<(usize, f64)> = indexed[..k].to_vec();
1334        let values: Vec<f64> = top_k.iter().map(|&(_, v)| v).collect();
1335        let indices: Vec<usize> = top_k.iter().map(|&(i, _)| i).collect();
1336        Ok((Tensor::from_vec(values, &[k])?, indices))
1337    }
1338
1339    /// GELU activation (approximate): x * 0.5 * (1 + tanh(√(2/π) * (x + 0.044715 * x³)))
1340    pub fn gelu(&self) -> Tensor {
1341        let data = self.to_vec();
1342        let sqrt_2_over_pi = (2.0_f64 / std::f64::consts::PI).sqrt();
1343        let result: Vec<f64> = data.iter().map(|&x| {
1344            let inner = sqrt_2_over_pi * (x + 0.044715 * x * x * x);
1345            0.5 * x * (1.0 + inner.tanh())
1346        }).collect();
1347        Tensor::from_vec(result, &self.shape).unwrap()
1348    }
1349
1350    /// Linear layer: output = input @ weight^T + bias
1351    ///
1352    /// `self` is `[..., in_features]`, `weight` is `[out_features, in_features]`,
1353    /// `bias` is `[out_features]`.
1354    /// Result is `[..., out_features]`.
1355    pub fn linear(
1356        &self,
1357        weight: &Tensor,
1358        bias: &Tensor,
1359    ) -> Result<Tensor, RuntimeError> {
1360        if weight.ndim() != 2 {
1361            return Err(RuntimeError::InvalidOperation(
1362                "linear: weight must be 2-D [out_features, in_features]".to_string(),
1363            ));
1364        }
1365        let out_features = weight.shape[0];
1366        let in_features = weight.shape[1];
1367        let last_dim = *self.shape.last().ok_or_else(|| {
1368            RuntimeError::InvalidOperation("linear: input must be at least 1-D".to_string())
1369        })?;
1370        if last_dim != in_features {
1371            return Err(RuntimeError::DimensionMismatch {
1372                expected: in_features,
1373                got: last_dim,
1374            });
1375        }
1376        if bias.len() != out_features {
1377            return Err(RuntimeError::InvalidOperation(
1378                format!(
1379                    "linear: bias length {} must match out_features {}",
1380                    bias.len(),
1381                    out_features
1382                ),
1383            ));
1384        }
1385
1386        let data = self.to_vec();
1387        let w = weight.to_vec();
1388        let b = bias.to_vec();
1389        let outer = data.len() / in_features;
1390        let mut result = vec![0.0f64; outer * out_features];
1391
1392        for row in 0..outer {
1393            let x_start = row * in_features;
1394            let x_slice = &data[x_start..x_start + in_features];
1395            let y_start = row * out_features;
1396            for j in 0..out_features {
1397                let w_start = j * in_features;
1398                let mut acc = BinnedAccumulatorF64::new();
1399                for p in 0..in_features {
1400                    acc.add(x_slice[p] * w[w_start + p]);
1401                }
1402                result[y_start + j] = acc.finalize() + b[j];
1403            }
1404        }
1405
1406        let mut out_shape = self.shape[..self.shape.len() - 1].to_vec();
1407        out_shape.push(out_features);
1408        Tensor::from_vec(result, &out_shape)
1409    }
1410
1411    /// 1D convolution: signal `[signal_len]` * filters `[out_ch, kernel_size]` + bias
1412    ///
1413    /// Returns `[out_ch, signal_len - kernel_size + 1]` (valid mode, stride=1).
1414    pub fn conv1d(
1415        &self,
1416        filters: &Tensor,
1417        bias: &Tensor,
1418    ) -> Result<Tensor, RuntimeError> {
1419        if self.ndim() != 1 {
1420            return Err(RuntimeError::InvalidOperation(
1421                "conv1d: input must be 1-D [signal_len]".to_string(),
1422            ));
1423        }
1424        if filters.ndim() != 2 {
1425            return Err(RuntimeError::InvalidOperation(
1426                "conv1d: filters must be 2-D [out_channels, kernel_size]".to_string(),
1427            ));
1428        }
1429        let signal_len = self.shape[0];
1430        let out_channels = filters.shape[0];
1431        let kernel_size = filters.shape[1];
1432        if signal_len < kernel_size {
1433            return Err(RuntimeError::InvalidOperation(
1434                format!(
1435                    "conv1d: signal_len {} < kernel_size {}",
1436                    signal_len, kernel_size
1437                ),
1438            ));
1439        }
1440        if bias.len() != out_channels {
1441            return Err(RuntimeError::InvalidOperation(
1442                format!(
1443                    "conv1d: bias length {} must match out_channels {}",
1444                    bias.len(), out_channels
1445                ),
1446            ));
1447        }
1448        let out_len = signal_len - kernel_size + 1;
1449        let s = self.to_vec();
1450        let f = filters.to_vec();
1451        let b = bias.to_vec();
1452        let mut result = vec![0.0; out_channels * out_len];
1453        kernel_fns::conv1d_raw(&s, &f, &b, &mut result, signal_len, out_channels, kernel_size);
1454        Tensor::from_vec(result, &[out_channels, out_len])
1455    }
1456
1457    /// 2D convolution — NCHW layout, valid mode, configurable stride.
1458    ///
1459    /// # Arguments
1460    /// - `self`:    `[N, C_in, H, W]` input tensor
1461    /// - `filters`: `[C_out, C_in, kH, kW]`
1462    /// - `bias`:    `[C_out]`
1463    /// - `stride`:  spatial stride (default 1)
1464    ///
1465    /// # Returns
1466    /// `[N, C_out, H_out, W_out]` where `H_out = (H - kH) / stride + 1`.
1467    ///
1468    /// Uses `BinnedAccumulatorF64` for every dot product — bit-identical results
1469    /// across all runs and hardware configurations.
1470    pub fn conv2d(
1471        &self,
1472        filters: &Tensor,
1473        bias: &Tensor,
1474        stride: usize,
1475    ) -> Result<Tensor, RuntimeError> {
1476        if self.ndim() != 4 {
1477            return Err(RuntimeError::InvalidOperation(
1478                "conv2d: input must be 4-D [N, C_in, H, W]".to_string(),
1479            ));
1480        }
1481        if filters.ndim() != 4 {
1482            return Err(RuntimeError::InvalidOperation(
1483                "conv2d: filters must be 4-D [C_out, C_in, kH, kW]".to_string(),
1484            ));
1485        }
1486        if stride == 0 {
1487            return Err(RuntimeError::InvalidOperation(
1488                "conv2d: stride must be >= 1".to_string(),
1489            ));
1490        }
1491
1492        let n    = self.shape[0];
1493        let c_in = self.shape[1];
1494        let h_in = self.shape[2];
1495        let w_in = self.shape[3];
1496
1497        let c_out      = filters.shape[0];
1498        let c_in_check = filters.shape[1];
1499        let kh         = filters.shape[2];
1500        let kw         = filters.shape[3];
1501
1502        if c_in != c_in_check {
1503            return Err(RuntimeError::InvalidOperation(format!(
1504                "conv2d: input C_in={} does not match filter C_in={}",
1505                c_in, c_in_check
1506            )));
1507        }
1508        if h_in < kh || w_in < kw {
1509            return Err(RuntimeError::InvalidOperation(format!(
1510                "conv2d: input spatial [{}, {}] is smaller than kernel [{}, {}]",
1511                h_in, w_in, kh, kw
1512            )));
1513        }
1514        if bias.len() != c_out {
1515            return Err(RuntimeError::InvalidOperation(format!(
1516                "conv2d: bias length {} must match C_out={}",
1517                bias.len(), c_out
1518            )));
1519        }
1520
1521        let h_out = (h_in - kh) / stride + 1;
1522        let w_out = (w_in - kw) / stride + 1;
1523
1524        let inp = self.to_vec();
1525        let flt = filters.to_vec();
1526        let b   = bias.to_vec();
1527        let mut result = vec![0.0f64; n * c_out * h_out * w_out];
1528
1529        kernel_fns::conv2d_raw(&inp, &flt, &b, &mut result,
1530                           n, c_in, h_in, w_in, c_out, kh, kw, stride);
1531
1532        Tensor::from_vec(result, &[n, c_out, h_out, w_out])
1533    }
1534
1535    /// 2D max-pooling — NCHW layout, non-overlapping windows.
1536    ///
1537    /// - `self`: `[N, C, H, W]`
1538    /// - `ph`, `pw`: pool height/width (stride = window size)
1539    ///
1540    /// Returns `[N, C, H/ph, W/pw]`.
1541    pub fn maxpool2d(&self, ph: usize, pw: usize) -> Result<Tensor, RuntimeError> {
1542        if self.ndim() != 4 {
1543            return Err(RuntimeError::InvalidOperation(
1544                "maxpool2d: input must be 4-D [N, C, H, W]".to_string(),
1545            ));
1546        }
1547        if ph == 0 || pw == 0 {
1548            return Err(RuntimeError::InvalidOperation(
1549                "maxpool2d: pool size must be >= 1".to_string(),
1550            ));
1551        }
1552
1553        let n    = self.shape[0];
1554        let c    = self.shape[1];
1555        let h_in = self.shape[2];
1556        let w_in = self.shape[3];
1557
1558        if h_in < ph || w_in < pw {
1559            return Err(RuntimeError::InvalidOperation(format!(
1560                "maxpool2d: input [{}, {}] smaller than pool [{}, {}]",
1561                h_in, w_in, ph, pw
1562            )));
1563        }
1564
1565        let h_out = h_in / ph;
1566        let w_out = w_in / pw;
1567
1568        let inp = self.to_vec();
1569        let mut result = vec![0.0f64; n * c * h_out * w_out];
1570
1571        kernel_fns::maxpool2d_raw(&inp, &mut result, n, c, h_in, w_in, ph, pw);
1572
1573        Tensor::from_vec(result, &[n, c, h_out, w_out])
1574    }
1575
1576    /// Scaled dot-product attention (single head).
1577    ///
1578    /// `queries` is `[..., T, d_k]`
1579    /// `keys`    is `[..., S, d_k]`
1580    /// `values`  is `[..., S, d_v]`
1581    ///
1582    /// Computes: softmax(Q × Kᵀ / √d_k) × V
1583    /// Returns `[..., T, d_v]`.
1584    pub fn scaled_dot_product_attention(
1585        queries: &Tensor,
1586        keys: &Tensor,
1587        values: &Tensor,
1588    ) -> Result<Tensor, RuntimeError> {
1589        if queries.ndim() < 2 || keys.ndim() < 2 || values.ndim() < 2 {
1590            return Err(RuntimeError::InvalidOperation(
1591                "attention: Q, K, V must be at least 2-D".to_string(),
1592            ));
1593        }
1594        let nd = queries.ndim();
1595        let d_k = queries.shape[nd - 1];
1596        let scale = 1.0 / (d_k as f64).sqrt();
1597
1598        // Transpose keys: swap last two dims
1599        let keys_t = keys.transpose_last_two()?;
1600
1601        // Q × K^T → [... T, S]
1602        let scores = queries.bmm(&keys_t)?;
1603
1604        // Scale
1605        let scores_scaled = scores.scalar_mul(scale);
1606
1607        // Softmax along last dim
1608        let attn_weights = scores_scaled.softmax()?;
1609
1610        // Attn × V → [... T, d_v]
1611        attn_weights.bmm(values)
1612    }
1613
1614    /// Transpose the last two dimensions of a tensor.
1615    ///
1616    /// `[..., A, B]` → `[..., B, A]`
1617    pub fn transpose_last_two(&self) -> Result<Tensor, RuntimeError> {
1618        if self.ndim() < 2 {
1619            return Err(RuntimeError::InvalidOperation(
1620                "transpose_last_two requires at least 2-D tensor".to_string(),
1621            ));
1622        }
1623        let nd = self.ndim();
1624        let rows = self.shape[nd - 2];
1625        let cols = self.shape[nd - 1];
1626        let data = self.to_vec();
1627        let batch_size: usize = self.shape[..nd - 2].iter().product::<usize>().max(1);
1628        let mat_size = rows * cols;
1629        let mut result = vec![0.0f64; data.len()];
1630
1631        for b in 0..batch_size {
1632            let off = b * mat_size;
1633            for i in 0..rows {
1634                for j in 0..cols {
1635                    result[off + j * rows + i] = data[off + i * cols + j];
1636                }
1637            }
1638        }
1639
1640        let mut out_shape = self.shape.clone();
1641        out_shape[nd - 2] = cols;
1642        out_shape[nd - 1] = rows;
1643        Tensor::from_vec(result, &out_shape)
1644    }
1645
1646    // -- Zero-Copy Weight Mapping -------------------------------------------
1647
1648    /// Create a tensor view from raw bytes — **zero allocation**.
1649    ///
1650    /// Interprets `bytes` as a contiguous block of `f64` (8 bytes each) or
1651    /// `f32` (4 bytes each, promoted to f64) values and maps them into a
1652    /// `Tensor` with the given shape.
1653    ///
1654    /// `dtype` must be `"f64"` or `"f32"`.
1655    ///
1656    /// For f64: bytes.len() must equal shape_numel * 8.
1657    /// For f32: bytes.len() must equal shape_numel * 4.
1658    ///
1659    /// The returned tensor **owns** its buffer (copied from the raw bytes)
1660    /// but performs exactly one allocation for the data vector.
1661    pub fn from_bytes(bytes: &[u8], shape: &[usize], dtype: &str) -> Result<Tensor, RuntimeError> {
1662        let numel = Self::shape_numel(shape);
1663        match dtype {
1664            "f64" => {
1665                let expected = numel * 8;
1666                if bytes.len() != expected {
1667                    return Err(RuntimeError::ShapeMismatch {
1668                        expected,
1669                        got: bytes.len(),
1670                    });
1671                }
1672                let mut data = Vec::with_capacity(numel);
1673                for i in 0..numel {
1674                    let off = i * 8;
1675                    let mut buf = [0u8; 8];
1676                    buf.copy_from_slice(&bytes[off..off + 8]);
1677                    data.push(f64::from_le_bytes(buf));
1678                }
1679                Ok(Tensor {
1680                    buffer: Buffer::from_vec(data),
1681                    shape: shape.to_vec(),
1682                    strides: Self::compute_strides(shape),
1683                    offset: 0,
1684                })
1685            }
1686            "f32" => {
1687                let expected = numel * 4;
1688                if bytes.len() != expected {
1689                    return Err(RuntimeError::ShapeMismatch {
1690                        expected,
1691                        got: bytes.len(),
1692                    });
1693                }
1694                let mut data = Vec::with_capacity(numel);
1695                for i in 0..numel {
1696                    let off = i * 4;
1697                    let mut buf = [0u8; 4];
1698                    buf.copy_from_slice(&bytes[off..off + 4]);
1699                    data.push(f32::from_le_bytes(buf) as f64);
1700                }
1701                Ok(Tensor {
1702                    buffer: Buffer::from_vec(data),
1703                    shape: shape.to_vec(),
1704                    strides: Self::compute_strides(shape),
1705                    offset: 0,
1706                })
1707            }
1708            _ => Err(RuntimeError::InvalidOperation(
1709                format!("from_bytes: unsupported dtype '{}', expected 'f32' or 'f64'", dtype),
1710            )),
1711        }
1712    }
1713
1714    // -- Multi-Head Attention Splitting -------------------------------------
1715
1716    /// Reshape a 3D tensor `[batch, seq, model_dim]` into 4D
1717    /// `[batch, num_heads, seq, head_dim]` by splitting the last dimension.
1718    ///
1719    /// This is a **zero-copy view** — it only changes shape/strides metadata.
1720    /// `model_dim` must be divisible by `num_heads`.
1721    pub fn split_heads(&self, num_heads: usize) -> Result<Tensor, RuntimeError> {
1722        if self.ndim() != 3 {
1723            return Err(RuntimeError::DimensionMismatch {
1724                expected: 3,
1725                got: self.ndim(),
1726            });
1727        }
1728        let batch = self.shape[0];
1729        let seq = self.shape[1];
1730        let model_dim = self.shape[2];
1731        if model_dim % num_heads != 0 {
1732            return Err(RuntimeError::InvalidOperation(
1733                format!(
1734                    "split_heads: model_dim {} not divisible by num_heads {}",
1735                    model_dim, num_heads
1736                ),
1737            ));
1738        }
1739        let head_dim = model_dim / num_heads;
1740        // Need contiguous data for the reshape
1741        let tensor = if self.is_contiguous() { self.clone() } else { self.to_contiguous() };
1742        // Reshape [B, S, H*D] -> [B, S, H, D] then transpose to [B, H, S, D]
1743        let reshaped = Tensor {
1744            buffer: tensor.buffer.clone(),
1745            shape: vec![batch, seq, num_heads, head_dim],
1746            strides: Self::compute_strides(&[batch, seq, num_heads, head_dim]),
1747            offset: 0,
1748        };
1749        // Transpose dims 1 and 2: [B, S, H, D] -> [B, H, S, D]
1750        // New strides: swap strides[1] and strides[2]
1751        Ok(Tensor {
1752            buffer: reshaped.buffer,
1753            shape: vec![batch, num_heads, seq, head_dim],
1754            strides: vec![
1755                reshaped.strides[0], // batch stride unchanged
1756                reshaped.strides[2], // head stride (was dim 2)
1757                reshaped.strides[1], // seq stride (was dim 1)
1758                reshaped.strides[3], // head_dim stride unchanged
1759            ],
1760            offset: 0,
1761        })
1762    }
1763
1764    /// Merge heads back: reshape 4D `[batch, num_heads, seq, head_dim]` into
1765    /// 3D `[batch, seq, model_dim]`. Materializes if non-contiguous.
1766    pub fn merge_heads(&self) -> Result<Tensor, RuntimeError> {
1767        if self.ndim() != 4 {
1768            return Err(RuntimeError::DimensionMismatch {
1769                expected: 4,
1770                got: self.ndim(),
1771            });
1772        }
1773        let batch = self.shape[0];
1774        let num_heads = self.shape[1];
1775        let seq = self.shape[2];
1776        let head_dim = self.shape[3];
1777        // Need [B, H, S, D] -> [B, S, H, D] -> [B, S, H*D]
1778        // Transpose dims 1 and 2 first
1779        let transposed = Tensor {
1780            buffer: self.buffer.clone(),
1781            shape: vec![batch, seq, num_heads, head_dim],
1782            strides: vec![
1783                self.strides[0],
1784                self.strides[2], // seq stride
1785                self.strides[1], // head stride
1786                self.strides[3],
1787            ],
1788            offset: self.offset,
1789        };
1790        // Materialize contiguous then reshape
1791        let contig = transposed.to_contiguous();
1792        let model_dim = num_heads * head_dim;
1793        Ok(Tensor {
1794            buffer: contig.buffer,
1795            shape: vec![batch, seq, model_dim],
1796            strides: Self::compute_strides(&[batch, seq, model_dim]),
1797            offset: 0,
1798        })
1799    }
1800
1801    /// View-only reshape: reinterpret shape without copying.
1802    /// Only works on contiguous tensors. Falls back to copy if non-contiguous.
1803    pub fn view_reshape(&self, new_shape: &[usize]) -> Result<Tensor, RuntimeError> {
1804        self.reshape(new_shape)
1805    }
1806
1807    // -----------------------------------------------------------------------
1808    // Phase C4: Sorting & Tensor Indexing
1809    // -----------------------------------------------------------------------
1810
1811    /// Returns indices that would sort the flattened tensor in ascending order.
1812    /// Uses f64::total_cmp for deterministic ordering of NaN.
1813    pub fn argsort(&self) -> Tensor {
1814        let data = self.to_vec();
1815        let mut indices: Vec<usize> = (0..data.len()).collect();
1816        indices.sort_by(|&a, &b| data[a].total_cmp(&data[b]));
1817        let result: Vec<f64> = indices.iter().map(|&i| i as f64).collect();
1818        Tensor::from_vec_unchecked(result, &[data.len()])
1819    }
1820
1821    /// Gather elements from the tensor along a dimension using index tensor.
1822    /// For 1D: result[i] = self[indices[i]]
1823    /// For 2D dim=0: result[i][j] = self[indices[i][j]][j]
1824    /// For 2D dim=1: result[i][j] = self[i][indices[i][j]]
1825    pub fn gather(&self, dim: usize, indices: &Tensor) -> Result<Tensor, RuntimeError> {
1826        let data = self.to_vec();
1827        let idx_data = indices.to_vec();
1828        if self.ndim() == 1 {
1829            let mut result = Vec::with_capacity(idx_data.len());
1830            for &idx in &idx_data {
1831                let i = idx as usize;
1832                if i >= data.len() {
1833                    return Err(RuntimeError::InvalidOperation(
1834                        format!("gather: index {} out of bounds for size {}", i, data.len()),
1835                    ));
1836                }
1837                result.push(data[i]);
1838            }
1839            Ok(Tensor::from_vec_unchecked(result, indices.shape()))
1840        } else if self.ndim() == 2 {
1841            let rows = self.shape[0];
1842            let cols = self.shape[1];
1843            let idx_shape = indices.shape();
1844            let out_rows = idx_shape[0];
1845            let out_cols = idx_shape[1];
1846            let mut result = vec![0.0; out_rows * out_cols];
1847            for i in 0..out_rows {
1848                for j in 0..out_cols {
1849                    let idx = idx_data[i * out_cols + j] as usize;
1850                    let val = if dim == 0 {
1851                        if idx >= rows {
1852                            return Err(RuntimeError::InvalidOperation(
1853                                format!("gather dim=0: index {} out of bounds for {} rows", idx, rows),
1854                            ));
1855                        }
1856                        data[idx * cols + j]
1857                    } else {
1858                        if idx >= cols {
1859                            return Err(RuntimeError::InvalidOperation(
1860                                format!("gather dim=1: index {} out of bounds for {} cols", idx, cols),
1861                            ));
1862                        }
1863                        data[i * cols + idx]
1864                    };
1865                    result[i * out_cols + j] = val;
1866                }
1867            }
1868            Ok(Tensor::from_vec_unchecked(result, idx_shape))
1869        } else {
1870            Err(RuntimeError::InvalidOperation(
1871                "gather: only 1D and 2D tensors supported".into(),
1872            ))
1873        }
1874    }
1875
1876    /// Scatter src values into a tensor of given shape at indices along a dimension.
1877    /// For 1D: result[indices[i]] = src[i]
1878    /// For 2D dim=0: result[indices[i][j]][j] = src[i][j]
1879    /// For 2D dim=1: result[i][indices[i][j]] = src[i][j]
1880    pub fn scatter(&self, dim: usize, indices: &Tensor, src: &Tensor) -> Result<Tensor, RuntimeError> {
1881        let mut result = self.to_vec();
1882        let idx_data = indices.to_vec();
1883        let src_data = src.to_vec();
1884        if self.ndim() == 1 {
1885            for (k, &idx) in idx_data.iter().enumerate() {
1886                let i = idx as usize;
1887                if i >= result.len() {
1888                    return Err(RuntimeError::InvalidOperation(
1889                        format!("scatter: index {} out of bounds for size {}", i, result.len()),
1890                    ));
1891                }
1892                result[i] = src_data[k];
1893            }
1894            Ok(Tensor::from_vec_unchecked(result, self.shape()))
1895        } else if self.ndim() == 2 {
1896            let cols = self.shape[1];
1897            let idx_shape = indices.shape();
1898            let out_cols = idx_shape[1];
1899            let out_rows = idx_shape[0];
1900            for i in 0..out_rows {
1901                for j in 0..out_cols {
1902                    let idx = idx_data[i * out_cols + j] as usize;
1903                    let src_val = src_data[i * out_cols + j];
1904                    if dim == 0 {
1905                        if idx >= self.shape[0] {
1906                            return Err(RuntimeError::InvalidOperation(
1907                                format!("scatter dim=0: index {} out of bounds for {} rows", idx, self.shape[0]),
1908                            ));
1909                        }
1910                        result[idx * cols + j] = src_val;
1911                    } else {
1912                        if idx >= cols {
1913                            return Err(RuntimeError::InvalidOperation(
1914                                format!("scatter dim=1: index {} out of bounds for {} cols", idx, cols),
1915                            ));
1916                        }
1917                        result[i * cols + idx] = src_val;
1918                    }
1919                }
1920            }
1921            Ok(Tensor::from_vec_unchecked(result, self.shape()))
1922        } else {
1923            Err(RuntimeError::InvalidOperation(
1924                "scatter: only 1D and 2D tensors supported".into(),
1925            ))
1926        }
1927    }
1928
1929    /// Select slices along a dimension by index.
1930    /// For 2D dim=0: selects rows
1931    /// For 2D dim=1: selects columns
1932    pub fn index_select(&self, dim: usize, indices: &Tensor) -> Result<Tensor, RuntimeError> {
1933        let data = self.to_vec();
1934        let idx_data = indices.to_vec();
1935        if self.ndim() == 1 {
1936            let mut result = Vec::with_capacity(idx_data.len());
1937            for &idx in &idx_data {
1938                let i = idx as usize;
1939                if i >= data.len() {
1940                    return Err(RuntimeError::InvalidOperation(
1941                        format!("index_select: index {} out of bounds for size {}", i, data.len()),
1942                    ));
1943                }
1944                result.push(data[i]);
1945            }
1946            Ok(Tensor::from_vec_unchecked(result, &[idx_data.len()]))
1947        } else if self.ndim() == 2 {
1948            let rows = self.shape[0];
1949            let cols = self.shape[1];
1950            let n = idx_data.len();
1951            if dim == 0 {
1952                let mut result = Vec::with_capacity(n * cols);
1953                for &idx in &idx_data {
1954                    let i = idx as usize;
1955                    if i >= rows {
1956                        return Err(RuntimeError::InvalidOperation(
1957                            format!("index_select dim=0: index {} out of bounds for {} rows", i, rows),
1958                        ));
1959                    }
1960                    for j in 0..cols {
1961                        result.push(data[i * cols + j]);
1962                    }
1963                }
1964                Ok(Tensor::from_vec_unchecked(result, &[n, cols]))
1965            } else {
1966                let mut result = Vec::with_capacity(rows * n);
1967                for i in 0..rows {
1968                    for &idx in &idx_data {
1969                        let j = idx as usize;
1970                        if j >= cols {
1971                            return Err(RuntimeError::InvalidOperation(
1972                                format!("index_select dim=1: index {} out of bounds for {} cols", j, cols),
1973                            ));
1974                        }
1975                        result.push(data[i * cols + j]);
1976                    }
1977                }
1978                Ok(Tensor::from_vec_unchecked(result, &[rows, n]))
1979            }
1980        } else {
1981            Err(RuntimeError::InvalidOperation(
1982                "index_select: only 1D and 2D tensors supported".into(),
1983            ))
1984        }
1985    }
1986
1987    // -----------------------------------------------------------------------
1988    // Phase 2: Boolean / Masking Ops
1989    // -----------------------------------------------------------------------
1990
1991    /// Element-wise conditional select: `where(condition, other)`.
1992    /// For each element, returns `self[i]` if `condition[i] != 0.0`, else `other[i]`.
1993    pub fn tensor_where(&self, condition: &Tensor, other: &Tensor) -> Result<Tensor, RuntimeError> {
1994        if self.shape() != condition.shape() || self.shape() != other.shape() {
1995            return Err(RuntimeError::InvalidOperation(
1996                format!("where: shape mismatch self={:?} cond={:?} other={:?}",
1997                    self.shape(), condition.shape(), other.shape()),
1998            ));
1999        }
2000        let s = self.to_vec();
2001        let c = condition.to_vec();
2002        let o = other.to_vec();
2003        let result: Vec<f64> = s.iter().zip(c.iter()).zip(o.iter())
2004            .map(|((&sv, &cv), &ov)| if cv != 0.0 { sv } else { ov })
2005            .collect();
2006        Tensor::from_vec(result, self.shape())
2007    }
2008
2009    /// Returns `true` if any element is non-zero.
2010    pub fn any(&self) -> bool {
2011        let data = self.to_vec();
2012        data.iter().any(|&x| x != 0.0)
2013    }
2014
2015    /// Returns `true` if all elements are non-zero.
2016    pub fn all(&self) -> bool {
2017        let data = self.to_vec();
2018        data.iter().all(|&x| x != 0.0)
2019    }
2020
2021    /// Returns a 1-D tensor of flat indices where elements are non-zero.
2022    pub fn nonzero(&self) -> Tensor {
2023        let data = self.to_vec();
2024        let indices: Vec<f64> = data.iter().enumerate()
2025            .filter(|(_, &v)| v != 0.0)
2026            .map(|(i, _)| i as f64)
2027            .collect();
2028        let len = indices.len();
2029        if len == 0 {
2030            Tensor::from_vec(vec![], &[0]).unwrap()
2031        } else {
2032            Tensor::from_vec(indices, &[len]).unwrap()
2033        }
2034    }
2035
2036    /// Fill elements where `mask` is non-zero with `value`.
2037    pub fn masked_fill(&self, mask: &Tensor, value: f64) -> Result<Tensor, RuntimeError> {
2038        if self.shape() != mask.shape() {
2039            return Err(RuntimeError::InvalidOperation(
2040                format!("masked_fill: shape mismatch self={:?} mask={:?}",
2041                    self.shape(), mask.shape()),
2042            ));
2043        }
2044        let data = self.to_vec();
2045        let m = mask.to_vec();
2046        let result: Vec<f64> = data.iter().zip(m.iter())
2047            .map(|(&d, &mv)| if mv != 0.0 { value } else { d })
2048            .collect();
2049        Tensor::from_vec(result, self.shape())
2050    }
2051
2052    // -----------------------------------------------------------------------
2053    // Phase 2: Axis Reductions with keepdim
2054    // -----------------------------------------------------------------------
2055
2056    /// Helper: generic axis reduction using BinnedAccumulator.
2057    /// `reduce_fn` takes a slice of values and returns the reduced value.
2058    fn reduce_axis<F>(&self, axis: usize, keepdim: bool, reduce_fn: F)
2059        -> Result<Tensor, RuntimeError>
2060    where
2061        F: Fn(&[f64]) -> f64,
2062    {
2063        let ndim = self.ndim();
2064        if axis >= ndim {
2065            return Err(RuntimeError::IndexOutOfBounds {
2066                index: axis,
2067                length: ndim,
2068            });
2069        }
2070
2071        let axis_len = self.shape[axis];
2072        // Build output shape
2073        let mut out_shape: Vec<usize> = self.shape.clone();
2074        out_shape[axis] = 1;
2075        let out_numel = Self::shape_numel(&out_shape);
2076        let out_strides = Self::compute_strides(&out_shape);
2077
2078        let data = self.to_vec();
2079        let mut result = Vec::with_capacity(out_numel);
2080        let mut indices = vec![0usize; ndim];
2081
2082        for out_idx in 0..out_numel {
2083            // Compute N-D index from flat output index
2084            {
2085                let mut remaining = out_idx;
2086                for d in 0..ndim {
2087                    indices[d] = remaining / out_strides[d];
2088                    remaining %= out_strides[d];
2089                }
2090            }
2091
2092            // Gather values along the reduction axis
2093            let mut vals = Vec::with_capacity(axis_len);
2094            for k in 0..axis_len {
2095                let mut flat = self.offset;
2096                for d in 0..ndim {
2097                    let idx = if d == axis { k } else { indices[d] };
2098                    flat += idx * self.strides[d];
2099                }
2100                vals.push(data[flat]);
2101            }
2102            result.push(reduce_fn(&vals));
2103        }
2104
2105        let final_shape = if keepdim {
2106            out_shape
2107        } else {
2108            // Remove the axis dimension
2109            let mut s: Vec<usize> = self.shape.iter().enumerate()
2110                .filter(|&(i, _)| i != axis)
2111                .map(|(_, &v)| v)
2112                .collect();
2113            if s.is_empty() {
2114                s.push(1); // scalar result
2115            }
2116            s
2117        };
2118
2119        Tensor::from_vec(result, &final_shape)
2120    }
2121
2122    /// Mean along an axis with optional keepdim.
2123    pub fn mean_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2124        self.reduce_axis(axis, keepdim, |vals| {
2125            let mut acc = BinnedAccumulatorF64::new();
2126            for &v in vals { acc.add(v); }
2127            acc.finalize() / vals.len() as f64
2128        })
2129    }
2130
2131    /// Max along an axis with optional keepdim. Returns (values, indices).
2132    pub fn max_axis(&self, axis: usize, keepdim: bool) -> Result<(Tensor, Tensor), RuntimeError> {
2133        let ndim = self.ndim();
2134        if axis >= ndim {
2135            return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2136        }
2137        let axis_len = self.shape[axis];
2138        let mut out_shape = self.shape.clone();
2139        out_shape[axis] = 1;
2140        let out_numel = Self::shape_numel(&out_shape);
2141        let out_strides = Self::compute_strides(&out_shape);
2142        let data = self.to_vec();
2143        let mut values = Vec::with_capacity(out_numel);
2144        let mut idx_vals = Vec::with_capacity(out_numel);
2145        let mut indices = vec![0usize; ndim];
2146
2147        for out_idx in 0..out_numel {
2148            let mut remaining = out_idx;
2149            for d in 0..ndim {
2150                indices[d] = remaining / out_strides[d];
2151                remaining %= out_strides[d];
2152            }
2153            let mut best_val = f64::NEG_INFINITY;
2154            let mut best_idx = 0usize;
2155            for k in 0..axis_len {
2156                let mut flat = self.offset;
2157                for d in 0..ndim {
2158                    let idx = if d == axis { k } else { indices[d] };
2159                    flat += idx * self.strides[d];
2160                }
2161                let v = data[flat];
2162                if v > best_val {
2163                    best_val = v;
2164                    best_idx = k;
2165                }
2166            }
2167            values.push(best_val);
2168            idx_vals.push(best_idx as f64);
2169        }
2170
2171        let final_shape = if keepdim {
2172            out_shape
2173        } else {
2174            let mut s: Vec<usize> = self.shape.iter().enumerate()
2175                .filter(|&(i, _)| i != axis).map(|(_, &v)| v).collect();
2176            if s.is_empty() { s.push(1); }
2177            s
2178        };
2179        Ok((
2180            Tensor::from_vec(values, &final_shape)?,
2181            Tensor::from_vec(idx_vals, &final_shape)?,
2182        ))
2183    }
2184
2185    /// Min along an axis with optional keepdim. Returns (values, indices).
2186    pub fn min_axis(&self, axis: usize, keepdim: bool) -> Result<(Tensor, Tensor), RuntimeError> {
2187        let ndim = self.ndim();
2188        if axis >= ndim {
2189            return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2190        }
2191        let axis_len = self.shape[axis];
2192        let mut out_shape = self.shape.clone();
2193        out_shape[axis] = 1;
2194        let out_numel = Self::shape_numel(&out_shape);
2195        let out_strides = Self::compute_strides(&out_shape);
2196        let data = self.to_vec();
2197        let mut values = Vec::with_capacity(out_numel);
2198        let mut idx_vals = Vec::with_capacity(out_numel);
2199        let mut indices = vec![0usize; ndim];
2200
2201        for out_idx in 0..out_numel {
2202            let mut remaining = out_idx;
2203            for d in 0..ndim {
2204                indices[d] = remaining / out_strides[d];
2205                remaining %= out_strides[d];
2206            }
2207            let mut best_val = f64::INFINITY;
2208            let mut best_idx = 0usize;
2209            for k in 0..axis_len {
2210                let mut flat = self.offset;
2211                for d in 0..ndim {
2212                    let idx = if d == axis { k } else { indices[d] };
2213                    flat += idx * self.strides[d];
2214                }
2215                let v = data[flat];
2216                if v < best_val {
2217                    best_val = v;
2218                    best_idx = k;
2219                }
2220            }
2221            values.push(best_val);
2222            idx_vals.push(best_idx as f64);
2223        }
2224
2225        let final_shape = if keepdim {
2226            out_shape
2227        } else {
2228            let mut s: Vec<usize> = self.shape.iter().enumerate()
2229                .filter(|&(i, _)| i != axis).map(|(_, &v)| v).collect();
2230            if s.is_empty() { s.push(1); }
2231            s
2232        };
2233        Ok((
2234            Tensor::from_vec(values, &final_shape)?,
2235            Tensor::from_vec(idx_vals, &final_shape)?,
2236        ))
2237    }
2238
2239    /// Variance along an axis with optional keepdim.
2240    pub fn var_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2241        let mean_t = self.mean_axis(axis, true)?;
2242        let ndim = self.ndim();
2243        if axis >= ndim {
2244            return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2245        }
2246        let axis_len = self.shape[axis];
2247        let mut out_shape = self.shape.clone();
2248        out_shape[axis] = 1;
2249        let out_numel = Self::shape_numel(&out_shape);
2250        let out_strides = Self::compute_strides(&out_shape);
2251        let data = self.to_vec();
2252        let mean_data = mean_t.to_vec();
2253        let mut result = Vec::with_capacity(out_numel);
2254        let mut indices = vec![0usize; ndim];
2255
2256        for out_idx in 0..out_numel {
2257            let mut remaining = out_idx;
2258            for d in 0..ndim {
2259                indices[d] = remaining / out_strides[d];
2260                remaining %= out_strides[d];
2261            }
2262            let mu = mean_data[out_idx];
2263            let mut acc = BinnedAccumulatorF64::new();
2264            for k in 0..axis_len {
2265                let mut flat = self.offset;
2266                for d in 0..ndim {
2267                    let idx = if d == axis { k } else { indices[d] };
2268                    flat += idx * self.strides[d];
2269                }
2270                let diff = data[flat] - mu;
2271                acc.add(diff * diff);
2272            }
2273            result.push(acc.finalize() / axis_len as f64);
2274        }
2275
2276        let final_shape = if keepdim {
2277            out_shape
2278        } else {
2279            let mut s: Vec<usize> = self.shape.iter().enumerate()
2280                .filter(|&(i, _)| i != axis).map(|(_, &v)| v).collect();
2281            if s.is_empty() { s.push(1); }
2282            s
2283        };
2284        Tensor::from_vec(result, &final_shape)
2285    }
2286
2287    /// Standard deviation along an axis with optional keepdim.
2288    pub fn std_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2289        let var = self.var_axis(axis, keepdim)?;
2290        Ok(var.map(|x| x.sqrt()))
2291    }
2292
2293    /// Product along an axis with optional keepdim.
2294    pub fn prod_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2295        self.reduce_axis(axis, keepdim, |vals| {
2296            // Product via exp(sum(ln(abs))) for numerical stability is overkill here;
2297            // simple product is deterministic and exact for integer-like values.
2298            let mut product = 1.0f64;
2299            for &v in vals { product *= v; }
2300            product
2301        })
2302    }
2303
2304    // -----------------------------------------------------------------------
2305    // Phase 2: Sort Operations
2306    // -----------------------------------------------------------------------
2307
2308    /// Sort along an axis (stable sort). Returns the sorted tensor.
2309    /// For N-D tensors, sorts slices along the specified axis.
2310    pub fn sort_axis(&self, axis: usize, descending: bool) -> Result<Tensor, RuntimeError> {
2311        let ndim = self.ndim();
2312        if axis >= ndim {
2313            return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2314        }
2315        let data = self.to_vec();
2316        let axis_len = self.shape[axis];
2317        let out_shape = self.shape.clone();
2318        let out_numel = Self::shape_numel(&out_shape);
2319
2320        // Build strides for iterating over all non-axis positions
2321        let mut iter_shape: Vec<usize> = Vec::new();
2322        for (i, &s) in self.shape.iter().enumerate() {
2323            if i != axis { iter_shape.push(s); }
2324        }
2325        let n_slices: usize = iter_shape.iter().product::<usize>().max(1);
2326
2327        let mut result = vec![0.0f64; out_numel];
2328
2329        // We iterate over all positions with axis index = 0
2330        let mut pos = vec![0usize; ndim];
2331        for slice_idx in 0..n_slices {
2332            // Compute the N-D position (with axis dim = 0)
2333            let mut remaining = slice_idx;
2334            let mut dim_idx = 0;
2335            for d in 0..ndim {
2336                if d == axis {
2337                    pos[d] = 0;
2338                } else {
2339                    let stride = {
2340                        let mut s = 1usize;
2341                        let mut di = 0;
2342                        for d2 in 0..ndim {
2343                            if d2 == axis { continue; }
2344                            if di > dim_idx { s *= self.shape[d2]; }
2345                            di += 1;
2346                        }
2347                        s
2348                    };
2349                    pos[d] = remaining / stride;
2350                    remaining %= stride;
2351                    dim_idx += 1;
2352                }
2353            }
2354
2355            // Gather values along axis
2356            let mut vals: Vec<(f64, usize)> = Vec::with_capacity(axis_len);
2357            for k in 0..axis_len {
2358                let mut flat = self.offset;
2359                for d in 0..ndim {
2360                    let idx = if d == axis { k } else { pos[d] };
2361                    flat += idx * self.strides[d];
2362                }
2363                vals.push((data[flat], k));
2364            }
2365
2366            // Stable sort with deterministic tie-breaking by original index
2367            if descending {
2368                vals.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
2369                    .then(a.1.cmp(&b.1)));
2370            } else {
2371                vals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
2372                    .then(a.1.cmp(&b.1)));
2373            }
2374
2375            // Scatter back
2376            for (k, &(v, _)) in vals.iter().enumerate() {
2377                let mut flat = 0;
2378                let out_strides_local = Self::compute_strides(&out_shape);
2379                for d in 0..ndim {
2380                    let idx = if d == axis { k } else { pos[d] };
2381                    flat += idx * out_strides_local[d];
2382                }
2383                result[flat] = v;
2384            }
2385        }
2386
2387        Tensor::from_vec(result, &out_shape)
2388    }
2389
2390    /// N-D argsort along an axis. Returns indices tensor.
2391    pub fn argsort_axis(&self, axis: usize, descending: bool) -> Result<Tensor, RuntimeError> {
2392        let ndim = self.ndim();
2393        if axis >= ndim {
2394            return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2395        }
2396        let data = self.to_vec();
2397        let axis_len = self.shape[axis];
2398        let out_shape = self.shape.clone();
2399        let out_numel = Self::shape_numel(&out_shape);
2400
2401        let mut iter_shape: Vec<usize> = Vec::new();
2402        for (i, &s) in self.shape.iter().enumerate() {
2403            if i != axis { iter_shape.push(s); }
2404        }
2405        let n_slices: usize = iter_shape.iter().product::<usize>().max(1);
2406
2407        let mut result = vec![0.0f64; out_numel];
2408        let mut pos = vec![0usize; ndim];
2409
2410        for slice_idx in 0..n_slices {
2411            let mut remaining = slice_idx;
2412            let mut dim_idx = 0;
2413            for d in 0..ndim {
2414                if d == axis {
2415                    pos[d] = 0;
2416                } else {
2417                    let stride = {
2418                        let mut s = 1usize;
2419                        let mut di = 0;
2420                        for d2 in 0..ndim {
2421                            if d2 == axis { continue; }
2422                            if di > dim_idx { s *= self.shape[d2]; }
2423                            di += 1;
2424                        }
2425                        s
2426                    };
2427                    pos[d] = remaining / stride;
2428                    remaining %= stride;
2429                    dim_idx += 1;
2430                }
2431            }
2432
2433            let mut vals: Vec<(f64, usize)> = Vec::with_capacity(axis_len);
2434            for k in 0..axis_len {
2435                let mut flat = self.offset;
2436                for d in 0..ndim {
2437                    let idx = if d == axis { k } else { pos[d] };
2438                    flat += idx * self.strides[d];
2439                }
2440                vals.push((data[flat], k));
2441            }
2442
2443            if descending {
2444                vals.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
2445                    .then(a.1.cmp(&b.1)));
2446            } else {
2447                vals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
2448                    .then(a.1.cmp(&b.1)));
2449            }
2450
2451            for (k, &(_, orig_idx)) in vals.iter().enumerate() {
2452                let out_strides_local = Self::compute_strides(&out_shape);
2453                let mut flat = 0;
2454                for d in 0..ndim {
2455                    let idx = if d == axis { k } else { pos[d] };
2456                    flat += idx * out_strides_local[d];
2457                }
2458                result[flat] = orig_idx as f64;
2459            }
2460        }
2461
2462        Tensor::from_vec(result, &out_shape)
2463    }
2464
2465    // -----------------------------------------------------------------------
2466    // Phase 2: Einsum
2467    // -----------------------------------------------------------------------
2468
2469    /// Einstein summation notation.
2470    /// Supports patterns like "ij,jk->ik" (matmul), "ii->i" (diagonal),
2471    /// "ij->ji" (transpose), "ijk,ikl->ijl" (batched matmul).
2472    /// Uses BinnedAccumulator for all reductions.
2473    pub fn einsum(notation: &str, inputs: &[&Tensor]) -> Result<Tensor, RuntimeError> {
2474        // Parse notation: "subscripts->output" or just "subscripts" (implicit)
2475        let parts: Vec<&str> = notation.split("->").collect();
2476        if parts.len() != 2 {
2477            return Err(RuntimeError::InvalidOperation(
2478                format!("einsum: expected 'subscripts->output' notation, got '{}'", notation),
2479            ));
2480        }
2481        let input_specs: Vec<&str> = parts[0].split(',').collect();
2482        let output_spec = parts[1];
2483
2484        if input_specs.len() != inputs.len() {
2485            return Err(RuntimeError::InvalidOperation(
2486                format!("einsum: {} input specs but {} tensors", input_specs.len(), inputs.len()),
2487            ));
2488        }
2489
2490        // Build label → size mapping
2491        let mut label_size = std::collections::BTreeMap::new();
2492        for (i, &spec) in input_specs.iter().enumerate() {
2493            let chars: Vec<char> = spec.chars().collect();
2494            if chars.len() != inputs[i].ndim() {
2495                return Err(RuntimeError::InvalidOperation(
2496                    format!("einsum: spec '{}' has {} dims but tensor has {}", spec, chars.len(), inputs[i].ndim()),
2497                ));
2498            }
2499            for (d, &c) in chars.iter().enumerate() {
2500                let sz = inputs[i].shape()[d];
2501                if let Some(&prev) = label_size.get(&c) {
2502                    if prev != sz {
2503                        return Err(RuntimeError::InvalidOperation(
2504                            format!("einsum: label '{}' has conflicting sizes {} vs {}", c, prev, sz),
2505                        ));
2506                    }
2507                } else {
2508                    label_size.insert(c, sz);
2509                }
2510            }
2511        }
2512
2513        // Determine output shape
2514        let output_chars: Vec<char> = output_spec.chars().collect();
2515        let output_shape: Vec<usize> = output_chars.iter()
2516            .map(|c| label_size.get(c).copied().ok_or_else(||
2517                RuntimeError::InvalidOperation(format!("einsum: unknown label '{}' in output", c))))
2518            .collect::<Result<_, _>>()?;
2519        let out_numel = Self::shape_numel(&output_shape);
2520
2521        // Determine contraction labels (in input but not output)
2522        let output_set: std::collections::BTreeSet<char> = output_chars.iter().copied().collect();
2523        let contract_labels: Vec<char> = label_size.keys()
2524            .filter(|c| !output_set.contains(c))
2525            .copied()
2526            .collect();
2527        let contract_sizes: Vec<usize> = contract_labels.iter()
2528            .map(|c| label_size[c])
2529            .collect();
2530        let contract_numel: usize = contract_sizes.iter().product::<usize>().max(1);
2531
2532        // Precompute input spec chars
2533        let input_chars: Vec<Vec<char>> = input_specs.iter().map(|s| s.chars().collect()).collect();
2534
2535        // For each output position, iterate over contraction indices
2536        let out_strides = Self::compute_strides(&output_shape);
2537        let mut result = vec![0.0f64; out_numel];
2538
2539        // Pre-read input data
2540        let input_data: Vec<Vec<f64>> = inputs.iter().map(|t| t.to_vec()).collect();
2541        let input_strides: Vec<Vec<usize>> = inputs.iter().map(|t| t.strides.clone()).collect();
2542        let input_offsets: Vec<usize> = inputs.iter().map(|t| t.offset).collect();
2543
2544        for out_idx in 0..out_numel {
2545            // Compute output label values
2546            let mut label_vals = std::collections::BTreeMap::new();
2547            let mut remaining = out_idx;
2548            for (d, &c) in output_chars.iter().enumerate() {
2549                let stride = if d < out_strides.len() { out_strides[d] } else { 1 };
2550                label_vals.insert(c, remaining / stride);
2551                remaining %= stride;
2552            }
2553
2554            let mut acc = BinnedAccumulatorF64::new();
2555            // Iterate over all contraction index combinations
2556            for cidx in 0..contract_numel {
2557                // Compute contraction label values
2558                let mut cr = cidx;
2559                for (ci, &cl) in contract_labels.iter().enumerate() {
2560                    let stride: usize = contract_sizes[ci+1..].iter().product::<usize>().max(1);
2561                    label_vals.insert(cl, cr / stride);
2562                    cr %= stride;
2563                }
2564
2565                // Compute product of input elements
2566                let mut product = 1.0f64;
2567                for (inp_idx, chars) in input_chars.iter().enumerate() {
2568                    let mut flat = input_offsets[inp_idx];
2569                    for (d, &c) in chars.iter().enumerate() {
2570                        flat += label_vals[&c] * input_strides[inp_idx][d];
2571                    }
2572                    product *= input_data[inp_idx][flat];
2573                }
2574                acc.add(product);
2575            }
2576            result[out_idx] = acc.finalize();
2577        }
2578
2579        if output_shape.is_empty() {
2580            Tensor::from_vec(result, &[1])
2581        } else {
2582            Tensor::from_vec(result, &output_shape)
2583        }
2584    }
2585
2586    // -----------------------------------------------------------------------
2587    // Phase 2: Reshape / View Enhancements
2588    // -----------------------------------------------------------------------
2589
2590    /// Add a dimension of size 1 at position `dim`.
2591    pub fn unsqueeze(&self, dim: usize) -> Result<Tensor, RuntimeError> {
2592        let ndim = self.ndim();
2593        if dim > ndim {
2594            return Err(RuntimeError::IndexOutOfBounds { index: dim, length: ndim + 1 });
2595        }
2596        let mut new_shape = self.shape.clone();
2597        new_shape.insert(dim, 1);
2598        self.reshape(&new_shape)
2599    }
2600
2601    /// Remove a dimension of size 1 at position `dim`.
2602    /// If `dim` is `None`, removes all dimensions of size 1.
2603    pub fn squeeze(&self, dim: Option<usize>) -> Result<Tensor, RuntimeError> {
2604        match dim {
2605            Some(d) => {
2606                if d >= self.ndim() {
2607                    return Err(RuntimeError::IndexOutOfBounds { index: d, length: self.ndim() });
2608                }
2609                if self.shape[d] != 1 {
2610                    return Err(RuntimeError::InvalidOperation(
2611                        format!("squeeze: dimension {} has size {}, not 1", d, self.shape[d]),
2612                    ));
2613                }
2614                let mut new_shape = self.shape.clone();
2615                new_shape.remove(d);
2616                if new_shape.is_empty() {
2617                    new_shape.push(1); // scalar
2618                }
2619                self.reshape(&new_shape)
2620            }
2621            None => {
2622                let new_shape: Vec<usize> = self.shape.iter()
2623                    .filter(|&&s| s != 1)
2624                    .copied()
2625                    .collect();
2626                let new_shape = if new_shape.is_empty() { vec![1] } else { new_shape };
2627                self.reshape(&new_shape)
2628            }
2629        }
2630    }
2631
2632    /// Broadcast without copying. Returns a view with stride=0 for broadcasted dims.
2633    /// Same as `broadcast_to` but named for consistency with the gap-fix plan.
2634    pub fn expand(&self, target_shape: &[usize]) -> Result<Tensor, RuntimeError> {
2635        self.broadcast_to(target_shape)
2636    }
2637
2638    /// Flatten a range of dimensions [start_dim, end_dim] into a single dimension.
2639    pub fn flatten(&self, start_dim: usize, end_dim: usize) -> Result<Tensor, RuntimeError> {
2640        if start_dim > end_dim || end_dim >= self.ndim() {
2641            return Err(RuntimeError::InvalidOperation(
2642                format!("flatten: invalid dim range [{}, {}] for {}D tensor", start_dim, end_dim, self.ndim()),
2643            ));
2644        }
2645        let mut new_shape = Vec::new();
2646        for i in 0..start_dim {
2647            new_shape.push(self.shape[i]);
2648        }
2649        let flat_size: usize = self.shape[start_dim..=end_dim].iter().product();
2650        new_shape.push(flat_size);
2651        for i in (end_dim + 1)..self.ndim() {
2652            new_shape.push(self.shape[i]);
2653        }
2654        self.reshape(&new_shape)
2655    }
2656
2657    /// Split tensor into `n` roughly equal chunks along dimension `dim`.
2658    pub fn chunk(&self, n: usize, dim: usize) -> Result<Vec<Tensor>, RuntimeError> {
2659        if dim >= self.ndim() {
2660            return Err(RuntimeError::IndexOutOfBounds { index: dim, length: self.ndim() });
2661        }
2662        if n == 0 {
2663            return Err(RuntimeError::InvalidOperation("chunk: n must be > 0".into()));
2664        }
2665        let dim_size = self.shape[dim];
2666        let chunk_size = (dim_size + n - 1) / n;
2667        let mut sizes = Vec::new();
2668        let mut remaining = dim_size;
2669        while remaining > 0 {
2670            let s = remaining.min(chunk_size);
2671            sizes.push(s);
2672            remaining -= s;
2673        }
2674        self.split(&sizes, dim)
2675    }
2676
2677    /// Split tensor along dimension `dim` according to the given sizes.
2678    pub fn split(&self, sizes: &[usize], dim: usize) -> Result<Vec<Tensor>, RuntimeError> {
2679        if dim >= self.ndim() {
2680            return Err(RuntimeError::IndexOutOfBounds { index: dim, length: self.ndim() });
2681        }
2682        let total: usize = sizes.iter().sum();
2683        if total != self.shape[dim] {
2684            return Err(RuntimeError::InvalidOperation(
2685                format!("split: sizes sum {} != dim size {}", total, self.shape[dim]),
2686            ));
2687        }
2688
2689        let mut results = Vec::new();
2690        let mut offset = 0;
2691
2692        for &sz in sizes {
2693            let ranges: Vec<(usize, usize)> = self.shape.iter()
2694                .enumerate()
2695                .map(|(i, &s)| {
2696                    if i == dim { (offset, offset + sz) } else { (0, s) }
2697                })
2698                .collect();
2699            let chunk = self.slice(&ranges)?;
2700            // Materialize as contiguous
2701            results.push(chunk.to_contiguous());
2702            offset += sz;
2703        }
2704
2705        Ok(results)
2706    }
2707
2708    /// Fused `alpha * self + beta * other` element-wise. Single pass, one allocation.
2709    ///
2710    /// Critical for LSTM/GRU gates where `f * c_prev + i * g` would otherwise
2711    /// create 3 intermediate tensors.
2712    pub fn scale_add(&self, alpha: f64, other: &Tensor, beta: f64) -> Result<Tensor, RuntimeError> {
2713        if self.shape != other.shape {
2714            return Err(RuntimeError::InvalidOperation(
2715                "scale_add: shape mismatch".to_string(),
2716            ));
2717        }
2718        let a = self.to_vec();
2719        let b = other.to_vec();
2720        let result: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| alpha * x + beta * y).collect();
2721        Tensor::from_vec(result, &self.shape)
2722    }
2723}
2724