Skip to main content

cjc_runtime/
tensor.rs

1
2use cjc_repro::Rng;
3
4use crate::accumulator::{binned_sum_f64, BinnedAccumulatorF64};
5use cjc_repro::KahanAccumulatorF64;
6
7use crate::accumulator;
8use crate::buffer::Buffer;
9use crate::dispatch;
10use crate::error::RuntimeError;
11use crate::kernel as kernel_fns;
12use crate::tensor_simd::{self, BinOp, UnaryOp};
13use crate::tensor_tiled::TiledMatmul;
14
15// ---------------------------------------------------------------------------
16// 2. Tensor Runtime
17// ---------------------------------------------------------------------------
18
19/// An N-dimensional tensor backed by a `Buffer<f64>`.
20///
21/// Supports element-wise arithmetic, matrix multiplication (2-D), and
22/// numerically-stable reductions via BinnedAccumulator summation.
23#[derive(Debug, Clone)]
24pub struct Tensor {
25    pub buffer: Buffer<f64>,
26    pub(crate) shape: Vec<usize>,
27    pub(crate) strides: Vec<usize>,
28    pub(crate) offset: usize,
29}
30
31impl Tensor {
32    // -- Construction -------------------------------------------------------
33
34    /// Compute row-major strides for a given shape.
35    pub(crate) fn compute_strides(shape: &[usize]) -> Vec<usize> {
36        let mut strides = vec![1usize; shape.len()];
37        for i in (0..shape.len().saturating_sub(1)).rev() {
38            strides[i] = strides[i + 1] * shape[i + 1];
39        }
40        strides
41    }
42
43    /// Total number of elements implied by `shape`.
44    fn shape_numel(shape: &[usize]) -> usize {
45        shape.iter().product()
46    }
47
48    /// Create a tensor filled with zeros.
49    pub fn zeros(shape: &[usize]) -> Self {
50        let numel = Self::shape_numel(shape);
51        Tensor {
52            buffer: Buffer::alloc(numel, 0.0),
53            shape: shape.to_vec(),
54            strides: Self::compute_strides(shape),
55            offset: 0,
56        }
57    }
58
59    /// Create a tensor filled with ones.
60    pub fn ones(shape: &[usize]) -> Self {
61        let numel = Self::shape_numel(shape);
62        Tensor {
63            buffer: Buffer::alloc(numel, 1.0),
64            shape: shape.to_vec(),
65            strides: Self::compute_strides(shape),
66            offset: 0,
67        }
68    }
69
70    /// Create a tensor filled with samples from the standard normal
71    /// distribution, drawn deterministically from `rng`.
72    pub fn randn(shape: &[usize], rng: &mut Rng) -> Self {
73        let numel = Self::shape_numel(shape);
74        let data: Vec<f64> = (0..numel).map(|_| rng.next_normal_f64()).collect();
75        Tensor {
76            buffer: Buffer::from_vec(data),
77            shape: shape.to_vec(),
78            strides: Self::compute_strides(shape),
79            offset: 0,
80        }
81    }
82
83    /// Create a tensor from raw data and a shape. Returns an error if the
84    /// number of elements does not match the shape.
85    pub fn from_vec(data: Vec<f64>, shape: &[usize]) -> Result<Self, RuntimeError> {
86        let numel = Self::shape_numel(shape);
87        if data.len() != numel {
88            return Err(RuntimeError::ShapeMismatch {
89                expected: numel,
90                got: data.len(),
91            });
92        }
93        Ok(Tensor {
94            buffer: Buffer::from_vec(data),
95            shape: shape.to_vec(),
96            strides: Self::compute_strides(shape),
97            offset: 0,
98        })
99    }
100
101    // -- Accessors ----------------------------------------------------------
102
103    /// The shape of this tensor.
104    pub fn shape(&self) -> &[usize] {
105        &self.shape
106    }
107
108    /// Number of dimensions.
109    pub fn ndim(&self) -> usize {
110        self.shape.len()
111    }
112
113    /// Total number of elements.
114    pub fn len(&self) -> usize {
115        Self::shape_numel(&self.shape)
116    }
117
118    /// Whether the tensor has zero elements.
119    pub fn is_empty(&self) -> bool {
120        self.len() == 0
121    }
122
123    /// Flatten a multi-dimensional index to a linear offset in the buffer.
124    fn linear_index(&self, indices: &[usize]) -> Result<usize, RuntimeError> {
125        if indices.len() != self.shape.len() {
126            return Err(RuntimeError::DimensionMismatch {
127                expected: self.shape.len(),
128                got: indices.len(),
129            });
130        }
131        let mut off = self.offset;
132        for (i, &idx) in indices.iter().enumerate() {
133            if idx >= self.shape[i] {
134                return Err(RuntimeError::IndexOutOfBounds {
135                    index: idx,
136                    length: self.shape[i],
137                });
138            }
139            off += idx * self.strides[i];
140        }
141        Ok(off)
142    }
143
144    /// Whether this tensor is contiguous in memory (row-major, no offset).
145    pub fn is_contiguous(&self) -> bool {
146        if self.offset != 0 {
147            return false;
148        }
149        let expected = Self::compute_strides(&self.shape);
150        self.strides == expected
151    }
152
153    /// Create a zero-copy slice (view) of this tensor.
154    /// `ranges` contains `(start, end)` for each dimension.
155    pub fn slice(&self, ranges: &[(usize, usize)]) -> Result<Tensor, RuntimeError> {
156        if ranges.len() != self.shape.len() {
157            return Err(RuntimeError::DimensionMismatch {
158                expected: self.shape.len(),
159                got: ranges.len(),
160            });
161        }
162        let mut new_offset = self.offset;
163        let mut new_shape = Vec::with_capacity(ranges.len());
164        for (i, &(start, end)) in ranges.iter().enumerate() {
165            if end > self.shape[i] || start > end {
166                return Err(RuntimeError::IndexOutOfBounds {
167                    index: end,
168                    length: self.shape[i],
169                });
170            }
171            new_offset += start * self.strides[i];
172            new_shape.push(end - start);
173        }
174        Ok(Tensor {
175            buffer: self.buffer.clone(), // shared — zero copy
176            shape: new_shape,
177            strides: self.strides.clone(),
178            offset: new_offset,
179        })
180    }
181
182    /// Materialize a contiguous copy if this tensor is non-contiguous.
183    pub fn to_contiguous(&self) -> Tensor {
184        if self.is_contiguous() {
185            return self.clone();
186        }
187        let data = self.to_vec();
188        Tensor {
189            buffer: Buffer::from_vec(data),
190            shape: self.shape.clone(),
191            strides: Self::compute_strides(&self.shape),
192            offset: 0,
193        }
194    }
195
196    /// Create a broadcast view of this tensor to `target_shape`.
197    /// Uses stride=0 for dimensions that need broadcasting (size 1 -> target size).
198    pub fn broadcast_to(&self, target_shape: &[usize]) -> Result<Tensor, RuntimeError> {
199        let src_ndim = self.shape.len();
200        let tgt_ndim = target_shape.len();
201        if tgt_ndim < src_ndim {
202            return Err(RuntimeError::InvalidOperation(
203                "cannot broadcast to a smaller rank".to_string(),
204            ));
205        }
206        let pad = tgt_ndim - src_ndim;
207        let mut new_strides = vec![0usize; tgt_ndim];
208        for i in 0..tgt_ndim {
209            if i < pad {
210                // Padded dimension: stride = 0 (broadcast)
211                new_strides[i] = 0;
212            } else {
213                let src_i = i - pad;
214                if self.shape[src_i] == target_shape[i] {
215                    new_strides[i] = self.strides[src_i];
216                } else if self.shape[src_i] == 1 {
217                    new_strides[i] = 0; // broadcast
218                } else {
219                    return Err(RuntimeError::ShapeMismatch {
220                        expected: target_shape[i],
221                        got: self.shape[src_i],
222                    });
223                }
224            }
225        }
226        Ok(Tensor {
227            buffer: self.buffer.clone(),
228            shape: target_shape.to_vec(),
229            strides: new_strides,
230            offset: self.offset,
231        })
232    }
233
234    /// Read the element at the given multi-dimensional index.
235    pub fn get(&self, indices: &[usize]) -> Result<f64, RuntimeError> {
236        let offset = self.linear_index(indices)?;
237        self.buffer
238            .get(offset)
239            .ok_or(RuntimeError::IndexOutOfBounds {
240                index: offset,
241                length: self.buffer.len(),
242            })
243    }
244
245    /// Write the element at the given multi-dimensional index.
246    pub fn set(&mut self, indices: &[usize], val: f64) -> Result<(), RuntimeError> {
247        let offset = self.linear_index(indices)?;
248        self.buffer.set(offset, val)
249    }
250
251    /// Extract the raw data as a `Vec<f64>`, respecting strides and offset.
252    pub fn to_vec(&self) -> Vec<f64> {
253        if self.is_contiguous() {
254            let full = self.buffer.borrow_data();
255            let numel = self.len();
256            if full.len() == numel {
257                return full.to_vec();
258            }
259            // Buffer may be larger than the tensor's logical size
260            // (e.g. Scratchpad pre-allocates extra capacity)
261            return full[..numel].to_vec();
262        }
263        // Non-contiguous: iterate via strided access
264        let numel = self.len();
265        let mut result = Vec::with_capacity(numel);
266        let ndim = self.shape.len();
267        let mut indices = vec![0usize; ndim];
268        for _ in 0..numel {
269            let mut off = self.offset;
270            for d in 0..ndim {
271                off += indices[d] * self.strides[d];
272            }
273            result.push(self.buffer.get(off).unwrap_or(0.0));
274            // Increment multi-index (row-major order)
275            for d in (0..ndim).rev() {
276                indices[d] += 1;
277                if indices[d] < self.shape[d] {
278                    break;
279                }
280                indices[d] = 0;
281            }
282        }
283        result
284    }
285
286    // -- Reshape ------------------------------------------------------------
287
288    /// Reshape to `new_shape`. The new shape must have the same total number
289    /// of elements. The returned tensor **shares** the underlying buffer.
290    pub fn reshape(&self, new_shape: &[usize]) -> Result<Tensor, RuntimeError> {
291        let new_numel = Self::shape_numel(new_shape);
292        if new_numel != self.len() {
293            return Err(RuntimeError::ShapeMismatch {
294                expected: self.len(),
295                got: new_numel,
296            });
297        }
298        // Reshape requires contiguous data; materialize if needed
299        let tensor = if self.is_contiguous() { self.clone() } else { self.to_contiguous() };
300        Ok(Tensor {
301            buffer: tensor.buffer,
302            shape: new_shape.to_vec(),
303            strides: Self::compute_strides(new_shape),
304            offset: 0,
305        })
306    }
307
308    // -- Element-wise operations --------------------------------------------
309
310    /// Apply a binary operation element-wise with broadcasting support.
311    fn elementwise_binop(
312        &self,
313        other: &Tensor,
314        op: impl Fn(f64, f64) -> f64,
315    ) -> Result<Tensor, RuntimeError> {
316        if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
317            // Fast path: same shape, both contiguous — borrow without cloning
318            let a = self.buffer.borrow_data();
319            let b = other.buffer.borrow_data();
320            let data: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| op(x, y)).collect();
321            return Ok(Tensor {
322                buffer: Buffer::from_vec(data),
323                shape: self.shape.clone(),
324                strides: Self::compute_strides(&self.shape),
325                offset: 0,
326            });
327        }
328
329        // Broadcasting path: compute result shape
330        let result_shape = Self::broadcast_result_shape(&self.shape, &other.shape)?;
331        let a_broadcast = self.broadcast_to(&result_shape)?;
332        let b_broadcast = other.broadcast_to(&result_shape)?;
333
334        let numel = Self::shape_numel(&result_shape);
335        let ndim = result_shape.len();
336        let mut data = Vec::with_capacity(numel);
337        let mut indices = vec![0usize; ndim];
338
339        for _ in 0..numel {
340            let mut off_a = a_broadcast.offset;
341            let mut off_b = b_broadcast.offset;
342            for d in 0..ndim {
343                off_a += indices[d] * a_broadcast.strides[d];
344                off_b += indices[d] * b_broadcast.strides[d];
345            }
346            let va = a_broadcast.buffer.get(off_a).unwrap_or(0.0);
347            let vb = b_broadcast.buffer.get(off_b).unwrap_or(0.0);
348            data.push(op(va, vb));
349
350            for d in (0..ndim).rev() {
351                indices[d] += 1;
352                if indices[d] < result_shape[d] {
353                    break;
354                }
355                indices[d] = 0;
356            }
357        }
358
359        Ok(Tensor {
360            buffer: Buffer::from_vec(data),
361            shape: result_shape.clone(),
362            strides: Self::compute_strides(&result_shape),
363            offset: 0,
364        })
365    }
366
367    /// Compute the broadcast result shape for two shapes (NumPy rules).
368    fn broadcast_result_shape(a: &[usize], b: &[usize]) -> Result<Vec<usize>, RuntimeError> {
369        let max_ndim = a.len().max(b.len());
370        let mut result = Vec::with_capacity(max_ndim);
371        for i in 0..max_ndim {
372            let da = if i < max_ndim - a.len() { 1 } else { a[i - (max_ndim - a.len())] };
373            let db = if i < max_ndim - b.len() { 1 } else { b[i - (max_ndim - b.len())] };
374            if da == db {
375                result.push(da);
376            } else if da == 1 {
377                result.push(db);
378            } else if db == 1 {
379                result.push(da);
380            } else {
381                return Err(RuntimeError::ShapeMismatch {
382                    expected: da,
383                    got: db,
384                });
385            }
386        }
387        Ok(result)
388    }
389
390    /// SIMD-accelerated element-wise binary operation for known ops.
391    ///
392    /// For same-shape contiguous tensors, uses AVX2 (4-wide f64) when available.
393    /// Falls back to the generic closure path for broadcast cases.
394    fn elementwise_binop_simd(
395        &self,
396        other: &Tensor,
397        op: BinOp,
398        fallback: impl Fn(f64, f64) -> f64,
399    ) -> Result<Tensor, RuntimeError> {
400        if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
401            // SIMD fast path: same shape, both contiguous
402            let a = self.buffer.borrow_data();
403            let b = other.buffer.borrow_data();
404            let data = tensor_simd::simd_binop(&a, &b, op);
405            return Ok(Tensor {
406                buffer: Buffer::from_vec(data),
407                shape: self.shape.clone(),
408                strides: Self::compute_strides(&self.shape),
409                offset: 0,
410            });
411        }
412        // Broadcast path: fall through to generic
413        self.elementwise_binop(other, fallback)
414    }
415
416    /// Element-wise addition (SIMD-accelerated for contiguous same-shape tensors).
417    pub fn add(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
418        self.elementwise_binop_simd(other, BinOp::Add, |a, b| a + b)
419    }
420
421    /// Element-wise subtraction (SIMD-accelerated for contiguous same-shape tensors).
422    pub fn sub(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
423        self.elementwise_binop_simd(other, BinOp::Sub, |a, b| a - b)
424    }
425
426    /// Element-wise (Hadamard) multiplication (SIMD-accelerated for contiguous same-shape tensors).
427    pub fn mul_elem(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
428        self.elementwise_binop_simd(other, BinOp::Mul, |a, b| a * b)
429    }
430
431    /// Element-wise division (SIMD-accelerated for contiguous same-shape tensors).
432    pub fn div_elem(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
433        self.elementwise_binop_simd(other, BinOp::Div, |a, b| a / b)
434    }
435
436    /// Fused multiply-add: `self * b + c` element-wise in a single pass.
437    ///
438    /// Eliminates the intermediate tensor that separate mul + add would create.
439    /// Uses software FMA (`a * b + c` with two roundings, not hardware FMA)
440    /// to preserve bit-identity with the non-fused path.
441    pub fn fused_mul_add(&self, b: &Tensor, c: &Tensor) -> Result<Tensor, RuntimeError> {
442        if self.shape != b.shape || self.shape != c.shape {
443            return Err(RuntimeError::InvalidOperation(
444                "broadcast_fma: all three tensors must have the same shape".to_string(),
445            ));
446        }
447        if self.is_contiguous() && b.is_contiguous() && c.is_contiguous() {
448            let a_data = self.buffer.borrow_data();
449            let b_data = b.buffer.borrow_data();
450            let c_data = c.buffer.borrow_data();
451            let n = a_data.len();
452            let mut out = vec![0.0f64; n];
453            // Software FMA: a*b + c (two roundings — NOT hardware FMA which uses one rounding).
454            // This produces identical results to separate broadcast2("mul") + broadcast2("add").
455            for i in 0..n {
456                out[i] = a_data[i] * b_data[i] + c_data[i];
457            }
458            return Ok(Tensor {
459                buffer: Buffer::from_vec(out),
460                shape: self.shape.clone(),
461                strides: Self::compute_strides(&self.shape),
462                offset: 0,
463            });
464        }
465        // Non-contiguous fallback: mul then add
466        let temp = self.mul_elem(b)?;
467        temp.add(c)
468    }
469
470    // ── v0.1 Broadcasting: additional element-wise binary ops ──
471
472    /// Element-wise power: `a^b`.
473    pub fn elem_pow(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
474        self.elementwise_binop(other, |a, b| a.powf(b))
475    }
476
477    /// Element-wise minimum.
478    pub fn elem_min(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
479        self.elementwise_binop(other, |a, b| a.min(b))
480    }
481
482    /// Element-wise maximum.
483    pub fn elem_max(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
484        self.elementwise_binop(other, |a, b| a.max(b))
485    }
486
487    /// Element-wise atan2(self, other).
488    pub fn elem_atan2(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
489        self.elementwise_binop(other, |a, b| a.atan2(b))
490    }
491
492    /// Element-wise hypot(self, other).
493    pub fn elem_hypot(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
494        self.elementwise_binop(other, |a, b| a.hypot(b))
495    }
496
497    /// Apply a unary function to every element, returning a new contiguous tensor.
498    pub fn map(&self, f: impl Fn(f64) -> f64) -> Tensor {
499        let data: Vec<f64> = self.to_vec().iter().map(|&x| f(x)).collect();
500        Tensor {
501            buffer: Buffer::from_vec(data),
502            shape: self.shape.clone(),
503            strides: Self::compute_strides(&self.shape),
504            offset: 0,
505        }
506    }
507
508    /// SIMD-accelerated unary map for known operations (sqrt, abs, neg, relu).
509    ///
510    /// Uses AVX2 (4-wide f64) when available, scalar fallback otherwise.
511    /// Bit-identical to `map(f)` for the supported operations.
512    pub fn map_simd(&self, op: UnaryOp) -> Tensor {
513        let src = self.to_vec();
514        let data = tensor_simd::simd_unary(&src, op);
515        Tensor {
516            buffer: Buffer::from_vec(data),
517            shape: self.shape.clone(),
518            strides: Self::compute_strides(&self.shape),
519            offset: 0,
520        }
521    }
522
523    // -- Reductions (using BinnedAccumulator) --------------------------------
524
525    /// Sum of all elements (binned accumulation — order-invariant, deterministic).
526    pub fn sum(&self) -> f64 {
527        let data = self.buffer.borrow_data();
528        binned_sum_f64(&data)
529    }
530
531    /// Sum of all elements using BinnedAccumulator (order-invariant, deterministic).
532    ///
533    /// Bit-identical results regardless of element ordering or reduction schedule.
534    pub fn binned_sum(&self) -> f64 {
535        let data = self.buffer.borrow_data();
536        accumulator::binned_sum_f64(&data)
537    }
538
539    /// Sum with dispatched strategy based on execution context.
540    ///
541    /// Uses Kahan in serial mode, Binned in parallel/@nogc/strict/linalg mode.
542    pub fn dispatched_sum(&self, ctx: &dispatch::ReductionContext) -> f64 {
543        let data = self.buffer.borrow_data();
544        dispatch::dispatch_sum_f64(&data, ctx)
545    }
546
547    /// Mean of all elements (binned sum / count).
548    pub fn mean(&self) -> f64 {
549        let n = self.len();
550        if n == 0 {
551            return 0.0;
552        }
553        self.sum() / n as f64
554    }
555
556    /// Mean with dispatched strategy based on execution context.
557    pub fn dispatched_mean(&self, ctx: &dispatch::ReductionContext) -> f64 {
558        let n = self.len();
559        if n == 0 {
560            return 0.0;
561        }
562        self.dispatched_sum(ctx) / n as f64
563    }
564
565    /// Sum along a specific axis, returning a tensor with that dimension reduced.
566    ///
567    /// Supports N-D tensors. The reduced axis becomes size 1 in the output.
568    /// Uses BinnedAccumulator for order-invariant, deterministic summation.
569    ///
570    /// Examples:
571    /// - 2D [M, N] with axis=0: result [1, N] (sum columns)
572    /// - 2D [M, N] with axis=1: result [M, 1] (sum rows)
573    /// - 3D [A, B, C] with axis=1: result [A, 1, C]
574    pub fn sum_axis(&self, axis: usize) -> Result<Tensor, RuntimeError> {
575        let ndim = self.ndim();
576        if axis >= ndim {
577            return Err(RuntimeError::IndexOutOfBounds {
578                index: axis,
579                length: ndim,
580            });
581        }
582
583        // Build output shape: same as input but with axis dimension = 1
584        let mut out_shape = self.shape.clone();
585        out_shape[axis] = 1;
586        let out_numel = Self::shape_numel(&out_shape);
587        let out_strides = Self::compute_strides(&out_shape);
588
589        let data = self.to_vec();
590        let axis_len = self.shape[axis];
591        let mut result = vec![0.0f64; out_numel];
592
593        // For each output position, sum over the reduced axis with binned accumulation.
594        let mut indices = vec![0usize; ndim];
595        for out_idx in 0..out_numel {
596            // Compute the N-D index from flat output index
597            {
598                let mut remaining = out_idx;
599                for d in 0..ndim {
600                    indices[d] = remaining / out_strides[d];
601                    remaining %= out_strides[d];
602                }
603            }
604
605            let mut acc = BinnedAccumulatorF64::new();
606            for k in 0..axis_len {
607                // Compute input flat index with indices[axis] = k
608                let mut flat = self.offset;
609                for d in 0..ndim {
610                    let idx = if d == axis { k } else { indices[d] };
611                    flat += idx * self.strides[d];
612                }
613                acc.add(data[flat]);
614            }
615            result[out_idx] = acc.finalize();
616        }
617
618        Tensor::from_vec(result, &out_shape)
619    }
620
621    // -- Matrix multiplication (2-D only) -----------------------------------
622
623    /// Negate every element, returning a new tensor.
624    pub fn neg(&self) -> Tensor {
625        self.map(|x| -x)
626    }
627
628    /// Transpose a tensor. For 2-D: swaps rows and columns (zero-copy view).
629    /// For N-D: reverses all axes (zero-copy view).
630    pub fn transpose(&self) -> Tensor {
631        let ndim = self.ndim();
632        if ndim <= 1 {
633            return self.clone();
634        }
635        // Reverse shape and strides — zero-copy view
636        let mut new_shape = self.shape.clone();
637        let mut new_strides = self.strides.clone();
638        new_shape.reverse();
639        new_strides.reverse();
640        Tensor {
641            buffer: self.buffer.clone(), // shared — zero copy
642            shape: new_shape,
643            strides: new_strides,
644            offset: self.offset,
645        }
646    }
647
648    /// Transpose with explicit axis permutation (N-D). Zero-copy view.
649    ///
650    /// `axes` must be a permutation of `[0, 1, ..., ndim-1]`.
651    pub fn transpose_axes(&self, axes: &[usize]) -> Result<Tensor, RuntimeError> {
652        let ndim = self.ndim();
653        if axes.len() != ndim {
654            return Err(RuntimeError::InvalidOperation(
655                format!("transpose_axes: expected {} axes, got {}", ndim, axes.len()),
656            ));
657        }
658        // Validate permutation
659        let mut seen = vec![false; ndim];
660        for &ax in axes {
661            if ax >= ndim {
662                return Err(RuntimeError::IndexOutOfBounds { index: ax, length: ndim });
663            }
664            if seen[ax] {
665                return Err(RuntimeError::InvalidOperation(
666                    format!("transpose_axes: duplicate axis {ax}"),
667                ));
668            }
669            seen[ax] = true;
670        }
671        let new_shape: Vec<usize> = axes.iter().map(|&ax| self.shape[ax]).collect();
672        let new_strides: Vec<usize> = axes.iter().map(|&ax| self.strides[ax]).collect();
673        Ok(Tensor {
674            buffer: self.buffer.clone(),
675            shape: new_shape,
676            strides: new_strides,
677            offset: self.offset,
678        })
679    }
680
681    /// Multiply every element by a scalar, returning a new tensor.
682    pub fn scalar_mul(&self, s: f64) -> Tensor {
683        self.map(|x| x * s)
684    }
685
686    // ── Panicking convenience constructors (used by AD engine) --------
687
688    /// Create a tensor from raw data and shape.
689    /// **Panics** if `data.len()` does not match the shape.
690    pub fn from_vec_unchecked(data: Vec<f64>, shape: &[usize]) -> Tensor {
691        Self::from_vec(data, shape).expect("Tensor::from_vec_unchecked: shape mismatch")
692    }
693
694    /// Element-wise addition. **Panics** on shape mismatch.
695    pub fn add_unchecked(&self, other: &Tensor) -> Tensor {
696        self.add(other).expect("Tensor::add shape mismatch")
697    }
698
699    /// Element-wise subtraction. **Panics** on shape mismatch.
700    pub fn sub_unchecked(&self, other: &Tensor) -> Tensor {
701        self.sub(other).expect("Tensor::sub shape mismatch")
702    }
703
704    /// Element-wise multiplication. **Panics** on shape mismatch.
705    pub fn mul_elem_unchecked(&self, other: &Tensor) -> Tensor {
706        self.mul_elem(other).expect("Tensor::mul_elem shape mismatch")
707    }
708
709    /// Element-wise division. **Panics** on shape mismatch.
710    pub fn div_elem_unchecked(&self, other: &Tensor) -> Tensor {
711        self.div_elem(other).expect("Tensor::div_elem shape mismatch")
712    }
713
714    /// Matrix multiplication. **Panics** on dimension mismatch.
715    pub fn matmul_unchecked(&self, other: &Tensor) -> Tensor {
716        self.matmul(other).expect("Tensor::matmul dimension mismatch")
717    }
718
719    /// Matrix multiplication for 2-D tensors.
720    ///
721    /// `self` is (M, K), `other` is (K, N) => result is (M, N).
722    pub fn matmul(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
723        if self.ndim() != 2 || other.ndim() != 2 {
724            return Err(RuntimeError::InvalidOperation(
725                "matmul requires 2-D tensors".to_string(),
726            ));
727        }
728        let m = self.shape[0];
729        let k = self.shape[1];
730        let k2 = other.shape[0];
731        let n = other.shape[1];
732        if k != k2 {
733            return Err(RuntimeError::DimensionMismatch {
734                expected: k,
735                got: k2,
736            });
737        }
738
739        let a = self.to_vec();
740        let b = other.to_vec();
741
742        // Parallel path (Mode A): parallelize over output rows when the parallel
743        // feature is enabled and the matrix is large enough (>= 256 in any dim).
744        #[cfg(feature = "parallel")]
745        {
746            if m >= 256 || n >= 256 || k >= 256 {
747                return Self::matmul_parallel_mode_a(&a, &b, m, n, k);
748            }
749        }
750
751        // Tiled path: use L2-friendly tiled matmul for medium-to-large matrices.
752        // Threshold: any dimension >= 64 (the default tile size).
753        // NOTE: tiled path uses naive accumulation (not binned) — different
754        // numerical path for large matrices, but better cache locality.
755        if m >= 64 || n >= 64 || k >= 64 {
756            return Self::matmul_tiled(&a, &b, m, n, k);
757        }
758
759        // Sequential path: single-threaded with binned accumulation.
760        Self::matmul_sequential(&a, &b, m, n, k)
761    }
762
763    /// Sequential matmul (always available, deterministic reference).
764    fn matmul_sequential(
765        a: &[f64], b: &[f64], m: usize, n: usize, k: usize,
766    ) -> Result<Tensor, RuntimeError> {
767        let mut result = vec![0.0f64; m * n];
768        for i in 0..m {
769            for j in 0..n {
770                let mut acc = KahanAccumulatorF64::new();
771                for p in 0..k {
772                    acc.add(a[i * k + p] * b[p * n + j]);
773                }
774                result[i * n + j] = acc.finalize();
775            }
776        }
777        Tensor::from_vec(result, &[m, n])
778    }
779
780    /// Tiled matmul: delegates to `TiledMatmul` for L2-cache-friendly tiling.
781    ///
782    /// Used for medium matrices (any dimension >= 64) where cache locality
783    /// matters but parallel overhead isn't justified. The tiled path uses
784    /// naive accumulation (not binned accumulation), trading a small amount of
785    /// floating-point precision for better cache behavior.
786    fn matmul_tiled(
787        a: &[f64], b: &[f64], m: usize, n: usize, k: usize,
788    ) -> Result<Tensor, RuntimeError> {
789        let engine = TiledMatmul::new();
790        let result = engine.matmul(a, m, k, b, n);
791        Tensor::from_vec(result, &[m, n])
792    }
793
794    /// Parallel matmul Mode A: parallelize over output rows using tiled
795    /// micro-kernels for cache locality.
796    ///
797    /// Deterministic because:
798    /// - Each output row is computed by exactly one thread.
799    /// - Within each row, accumulation uses tiled AXPY (deterministic order).
800    /// - No cross-thread reduction or merge of partial sums.
801    ///
802    /// Uses KahanAccumulatorF64 (lightweight) instead of BinnedAccumulatorF64
803    /// (32KB per accumulator) to avoid massive stack pressure in parallel mode.
804    #[cfg(feature = "parallel")]
805    fn matmul_parallel_mode_a(
806        a: &[f64], b: &[f64], m: usize, n: usize, k: usize,
807    ) -> Result<Tensor, RuntimeError> {
808        use rayon::prelude::*;
809        use cjc_repro::KahanAccumulatorF64;
810
811        // For large matrices, use tiled matmul (sequential but cache-friendly).
812        // The tiling provides far better cache locality than parallel row-wise
813        // with column-strided B access, and avoids the 32KB-per-element
814        // BinnedAccumulator overhead that caused the 128→256 regression.
815        if m >= 512 && n >= 512 {
816            // Only use rayon for very large matrices where thread overhead
817            // is amortized. Split into row-bands, each processed with tiled matmul.
818            let band_size = (m + rayon::current_num_threads() - 1) / rayon::current_num_threads();
819            let band_size = band_size.max(64); // At least 64 rows per band
820            let mut result = vec![0.0f64; m * n];
821
822            result
823                .par_chunks_mut(band_size * n)
824                .enumerate()
825                .for_each(|(band_idx, band)| {
826                    let i_start = band_idx * band_size;
827                    let i_end = (i_start + band_size).min(m);
828                    let band_m = i_end - i_start;
829                    let a_band = &a[i_start * k .. i_end * k];
830                    let engine = crate::tensor_tiled::TiledMatmul::new();
831                    let tiled_result = engine.matmul(a_band, band_m, k, b, n);
832                    band[..band_m * n].copy_from_slice(&tiled_result);
833                });
834
835            return Tensor::from_vec(result, &[m, n]);
836        }
837
838        // For medium matrices (256-511), use Kahan accumulator (16 bytes, not 32KB).
839        let mut result = vec![0.0f64; m * n];
840        result
841            .par_chunks_mut(n)
842            .enumerate()
843            .for_each(|(i, row)| {
844                for j in 0..n {
845                    let mut acc = KahanAccumulatorF64::new();
846                    for p in 0..k {
847                        acc.add(a[i * k + p] * b[p * n + j]);
848                    }
849                    row[j] = acc.finalize();
850                }
851            });
852
853        Tensor::from_vec(result, &[m, n])
854    }
855
856    // -- Transformer Kernels ------------------------------------------------
857
858    /// Batched matrix multiplication.
859    ///
860    /// `self` is `[..., M, K]`, `other` is `[..., K, N]` => result is `[..., M, N]`.
861    /// The batch dimensions must be identical (no broadcast).
862    /// For 2-D inputs, delegates to `matmul`.
863    pub fn bmm(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
864        if self.ndim() < 2 || other.ndim() < 2 {
865            return Err(RuntimeError::InvalidOperation(
866                "bmm requires at least 2-D tensors".to_string(),
867            ));
868        }
869        if self.ndim() == 2 && other.ndim() == 2 {
870            return self.matmul(other);
871        }
872        if self.ndim() != other.ndim() {
873            return Err(RuntimeError::InvalidOperation(
874                format!(
875                    "bmm requires same number of dimensions, got {} and {}",
876                    self.ndim(),
877                    other.ndim()
878                ),
879            ));
880        }
881        let nd = self.ndim();
882        let batch_dims_a = &self.shape[..nd - 2];
883        let batch_dims_b = &other.shape[..nd - 2];
884        if batch_dims_a != batch_dims_b {
885            return Err(RuntimeError::InvalidOperation(
886                format!(
887                    "bmm batch dimensions mismatch: {:?} vs {:?}",
888                    batch_dims_a, batch_dims_b
889                ),
890            ));
891        }
892        let m = self.shape[nd - 2];
893        let k = self.shape[nd - 1];
894        let k2 = other.shape[nd - 2];
895        let n = other.shape[nd - 1];
896        if k != k2 {
897            return Err(RuntimeError::DimensionMismatch {
898                expected: k,
899                got: k2,
900            });
901        }
902
903        let batch_size: usize = batch_dims_a.iter().product();
904        let a = self.to_vec();
905        let b = other.to_vec();
906        let mat_a_stride = m * k;
907        let mat_b_stride = k * n;
908        let mat_c_stride = m * n;
909        let mut result = vec![0.0f64; batch_size * mat_c_stride];
910
911        // Helper closure: compute one batch into c_slice
912        let compute_batch = |batch: usize, c_slice: &mut [f64]| {
913            let a_slice = &a[batch * mat_a_stride..(batch + 1) * mat_a_stride];
914            let b_slice = &b[batch * mat_b_stride..(batch + 1) * mat_b_stride];
915
916            if m >= 64 || n >= 64 || k >= 64 {
917                let engine = crate::tensor_tiled::TiledMatmul::new();
918                let tiled = engine.matmul(a_slice, m, k, b_slice, n);
919                c_slice.copy_from_slice(&tiled);
920            } else {
921                for i in 0..m {
922                    for j in 0..n {
923                        let mut acc = KahanAccumulatorF64::new();
924                        for p in 0..k {
925                            acc.add(a_slice[i * k + p] * b_slice[p * n + j]);
926                        }
927                        c_slice[i * n + j] = acc.finalize();
928                    }
929                }
930            }
931        };
932
933        // Parallel path: parallelize over batches when workload is large enough
934        #[cfg(feature = "parallel")]
935        {
936            if batch_size > 1 && m * k >= 4096 {
937                use rayon::prelude::*;
938                result
939                    .par_chunks_mut(mat_c_stride)
940                    .enumerate()
941                    .for_each(|(batch, c_slice)| {
942                        compute_batch(batch, c_slice);
943                    });
944
945                let mut out_shape = batch_dims_a.to_vec();
946                out_shape.push(m);
947                out_shape.push(n);
948                return Tensor::from_vec(result, &out_shape);
949            }
950        }
951
952        // Sequential fallback
953        for batch in 0..batch_size {
954            let c_off = batch * mat_c_stride;
955            compute_batch(batch, &mut result[c_off..c_off + mat_c_stride]);
956        }
957
958        let mut out_shape = batch_dims_a.to_vec();
959        out_shape.push(m);
960        out_shape.push(n);
961        Tensor::from_vec(result, &out_shape)
962    }
963
964    /// Softmax along the last dimension (two-pass stable algorithm).
965    ///
966    /// Pass 1: find max per row (prevents overflow in exp)
967    /// Pass 2: compute exp(x - max), accumulate sum, normalize
968    ///
969    /// For a tensor of shape `[..., N]`, softmax is applied independently
970    /// to each length-N slice along the last axis.
971    pub fn softmax(&self) -> Result<Tensor, RuntimeError> {
972        if self.ndim() == 0 {
973            return Err(RuntimeError::InvalidOperation(
974                "softmax requires at least 1-D tensor".to_string(),
975            ));
976        }
977        // Avoid allocation when tensor is already contiguous and starts at offset 0
978        let data_ref;
979        let data_vec;
980        let data: &[f64] = if self.is_contiguous() && self.offset == 0 {
981            data_ref = self.buffer.borrow_data();
982            &data_ref
983        } else {
984            data_vec = self.to_vec();
985            &data_vec
986        };
987        let n = *self.shape.last().unwrap(); // last dimension size
988        let outer: usize = data.len() / n;  // product of all dims except last
989        let mut result = vec![0.0f64; data.len()];
990
991        for row in 0..outer {
992            let start = row * n;
993            let end = start + n;
994            let slice = &data[start..end];
995
996            // Pass 1: find max for numerical stability
997            let mut max_val = f64::NEG_INFINITY;
998            for &v in slice {
999                if v > max_val {
1000                    max_val = v;
1001                }
1002            }
1003
1004            // Pass 2: exp(x - max) and accumulate sum
1005            let mut exp_vals = vec![0.0f64; n];
1006            let mut sum = 0.0f64;
1007            let mut comp = 0.0f64; // Kahan compensation
1008            for i in 0..n {
1009                let e = (slice[i] - max_val).exp();
1010                exp_vals[i] = e;
1011                // Kahan summation for the denominator
1012                let y = e - comp;
1013                let t = sum + y;
1014                comp = (t - sum) - y;
1015                sum = t;
1016            }
1017
1018            // Normalize
1019            if sum == 0.0 {
1020                // Degenerate case: all -inf inputs → uniform
1021                let uniform = 1.0 / n as f64;
1022                for i in 0..n {
1023                    result[start + i] = uniform;
1024                }
1025            } else {
1026                for i in 0..n {
1027                    result[start + i] = exp_vals[i] / sum;
1028                }
1029            }
1030        }
1031
1032        Tensor::from_vec(result, &self.shape)
1033    }
1034
1035    /// Layer normalization over the last dimension.
1036    ///
1037    /// For each length-D slice along the last axis:
1038    ///   1. mean = Σx / D  (BinnedAccumulator)
1039    ///   2. var  = Σ(x - mean)² / D  (BinnedAccumulator)
1040    ///   3. normalized = (x - mean) / √(var + eps)
1041    ///   4. output = gamma * normalized + beta
1042    ///
1043    /// `gamma` and `beta` are 1-D tensors of shape `[D]`.
1044    /// `eps` is a small constant (typically 1e-5).
1045    pub fn layer_norm(
1046        &self,
1047        gamma: &Tensor,
1048        beta: &Tensor,
1049        eps: f64,
1050    ) -> Result<Tensor, RuntimeError> {
1051        if self.ndim() == 0 {
1052            return Err(RuntimeError::InvalidOperation(
1053                "layer_norm requires at least 1-D tensor".to_string(),
1054            ));
1055        }
1056        let d = *self.shape.last().unwrap();
1057        if gamma.len() != d || beta.len() != d {
1058            return Err(RuntimeError::InvalidOperation(
1059                format!(
1060                    "layer_norm: gamma/beta length {} must match last dim {}",
1061                    gamma.len(),
1062                    d
1063                ),
1064            ));
1065        }
1066
1067        let data = self.to_vec();
1068        let gamma_data = gamma.to_vec();
1069        let beta_data = beta.to_vec();
1070        let outer = data.len() / d;
1071        let mut result = vec![0.0f64; data.len()];
1072
1073        for row in 0..outer {
1074            let start = row * d;
1075            let slice = &data[start..start + d];
1076
1077            // Pass 1: compute mean via BinnedAccumulator
1078            let mean = binned_sum_f64(slice) / d as f64;
1079
1080            // Pass 2: compute variance via BinnedAccumulator
1081            let diffs: Vec<f64> = slice.iter().map(|&x| {
1082                let diff = x - mean;
1083                diff * diff
1084            }).collect();
1085            let variance = binned_sum_f64(&diffs) / d as f64;
1086
1087            // Normalize, scale, shift
1088            let inv_std = 1.0 / (variance + eps).sqrt();
1089            for i in 0..d {
1090                let normalized = (slice[i] - mean) * inv_std;
1091                result[start + i] = gamma_data[i] * normalized + beta_data[i];
1092            }
1093        }
1094
1095        Tensor::from_vec(result, &self.shape)
1096    }
1097
1098    /// ReLU activation: max(0, x) element-wise.
1099    /// Apply a function element-wise, reusing the buffer when possible (COW).
1100    /// If refcount == 1, mutates in place (zero allocations).
1101    /// Otherwise, allocates a new buffer.
1102    fn map_elementwise(&self, f: impl Fn(f64) -> f64) -> Tensor {
1103        if self.is_contiguous() && self.offset == 0 && self.buffer.refcount() == 1 {
1104            // Fast path: mutate in place (COW — we're the sole owner)
1105            let mut data = self.buffer.borrow_data().clone();
1106            for x in data.iter_mut() {
1107                *x = f(*x);
1108            }
1109            Tensor::from_vec(data, &self.shape).unwrap()
1110        } else {
1111            // Fallback: allocate new buffer
1112            let data = self.to_vec();
1113            let result: Vec<f64> = data.iter().map(|&x| f(x)).collect();
1114            Tensor::from_vec(result, &self.shape).unwrap()
1115        }
1116    }
1117
1118    pub fn relu(&self) -> Tensor {
1119        self.map_elementwise(|x| if x > 0.0 { x } else { 0.0 })
1120    }
1121
1122    /// Sigmoid activation: 1 / (1 + exp(-x)) element-wise.
1123    pub fn sigmoid(&self) -> Tensor {
1124        self.map_elementwise(|x| 1.0 / (1.0 + (-x).exp()))
1125    }
1126
1127    /// Tanh activation element-wise.
1128    pub fn tanh_activation(&self) -> Tensor {
1129        self.map_elementwise(|x| x.tanh())
1130    }
1131
1132    /// Leaky ReLU activation: max(alpha*x, x) element-wise.
1133    pub fn leaky_relu(&self, alpha: f64) -> Tensor {
1134        self.map_elementwise(move |x| if x > 0.0 { x } else { alpha * x })
1135    }
1136
1137    /// SiLU (Swish) activation: x * sigmoid(x) element-wise.
1138    pub fn silu(&self) -> Tensor {
1139        let data = self.to_vec();
1140        let result: Vec<f64> = data.iter().map(|&x| x / (1.0 + (-x).exp())).collect();
1141        Tensor::from_vec(result, &self.shape).unwrap()
1142    }
1143
1144    /// Mish activation: x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))).
1145    pub fn mish(&self) -> Tensor {
1146        let data = self.to_vec();
1147        let result: Vec<f64> = data.iter().map(|&x| {
1148            let sp = (1.0 + x.exp()).ln();
1149            x * sp.tanh()
1150        }).collect();
1151        Tensor::from_vec(result, &self.shape).unwrap()
1152    }
1153
1154    /// Argmax: index of the maximum element (first occurrence, deterministic).
1155    pub fn argmax(&self) -> usize {
1156        let data = self.to_vec();
1157        let mut best_idx = 0;
1158        let mut best_val = f64::NEG_INFINITY;
1159        for (i, &v) in data.iter().enumerate() {
1160            if v > best_val || (v == best_val && i < best_idx) {
1161                best_val = v;
1162                best_idx = i;
1163            }
1164        }
1165        best_idx
1166    }
1167
1168    /// Argmin: index of the minimum element (first occurrence, deterministic).
1169    pub fn argmin(&self) -> usize {
1170        let data = self.to_vec();
1171        let mut best_idx = 0;
1172        let mut best_val = f64::INFINITY;
1173        for (i, &v) in data.iter().enumerate() {
1174            if v < best_val || (v == best_val && i < best_idx) {
1175                best_val = v;
1176                best_idx = i;
1177            }
1178        }
1179        best_idx
1180    }
1181
1182    /// Clamp all elements to [min, max].
1183    pub fn clamp(&self, min: f64, max: f64) -> Tensor {
1184        let data = self.to_vec();
1185        let result: Vec<f64> = data.iter().map(|&x| x.max(min).min(max)).collect();
1186        Tensor::from_vec(result, &self.shape).unwrap()
1187    }
1188
1189    /// One-hot encoding: given a 1D tensor of integer indices and a depth,
1190    /// returns a 2D tensor of shape [len, depth].
1191    pub fn one_hot(indices: &[usize], depth: usize) -> Result<Tensor, RuntimeError> {
1192        let n = indices.len();
1193        let mut data = vec![0.0; n * depth];
1194        for (i, &idx) in indices.iter().enumerate() {
1195            if idx >= depth {
1196                return Err(RuntimeError::InvalidOperation(format!(
1197                    "one_hot: index {idx} >= depth {depth}"
1198                )));
1199            }
1200            data[i * depth + idx] = 1.0;
1201        }
1202        Tensor::from_vec(data, &[n, depth])
1203    }
1204
1205    // -----------------------------------------------------------------------
1206    // Phase B4: Tensor extensions (cat, stack, topk)
1207    // -----------------------------------------------------------------------
1208
1209    /// Concatenate tensors along existing axis.
1210    pub fn cat(tensors: &[&Tensor], axis: usize) -> Result<Tensor, RuntimeError> {
1211        if tensors.is_empty() {
1212            return Err(RuntimeError::InvalidOperation("cat: no tensors".to_string()));
1213        }
1214        let ndim = tensors[0].ndim();
1215        if axis >= ndim {
1216            return Err(RuntimeError::InvalidOperation(
1217                format!("cat: axis {axis} out of bounds for {ndim}D tensor"),
1218            ));
1219        }
1220        for (i, t) in tensors.iter().enumerate().skip(1) {
1221            if t.ndim() != ndim {
1222                return Err(RuntimeError::InvalidOperation(
1223                    format!("cat: tensor {i} has different ndim"),
1224                ));
1225            }
1226            for d in 0..ndim {
1227                if d != axis && t.shape[d] != tensors[0].shape[d] {
1228                    return Err(RuntimeError::InvalidOperation(
1229                        format!("cat: shape mismatch at dim {d}"),
1230                    ));
1231                }
1232            }
1233        }
1234        let mut out_shape = tensors[0].shape.clone();
1235        for t in tensors.iter().skip(1) {
1236            out_shape[axis] += t.shape[axis];
1237        }
1238        let total = out_shape.iter().product::<usize>();
1239        let mut result = vec![0.0; total];
1240        let mut out_strides = vec![1usize; ndim];
1241        for d in (0..ndim - 1).rev() {
1242            out_strides[d] = out_strides[d + 1] * out_shape[d + 1];
1243        }
1244        let mut offset = 0;
1245        for t in tensors {
1246            let t_data = t.to_vec();
1247            let t_total: usize = t.shape.iter().product();
1248            let mut t_strides = vec![1usize; ndim];
1249            for d in (0..ndim - 1).rev() {
1250                t_strides[d] = t_strides[d + 1] * t.shape[d + 1];
1251            }
1252            for idx in 0..t_total {
1253                let mut remaining = idx;
1254                let mut out_flat = 0;
1255                for d in 0..ndim {
1256                    let coord = remaining / t_strides[d];
1257                    remaining %= t_strides[d];
1258                    let out_coord = if d == axis { coord + offset } else { coord };
1259                    out_flat += out_coord * out_strides[d];
1260                }
1261                result[out_flat] = t_data[idx];
1262            }
1263            offset += t.shape[axis];
1264        }
1265        Tensor::from_vec(result, &out_shape)
1266    }
1267
1268    /// Stack tensors along a new axis.
1269    pub fn stack(tensors: &[&Tensor], axis: usize) -> Result<Tensor, RuntimeError> {
1270        if tensors.is_empty() {
1271            return Err(RuntimeError::InvalidOperation("stack: no tensors".to_string()));
1272        }
1273        let base_shape = &tensors[0].shape;
1274        let ndim = base_shape.len();
1275        if axis > ndim {
1276            return Err(RuntimeError::InvalidOperation(
1277                format!("stack: axis {axis} out of bounds"),
1278            ));
1279        }
1280        for (i, t) in tensors.iter().enumerate().skip(1) {
1281            if &t.shape != base_shape {
1282                return Err(RuntimeError::InvalidOperation(
1283                    format!("stack: tensor {i} shape mismatch"),
1284                ));
1285            }
1286        }
1287        let mut out_shape = Vec::with_capacity(ndim + 1);
1288        for d in 0..axis { out_shape.push(base_shape[d]); }
1289        out_shape.push(tensors.len());
1290        for d in axis..ndim { out_shape.push(base_shape[d]); }
1291        let total: usize = out_shape.iter().product();
1292        let mut result = vec![0.0; total];
1293        let inner_size: usize = base_shape[axis..].iter().product::<usize>().max(1);
1294        let outer_size: usize = base_shape[..axis].iter().product::<usize>().max(1);
1295        for (t_idx, t) in tensors.iter().enumerate() {
1296            let t_data = t.to_vec();
1297            for outer in 0..outer_size {
1298                for inner in 0..inner_size {
1299                    let src = outer * inner_size + inner;
1300                    let dst = outer * (tensors.len() * inner_size) + t_idx * inner_size + inner;
1301                    if src < t_data.len() && dst < result.len() {
1302                        result[dst] = t_data[src];
1303                    }
1304                }
1305            }
1306        }
1307        Tensor::from_vec(result, &out_shape)
1308    }
1309
1310    /// Top-k values and indices (largest k values from flat data).
1311    pub fn topk(&self, k: usize) -> Result<(Tensor, Vec<usize>), RuntimeError> {
1312        let data = self.to_vec();
1313        let n = data.len();
1314        if k > n {
1315            return Err(RuntimeError::InvalidOperation(
1316                format!("topk: k={k} exceeds data length {n}"),
1317            ));
1318        }
1319        let mut indexed: Vec<(usize, f64)> = data.into_iter().enumerate().collect();
1320        indexed.sort_by(|a, b| b.1.total_cmp(&a.1).then(a.0.cmp(&b.0)));
1321        let top_k: Vec<(usize, f64)> = indexed[..k].to_vec();
1322        let values: Vec<f64> = top_k.iter().map(|&(_, v)| v).collect();
1323        let indices: Vec<usize> = top_k.iter().map(|&(i, _)| i).collect();
1324        Ok((Tensor::from_vec(values, &[k])?, indices))
1325    }
1326
1327    /// GELU activation (approximate): x * 0.5 * (1 + tanh(√(2/π) * (x + 0.044715 * x³)))
1328    pub fn gelu(&self) -> Tensor {
1329        let data = self.to_vec();
1330        let sqrt_2_over_pi = (2.0_f64 / std::f64::consts::PI).sqrt();
1331        let result: Vec<f64> = data.iter().map(|&x| {
1332            let inner = sqrt_2_over_pi * (x + 0.044715 * x * x * x);
1333            0.5 * x * (1.0 + inner.tanh())
1334        }).collect();
1335        Tensor::from_vec(result, &self.shape).unwrap()
1336    }
1337
1338    /// Linear layer: output = input @ weight^T + bias
1339    ///
1340    /// `self` is `[..., in_features]`, `weight` is `[out_features, in_features]`,
1341    /// `bias` is `[out_features]`.
1342    /// Result is `[..., out_features]`.
1343    pub fn linear(
1344        &self,
1345        weight: &Tensor,
1346        bias: &Tensor,
1347    ) -> Result<Tensor, RuntimeError> {
1348        if weight.ndim() != 2 {
1349            return Err(RuntimeError::InvalidOperation(
1350                "linear: weight must be 2-D [out_features, in_features]".to_string(),
1351            ));
1352        }
1353        let out_features = weight.shape[0];
1354        let in_features = weight.shape[1];
1355        let last_dim = *self.shape.last().ok_or_else(|| {
1356            RuntimeError::InvalidOperation("linear: input must be at least 1-D".to_string())
1357        })?;
1358        if last_dim != in_features {
1359            return Err(RuntimeError::DimensionMismatch {
1360                expected: in_features,
1361                got: last_dim,
1362            });
1363        }
1364        if bias.len() != out_features {
1365            return Err(RuntimeError::InvalidOperation(
1366                format!(
1367                    "linear: bias length {} must match out_features {}",
1368                    bias.len(),
1369                    out_features
1370                ),
1371            ));
1372        }
1373
1374        let data = self.to_vec();
1375        let w = weight.to_vec();
1376        let b = bias.to_vec();
1377        let outer = data.len() / in_features;
1378        let mut result = vec![0.0f64; outer * out_features];
1379
1380        for row in 0..outer {
1381            let x_start = row * in_features;
1382            let x_slice = &data[x_start..x_start + in_features];
1383            let y_start = row * out_features;
1384            for j in 0..out_features {
1385                let w_start = j * in_features;
1386                let mut acc = BinnedAccumulatorF64::new();
1387                for p in 0..in_features {
1388                    acc.add(x_slice[p] * w[w_start + p]);
1389                }
1390                result[y_start + j] = acc.finalize() + b[j];
1391            }
1392        }
1393
1394        let mut out_shape = self.shape[..self.shape.len() - 1].to_vec();
1395        out_shape.push(out_features);
1396        Tensor::from_vec(result, &out_shape)
1397    }
1398
1399    /// 1D convolution: signal `[signal_len]` * filters `[out_ch, kernel_size]` + bias
1400    ///
1401    /// Returns `[out_ch, signal_len - kernel_size + 1]` (valid mode, stride=1).
1402    pub fn conv1d(
1403        &self,
1404        filters: &Tensor,
1405        bias: &Tensor,
1406    ) -> Result<Tensor, RuntimeError> {
1407        if self.ndim() != 1 {
1408            return Err(RuntimeError::InvalidOperation(
1409                "conv1d: input must be 1-D [signal_len]".to_string(),
1410            ));
1411        }
1412        if filters.ndim() != 2 {
1413            return Err(RuntimeError::InvalidOperation(
1414                "conv1d: filters must be 2-D [out_channels, kernel_size]".to_string(),
1415            ));
1416        }
1417        let signal_len = self.shape[0];
1418        let out_channels = filters.shape[0];
1419        let kernel_size = filters.shape[1];
1420        if signal_len < kernel_size {
1421            return Err(RuntimeError::InvalidOperation(
1422                format!(
1423                    "conv1d: signal_len {} < kernel_size {}",
1424                    signal_len, kernel_size
1425                ),
1426            ));
1427        }
1428        if bias.len() != out_channels {
1429            return Err(RuntimeError::InvalidOperation(
1430                format!(
1431                    "conv1d: bias length {} must match out_channels {}",
1432                    bias.len(), out_channels
1433                ),
1434            ));
1435        }
1436        let out_len = signal_len - kernel_size + 1;
1437        let s = self.to_vec();
1438        let f = filters.to_vec();
1439        let b = bias.to_vec();
1440        let mut result = vec![0.0; out_channels * out_len];
1441        kernel_fns::conv1d_raw(&s, &f, &b, &mut result, signal_len, out_channels, kernel_size);
1442        Tensor::from_vec(result, &[out_channels, out_len])
1443    }
1444
1445    /// 2D convolution — NCHW layout, valid mode, configurable stride.
1446    ///
1447    /// # Arguments
1448    /// - `self`:    `[N, C_in, H, W]` input tensor
1449    /// - `filters`: `[C_out, C_in, kH, kW]`
1450    /// - `bias`:    `[C_out]`
1451    /// - `stride`:  spatial stride (default 1)
1452    ///
1453    /// # Returns
1454    /// `[N, C_out, H_out, W_out]` where `H_out = (H - kH) / stride + 1`.
1455    ///
1456    /// Uses `BinnedAccumulatorF64` for every dot product — bit-identical results
1457    /// across all runs and hardware configurations.
1458    pub fn conv2d(
1459        &self,
1460        filters: &Tensor,
1461        bias: &Tensor,
1462        stride: usize,
1463    ) -> Result<Tensor, RuntimeError> {
1464        if self.ndim() != 4 {
1465            return Err(RuntimeError::InvalidOperation(
1466                "conv2d: input must be 4-D [N, C_in, H, W]".to_string(),
1467            ));
1468        }
1469        if filters.ndim() != 4 {
1470            return Err(RuntimeError::InvalidOperation(
1471                "conv2d: filters must be 4-D [C_out, C_in, kH, kW]".to_string(),
1472            ));
1473        }
1474        if stride == 0 {
1475            return Err(RuntimeError::InvalidOperation(
1476                "conv2d: stride must be >= 1".to_string(),
1477            ));
1478        }
1479
1480        let n    = self.shape[0];
1481        let c_in = self.shape[1];
1482        let h_in = self.shape[2];
1483        let w_in = self.shape[3];
1484
1485        let c_out      = filters.shape[0];
1486        let c_in_check = filters.shape[1];
1487        let kh         = filters.shape[2];
1488        let kw         = filters.shape[3];
1489
1490        if c_in != c_in_check {
1491            return Err(RuntimeError::InvalidOperation(format!(
1492                "conv2d: input C_in={} does not match filter C_in={}",
1493                c_in, c_in_check
1494            )));
1495        }
1496        if h_in < kh || w_in < kw {
1497            return Err(RuntimeError::InvalidOperation(format!(
1498                "conv2d: input spatial [{}, {}] is smaller than kernel [{}, {}]",
1499                h_in, w_in, kh, kw
1500            )));
1501        }
1502        if bias.len() != c_out {
1503            return Err(RuntimeError::InvalidOperation(format!(
1504                "conv2d: bias length {} must match C_out={}",
1505                bias.len(), c_out
1506            )));
1507        }
1508
1509        let h_out = (h_in - kh) / stride + 1;
1510        let w_out = (w_in - kw) / stride + 1;
1511
1512        let inp = self.to_vec();
1513        let flt = filters.to_vec();
1514        let b   = bias.to_vec();
1515        let mut result = vec![0.0f64; n * c_out * h_out * w_out];
1516
1517        kernel_fns::conv2d_raw(&inp, &flt, &b, &mut result,
1518                           n, c_in, h_in, w_in, c_out, kh, kw, stride);
1519
1520        Tensor::from_vec(result, &[n, c_out, h_out, w_out])
1521    }
1522
1523    /// 2D max-pooling — NCHW layout, non-overlapping windows.
1524    ///
1525    /// - `self`: `[N, C, H, W]`
1526    /// - `ph`, `pw`: pool height/width (stride = window size)
1527    ///
1528    /// Returns `[N, C, H/ph, W/pw]`.
1529    pub fn maxpool2d(&self, ph: usize, pw: usize) -> Result<Tensor, RuntimeError> {
1530        if self.ndim() != 4 {
1531            return Err(RuntimeError::InvalidOperation(
1532                "maxpool2d: input must be 4-D [N, C, H, W]".to_string(),
1533            ));
1534        }
1535        if ph == 0 || pw == 0 {
1536            return Err(RuntimeError::InvalidOperation(
1537                "maxpool2d: pool size must be >= 1".to_string(),
1538            ));
1539        }
1540
1541        let n    = self.shape[0];
1542        let c    = self.shape[1];
1543        let h_in = self.shape[2];
1544        let w_in = self.shape[3];
1545
1546        if h_in < ph || w_in < pw {
1547            return Err(RuntimeError::InvalidOperation(format!(
1548                "maxpool2d: input [{}, {}] smaller than pool [{}, {}]",
1549                h_in, w_in, ph, pw
1550            )));
1551        }
1552
1553        let h_out = h_in / ph;
1554        let w_out = w_in / pw;
1555
1556        let inp = self.to_vec();
1557        let mut result = vec![0.0f64; n * c * h_out * w_out];
1558
1559        kernel_fns::maxpool2d_raw(&inp, &mut result, n, c, h_in, w_in, ph, pw);
1560
1561        Tensor::from_vec(result, &[n, c, h_out, w_out])
1562    }
1563
1564    /// Scaled dot-product attention (single head).
1565    ///
1566    /// `queries` is `[..., T, d_k]`
1567    /// `keys`    is `[..., S, d_k]`
1568    /// `values`  is `[..., S, d_v]`
1569    ///
1570    /// Computes: softmax(Q × Kᵀ / √d_k) × V
1571    /// Returns `[..., T, d_v]`.
1572    pub fn scaled_dot_product_attention(
1573        queries: &Tensor,
1574        keys: &Tensor,
1575        values: &Tensor,
1576    ) -> Result<Tensor, RuntimeError> {
1577        if queries.ndim() < 2 || keys.ndim() < 2 || values.ndim() < 2 {
1578            return Err(RuntimeError::InvalidOperation(
1579                "attention: Q, K, V must be at least 2-D".to_string(),
1580            ));
1581        }
1582        let nd = queries.ndim();
1583        let d_k = queries.shape[nd - 1];
1584        let scale = 1.0 / (d_k as f64).sqrt();
1585
1586        // Transpose keys: swap last two dims
1587        let keys_t = keys.transpose_last_two()?;
1588
1589        // Q × K^T → [... T, S]
1590        let scores = queries.bmm(&keys_t)?;
1591
1592        // Scale
1593        let scores_scaled = scores.scalar_mul(scale);
1594
1595        // Softmax along last dim
1596        let attn_weights = scores_scaled.softmax()?;
1597
1598        // Attn × V → [... T, d_v]
1599        attn_weights.bmm(values)
1600    }
1601
1602    /// Transpose the last two dimensions of a tensor.
1603    ///
1604    /// `[..., A, B]` → `[..., B, A]`
1605    pub fn transpose_last_two(&self) -> Result<Tensor, RuntimeError> {
1606        if self.ndim() < 2 {
1607            return Err(RuntimeError::InvalidOperation(
1608                "transpose_last_two requires at least 2-D tensor".to_string(),
1609            ));
1610        }
1611        let nd = self.ndim();
1612        let rows = self.shape[nd - 2];
1613        let cols = self.shape[nd - 1];
1614        let data = self.to_vec();
1615        let batch_size: usize = self.shape[..nd - 2].iter().product::<usize>().max(1);
1616        let mat_size = rows * cols;
1617        let mut result = vec![0.0f64; data.len()];
1618
1619        for b in 0..batch_size {
1620            let off = b * mat_size;
1621            for i in 0..rows {
1622                for j in 0..cols {
1623                    result[off + j * rows + i] = data[off + i * cols + j];
1624                }
1625            }
1626        }
1627
1628        let mut out_shape = self.shape.clone();
1629        out_shape[nd - 2] = cols;
1630        out_shape[nd - 1] = rows;
1631        Tensor::from_vec(result, &out_shape)
1632    }
1633
1634    // -- Zero-Copy Weight Mapping -------------------------------------------
1635
1636    /// Create a tensor view from raw bytes — **zero allocation**.
1637    ///
1638    /// Interprets `bytes` as a contiguous block of `f64` (8 bytes each) or
1639    /// `f32` (4 bytes each, promoted to f64) values and maps them into a
1640    /// `Tensor` with the given shape.
1641    ///
1642    /// `dtype` must be `"f64"` or `"f32"`.
1643    ///
1644    /// For f64: bytes.len() must equal shape_numel * 8.
1645    /// For f32: bytes.len() must equal shape_numel * 4.
1646    ///
1647    /// The returned tensor **owns** its buffer (copied from the raw bytes)
1648    /// but performs exactly one allocation for the data vector.
1649    pub fn from_bytes(bytes: &[u8], shape: &[usize], dtype: &str) -> Result<Tensor, RuntimeError> {
1650        let numel = Self::shape_numel(shape);
1651        match dtype {
1652            "f64" => {
1653                let expected = numel * 8;
1654                if bytes.len() != expected {
1655                    return Err(RuntimeError::ShapeMismatch {
1656                        expected,
1657                        got: bytes.len(),
1658                    });
1659                }
1660                let mut data = Vec::with_capacity(numel);
1661                for i in 0..numel {
1662                    let off = i * 8;
1663                    let mut buf = [0u8; 8];
1664                    buf.copy_from_slice(&bytes[off..off + 8]);
1665                    data.push(f64::from_le_bytes(buf));
1666                }
1667                Ok(Tensor {
1668                    buffer: Buffer::from_vec(data),
1669                    shape: shape.to_vec(),
1670                    strides: Self::compute_strides(shape),
1671                    offset: 0,
1672                })
1673            }
1674            "f32" => {
1675                let expected = numel * 4;
1676                if bytes.len() != expected {
1677                    return Err(RuntimeError::ShapeMismatch {
1678                        expected,
1679                        got: bytes.len(),
1680                    });
1681                }
1682                let mut data = Vec::with_capacity(numel);
1683                for i in 0..numel {
1684                    let off = i * 4;
1685                    let mut buf = [0u8; 4];
1686                    buf.copy_from_slice(&bytes[off..off + 4]);
1687                    data.push(f32::from_le_bytes(buf) as f64);
1688                }
1689                Ok(Tensor {
1690                    buffer: Buffer::from_vec(data),
1691                    shape: shape.to_vec(),
1692                    strides: Self::compute_strides(shape),
1693                    offset: 0,
1694                })
1695            }
1696            _ => Err(RuntimeError::InvalidOperation(
1697                format!("from_bytes: unsupported dtype '{}', expected 'f32' or 'f64'", dtype),
1698            )),
1699        }
1700    }
1701
1702    // -- Multi-Head Attention Splitting -------------------------------------
1703
1704    /// Reshape a 3D tensor `[batch, seq, model_dim]` into 4D
1705    /// `[batch, num_heads, seq, head_dim]` by splitting the last dimension.
1706    ///
1707    /// This is a **zero-copy view** — it only changes shape/strides metadata.
1708    /// `model_dim` must be divisible by `num_heads`.
1709    pub fn split_heads(&self, num_heads: usize) -> Result<Tensor, RuntimeError> {
1710        if self.ndim() != 3 {
1711            return Err(RuntimeError::DimensionMismatch {
1712                expected: 3,
1713                got: self.ndim(),
1714            });
1715        }
1716        let batch = self.shape[0];
1717        let seq = self.shape[1];
1718        let model_dim = self.shape[2];
1719        if model_dim % num_heads != 0 {
1720            return Err(RuntimeError::InvalidOperation(
1721                format!(
1722                    "split_heads: model_dim {} not divisible by num_heads {}",
1723                    model_dim, num_heads
1724                ),
1725            ));
1726        }
1727        let head_dim = model_dim / num_heads;
1728        // Need contiguous data for the reshape
1729        let tensor = if self.is_contiguous() { self.clone() } else { self.to_contiguous() };
1730        // Reshape [B, S, H*D] -> [B, S, H, D] then transpose to [B, H, S, D]
1731        let reshaped = Tensor {
1732            buffer: tensor.buffer.clone(),
1733            shape: vec![batch, seq, num_heads, head_dim],
1734            strides: Self::compute_strides(&[batch, seq, num_heads, head_dim]),
1735            offset: 0,
1736        };
1737        // Transpose dims 1 and 2: [B, S, H, D] -> [B, H, S, D]
1738        // New strides: swap strides[1] and strides[2]
1739        Ok(Tensor {
1740            buffer: reshaped.buffer,
1741            shape: vec![batch, num_heads, seq, head_dim],
1742            strides: vec![
1743                reshaped.strides[0], // batch stride unchanged
1744                reshaped.strides[2], // head stride (was dim 2)
1745                reshaped.strides[1], // seq stride (was dim 1)
1746                reshaped.strides[3], // head_dim stride unchanged
1747            ],
1748            offset: 0,
1749        })
1750    }
1751
1752    /// Merge heads back: reshape 4D `[batch, num_heads, seq, head_dim]` into
1753    /// 3D `[batch, seq, model_dim]`. Materializes if non-contiguous.
1754    pub fn merge_heads(&self) -> Result<Tensor, RuntimeError> {
1755        if self.ndim() != 4 {
1756            return Err(RuntimeError::DimensionMismatch {
1757                expected: 4,
1758                got: self.ndim(),
1759            });
1760        }
1761        let batch = self.shape[0];
1762        let num_heads = self.shape[1];
1763        let seq = self.shape[2];
1764        let head_dim = self.shape[3];
1765        // Need [B, H, S, D] -> [B, S, H, D] -> [B, S, H*D]
1766        // Transpose dims 1 and 2 first
1767        let transposed = Tensor {
1768            buffer: self.buffer.clone(),
1769            shape: vec![batch, seq, num_heads, head_dim],
1770            strides: vec![
1771                self.strides[0],
1772                self.strides[2], // seq stride
1773                self.strides[1], // head stride
1774                self.strides[3],
1775            ],
1776            offset: self.offset,
1777        };
1778        // Materialize contiguous then reshape
1779        let contig = transposed.to_contiguous();
1780        let model_dim = num_heads * head_dim;
1781        Ok(Tensor {
1782            buffer: contig.buffer,
1783            shape: vec![batch, seq, model_dim],
1784            strides: Self::compute_strides(&[batch, seq, model_dim]),
1785            offset: 0,
1786        })
1787    }
1788
1789    /// View-only reshape: reinterpret shape without copying.
1790    /// Only works on contiguous tensors. Falls back to copy if non-contiguous.
1791    pub fn view_reshape(&self, new_shape: &[usize]) -> Result<Tensor, RuntimeError> {
1792        self.reshape(new_shape)
1793    }
1794
1795    // -----------------------------------------------------------------------
1796    // Phase C4: Sorting & Tensor Indexing
1797    // -----------------------------------------------------------------------
1798
1799    /// Returns indices that would sort the flattened tensor in ascending order.
1800    /// Uses f64::total_cmp for deterministic ordering of NaN.
1801    pub fn argsort(&self) -> Tensor {
1802        let data = self.to_vec();
1803        let mut indices: Vec<usize> = (0..data.len()).collect();
1804        indices.sort_by(|&a, &b| data[a].total_cmp(&data[b]));
1805        let result: Vec<f64> = indices.iter().map(|&i| i as f64).collect();
1806        Tensor::from_vec_unchecked(result, &[data.len()])
1807    }
1808
1809    /// Gather elements from the tensor along a dimension using index tensor.
1810    /// For 1D: result[i] = self[indices[i]]
1811    /// For 2D dim=0: result[i][j] = self[indices[i][j]][j]
1812    /// For 2D dim=1: result[i][j] = self[i][indices[i][j]]
1813    pub fn gather(&self, dim: usize, indices: &Tensor) -> Result<Tensor, RuntimeError> {
1814        let data = self.to_vec();
1815        let idx_data = indices.to_vec();
1816        if self.ndim() == 1 {
1817            let mut result = Vec::with_capacity(idx_data.len());
1818            for &idx in &idx_data {
1819                let i = idx as usize;
1820                if i >= data.len() {
1821                    return Err(RuntimeError::InvalidOperation(
1822                        format!("gather: index {} out of bounds for size {}", i, data.len()),
1823                    ));
1824                }
1825                result.push(data[i]);
1826            }
1827            Ok(Tensor::from_vec_unchecked(result, indices.shape()))
1828        } else if self.ndim() == 2 {
1829            let rows = self.shape[0];
1830            let cols = self.shape[1];
1831            let idx_shape = indices.shape();
1832            let out_rows = idx_shape[0];
1833            let out_cols = idx_shape[1];
1834            let mut result = vec![0.0; out_rows * out_cols];
1835            for i in 0..out_rows {
1836                for j in 0..out_cols {
1837                    let idx = idx_data[i * out_cols + j] as usize;
1838                    let val = if dim == 0 {
1839                        if idx >= rows {
1840                            return Err(RuntimeError::InvalidOperation(
1841                                format!("gather dim=0: index {} out of bounds for {} rows", idx, rows),
1842                            ));
1843                        }
1844                        data[idx * cols + j]
1845                    } else {
1846                        if idx >= cols {
1847                            return Err(RuntimeError::InvalidOperation(
1848                                format!("gather dim=1: index {} out of bounds for {} cols", idx, cols),
1849                            ));
1850                        }
1851                        data[i * cols + idx]
1852                    };
1853                    result[i * out_cols + j] = val;
1854                }
1855            }
1856            Ok(Tensor::from_vec_unchecked(result, idx_shape))
1857        } else {
1858            Err(RuntimeError::InvalidOperation(
1859                "gather: only 1D and 2D tensors supported".into(),
1860            ))
1861        }
1862    }
1863
1864    /// Scatter src values into a tensor of given shape at indices along a dimension.
1865    /// For 1D: result[indices[i]] = src[i]
1866    /// For 2D dim=0: result[indices[i][j]][j] = src[i][j]
1867    /// For 2D dim=1: result[i][indices[i][j]] = src[i][j]
1868    pub fn scatter(&self, dim: usize, indices: &Tensor, src: &Tensor) -> Result<Tensor, RuntimeError> {
1869        let mut result = self.to_vec();
1870        let idx_data = indices.to_vec();
1871        let src_data = src.to_vec();
1872        if self.ndim() == 1 {
1873            for (k, &idx) in idx_data.iter().enumerate() {
1874                let i = idx as usize;
1875                if i >= result.len() {
1876                    return Err(RuntimeError::InvalidOperation(
1877                        format!("scatter: index {} out of bounds for size {}", i, result.len()),
1878                    ));
1879                }
1880                result[i] = src_data[k];
1881            }
1882            Ok(Tensor::from_vec_unchecked(result, self.shape()))
1883        } else if self.ndim() == 2 {
1884            let cols = self.shape[1];
1885            let idx_shape = indices.shape();
1886            let out_cols = idx_shape[1];
1887            let out_rows = idx_shape[0];
1888            for i in 0..out_rows {
1889                for j in 0..out_cols {
1890                    let idx = idx_data[i * out_cols + j] as usize;
1891                    let src_val = src_data[i * out_cols + j];
1892                    if dim == 0 {
1893                        if idx >= self.shape[0] {
1894                            return Err(RuntimeError::InvalidOperation(
1895                                format!("scatter dim=0: index {} out of bounds for {} rows", idx, self.shape[0]),
1896                            ));
1897                        }
1898                        result[idx * cols + j] = src_val;
1899                    } else {
1900                        if idx >= cols {
1901                            return Err(RuntimeError::InvalidOperation(
1902                                format!("scatter dim=1: index {} out of bounds for {} cols", idx, cols),
1903                            ));
1904                        }
1905                        result[i * cols + idx] = src_val;
1906                    }
1907                }
1908            }
1909            Ok(Tensor::from_vec_unchecked(result, self.shape()))
1910        } else {
1911            Err(RuntimeError::InvalidOperation(
1912                "scatter: only 1D and 2D tensors supported".into(),
1913            ))
1914        }
1915    }
1916
1917    /// Select slices along a dimension by index.
1918    /// For 2D dim=0: selects rows
1919    /// For 2D dim=1: selects columns
1920    pub fn index_select(&self, dim: usize, indices: &Tensor) -> Result<Tensor, RuntimeError> {
1921        let data = self.to_vec();
1922        let idx_data = indices.to_vec();
1923        if self.ndim() == 1 {
1924            let mut result = Vec::with_capacity(idx_data.len());
1925            for &idx in &idx_data {
1926                let i = idx as usize;
1927                if i >= data.len() {
1928                    return Err(RuntimeError::InvalidOperation(
1929                        format!("index_select: index {} out of bounds for size {}", i, data.len()),
1930                    ));
1931                }
1932                result.push(data[i]);
1933            }
1934            Ok(Tensor::from_vec_unchecked(result, &[idx_data.len()]))
1935        } else if self.ndim() == 2 {
1936            let rows = self.shape[0];
1937            let cols = self.shape[1];
1938            let n = idx_data.len();
1939            if dim == 0 {
1940                let mut result = Vec::with_capacity(n * cols);
1941                for &idx in &idx_data {
1942                    let i = idx as usize;
1943                    if i >= rows {
1944                        return Err(RuntimeError::InvalidOperation(
1945                            format!("index_select dim=0: index {} out of bounds for {} rows", i, rows),
1946                        ));
1947                    }
1948                    for j in 0..cols {
1949                        result.push(data[i * cols + j]);
1950                    }
1951                }
1952                Ok(Tensor::from_vec_unchecked(result, &[n, cols]))
1953            } else {
1954                let mut result = Vec::with_capacity(rows * n);
1955                for i in 0..rows {
1956                    for &idx in &idx_data {
1957                        let j = idx as usize;
1958                        if j >= cols {
1959                            return Err(RuntimeError::InvalidOperation(
1960                                format!("index_select dim=1: index {} out of bounds for {} cols", j, cols),
1961                            ));
1962                        }
1963                        result.push(data[i * cols + j]);
1964                    }
1965                }
1966                Ok(Tensor::from_vec_unchecked(result, &[rows, n]))
1967            }
1968        } else {
1969            Err(RuntimeError::InvalidOperation(
1970                "index_select: only 1D and 2D tensors supported".into(),
1971            ))
1972        }
1973    }
1974
1975    // -----------------------------------------------------------------------
1976    // Phase 2: Boolean / Masking Ops
1977    // -----------------------------------------------------------------------
1978
1979    /// Element-wise conditional select: `where(condition, other)`.
1980    /// For each element, returns `self[i]` if `condition[i] != 0.0`, else `other[i]`.
1981    pub fn tensor_where(&self, condition: &Tensor, other: &Tensor) -> Result<Tensor, RuntimeError> {
1982        if self.shape() != condition.shape() || self.shape() != other.shape() {
1983            return Err(RuntimeError::InvalidOperation(
1984                format!("where: shape mismatch self={:?} cond={:?} other={:?}",
1985                    self.shape(), condition.shape(), other.shape()),
1986            ));
1987        }
1988        let s = self.to_vec();
1989        let c = condition.to_vec();
1990        let o = other.to_vec();
1991        let result: Vec<f64> = s.iter().zip(c.iter()).zip(o.iter())
1992            .map(|((&sv, &cv), &ov)| if cv != 0.0 { sv } else { ov })
1993            .collect();
1994        Tensor::from_vec(result, self.shape())
1995    }
1996
1997    /// Returns `true` if any element is non-zero.
1998    pub fn any(&self) -> bool {
1999        let data = self.to_vec();
2000        data.iter().any(|&x| x != 0.0)
2001    }
2002
2003    /// Returns `true` if all elements are non-zero.
2004    pub fn all(&self) -> bool {
2005        let data = self.to_vec();
2006        data.iter().all(|&x| x != 0.0)
2007    }
2008
2009    /// Returns a 1-D tensor of flat indices where elements are non-zero.
2010    pub fn nonzero(&self) -> Tensor {
2011        let data = self.to_vec();
2012        let indices: Vec<f64> = data.iter().enumerate()
2013            .filter(|(_, &v)| v != 0.0)
2014            .map(|(i, _)| i as f64)
2015            .collect();
2016        let len = indices.len();
2017        if len == 0 {
2018            Tensor::from_vec(vec![], &[0]).unwrap()
2019        } else {
2020            Tensor::from_vec(indices, &[len]).unwrap()
2021        }
2022    }
2023
2024    /// Fill elements where `mask` is non-zero with `value`.
2025    pub fn masked_fill(&self, mask: &Tensor, value: f64) -> Result<Tensor, RuntimeError> {
2026        if self.shape() != mask.shape() {
2027            return Err(RuntimeError::InvalidOperation(
2028                format!("masked_fill: shape mismatch self={:?} mask={:?}",
2029                    self.shape(), mask.shape()),
2030            ));
2031        }
2032        let data = self.to_vec();
2033        let m = mask.to_vec();
2034        let result: Vec<f64> = data.iter().zip(m.iter())
2035            .map(|(&d, &mv)| if mv != 0.0 { value } else { d })
2036            .collect();
2037        Tensor::from_vec(result, self.shape())
2038    }
2039
2040    // -----------------------------------------------------------------------
2041    // Phase 2: Axis Reductions with keepdim
2042    // -----------------------------------------------------------------------
2043
2044    /// Helper: generic axis reduction using BinnedAccumulator.
2045    /// `reduce_fn` takes a slice of values and returns the reduced value.
2046    fn reduce_axis<F>(&self, axis: usize, keepdim: bool, reduce_fn: F)
2047        -> Result<Tensor, RuntimeError>
2048    where
2049        F: Fn(&[f64]) -> f64,
2050    {
2051        let ndim = self.ndim();
2052        if axis >= ndim {
2053            return Err(RuntimeError::IndexOutOfBounds {
2054                index: axis,
2055                length: ndim,
2056            });
2057        }
2058
2059        let axis_len = self.shape[axis];
2060        // Build output shape
2061        let mut out_shape: Vec<usize> = self.shape.clone();
2062        out_shape[axis] = 1;
2063        let out_numel = Self::shape_numel(&out_shape);
2064        let out_strides = Self::compute_strides(&out_shape);
2065
2066        let data = self.to_vec();
2067        let mut result = Vec::with_capacity(out_numel);
2068        let mut indices = vec![0usize; ndim];
2069
2070        for out_idx in 0..out_numel {
2071            // Compute N-D index from flat output index
2072            {
2073                let mut remaining = out_idx;
2074                for d in 0..ndim {
2075                    indices[d] = remaining / out_strides[d];
2076                    remaining %= out_strides[d];
2077                }
2078            }
2079
2080            // Gather values along the reduction axis
2081            let mut vals = Vec::with_capacity(axis_len);
2082            for k in 0..axis_len {
2083                let mut flat = self.offset;
2084                for d in 0..ndim {
2085                    let idx = if d == axis { k } else { indices[d] };
2086                    flat += idx * self.strides[d];
2087                }
2088                vals.push(data[flat]);
2089            }
2090            result.push(reduce_fn(&vals));
2091        }
2092
2093        let final_shape = if keepdim {
2094            out_shape
2095        } else {
2096            // Remove the axis dimension
2097            let mut s: Vec<usize> = self.shape.iter().enumerate()
2098                .filter(|&(i, _)| i != axis)
2099                .map(|(_, &v)| v)
2100                .collect();
2101            if s.is_empty() {
2102                s.push(1); // scalar result
2103            }
2104            s
2105        };
2106
2107        Tensor::from_vec(result, &final_shape)
2108    }
2109
2110    /// Mean along an axis with optional keepdim.
2111    pub fn mean_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2112        self.reduce_axis(axis, keepdim, |vals| {
2113            let mut acc = BinnedAccumulatorF64::new();
2114            for &v in vals { acc.add(v); }
2115            acc.finalize() / vals.len() as f64
2116        })
2117    }
2118
2119    /// Max along an axis with optional keepdim. Returns (values, indices).
2120    pub fn max_axis(&self, axis: usize, keepdim: bool) -> Result<(Tensor, Tensor), RuntimeError> {
2121        let ndim = self.ndim();
2122        if axis >= ndim {
2123            return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2124        }
2125        let axis_len = self.shape[axis];
2126        let mut out_shape = self.shape.clone();
2127        out_shape[axis] = 1;
2128        let out_numel = Self::shape_numel(&out_shape);
2129        let out_strides = Self::compute_strides(&out_shape);
2130        let data = self.to_vec();
2131        let mut values = Vec::with_capacity(out_numel);
2132        let mut idx_vals = Vec::with_capacity(out_numel);
2133        let mut indices = vec![0usize; ndim];
2134
2135        for out_idx in 0..out_numel {
2136            let mut remaining = out_idx;
2137            for d in 0..ndim {
2138                indices[d] = remaining / out_strides[d];
2139                remaining %= out_strides[d];
2140            }
2141            let mut best_val = f64::NEG_INFINITY;
2142            let mut best_idx = 0usize;
2143            for k in 0..axis_len {
2144                let mut flat = self.offset;
2145                for d in 0..ndim {
2146                    let idx = if d == axis { k } else { indices[d] };
2147                    flat += idx * self.strides[d];
2148                }
2149                let v = data[flat];
2150                if v > best_val {
2151                    best_val = v;
2152                    best_idx = k;
2153                }
2154            }
2155            values.push(best_val);
2156            idx_vals.push(best_idx as f64);
2157        }
2158
2159        let final_shape = if keepdim {
2160            out_shape
2161        } else {
2162            let mut s: Vec<usize> = self.shape.iter().enumerate()
2163                .filter(|&(i, _)| i != axis).map(|(_, &v)| v).collect();
2164            if s.is_empty() { s.push(1); }
2165            s
2166        };
2167        Ok((
2168            Tensor::from_vec(values, &final_shape)?,
2169            Tensor::from_vec(idx_vals, &final_shape)?,
2170        ))
2171    }
2172
2173    /// Min along an axis with optional keepdim. Returns (values, indices).
2174    pub fn min_axis(&self, axis: usize, keepdim: bool) -> Result<(Tensor, Tensor), RuntimeError> {
2175        let ndim = self.ndim();
2176        if axis >= ndim {
2177            return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2178        }
2179        let axis_len = self.shape[axis];
2180        let mut out_shape = self.shape.clone();
2181        out_shape[axis] = 1;
2182        let out_numel = Self::shape_numel(&out_shape);
2183        let out_strides = Self::compute_strides(&out_shape);
2184        let data = self.to_vec();
2185        let mut values = Vec::with_capacity(out_numel);
2186        let mut idx_vals = Vec::with_capacity(out_numel);
2187        let mut indices = vec![0usize; ndim];
2188
2189        for out_idx in 0..out_numel {
2190            let mut remaining = out_idx;
2191            for d in 0..ndim {
2192                indices[d] = remaining / out_strides[d];
2193                remaining %= out_strides[d];
2194            }
2195            let mut best_val = f64::INFINITY;
2196            let mut best_idx = 0usize;
2197            for k in 0..axis_len {
2198                let mut flat = self.offset;
2199                for d in 0..ndim {
2200                    let idx = if d == axis { k } else { indices[d] };
2201                    flat += idx * self.strides[d];
2202                }
2203                let v = data[flat];
2204                if v < best_val {
2205                    best_val = v;
2206                    best_idx = k;
2207                }
2208            }
2209            values.push(best_val);
2210            idx_vals.push(best_idx as f64);
2211        }
2212
2213        let final_shape = if keepdim {
2214            out_shape
2215        } else {
2216            let mut s: Vec<usize> = self.shape.iter().enumerate()
2217                .filter(|&(i, _)| i != axis).map(|(_, &v)| v).collect();
2218            if s.is_empty() { s.push(1); }
2219            s
2220        };
2221        Ok((
2222            Tensor::from_vec(values, &final_shape)?,
2223            Tensor::from_vec(idx_vals, &final_shape)?,
2224        ))
2225    }
2226
2227    /// Variance along an axis with optional keepdim.
2228    pub fn var_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2229        let mean_t = self.mean_axis(axis, true)?;
2230        let ndim = self.ndim();
2231        if axis >= ndim {
2232            return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2233        }
2234        let axis_len = self.shape[axis];
2235        let mut out_shape = self.shape.clone();
2236        out_shape[axis] = 1;
2237        let out_numel = Self::shape_numel(&out_shape);
2238        let out_strides = Self::compute_strides(&out_shape);
2239        let data = self.to_vec();
2240        let mean_data = mean_t.to_vec();
2241        let mut result = Vec::with_capacity(out_numel);
2242        let mut indices = vec![0usize; ndim];
2243
2244        for out_idx in 0..out_numel {
2245            let mut remaining = out_idx;
2246            for d in 0..ndim {
2247                indices[d] = remaining / out_strides[d];
2248                remaining %= out_strides[d];
2249            }
2250            let mu = mean_data[out_idx];
2251            let mut acc = BinnedAccumulatorF64::new();
2252            for k in 0..axis_len {
2253                let mut flat = self.offset;
2254                for d in 0..ndim {
2255                    let idx = if d == axis { k } else { indices[d] };
2256                    flat += idx * self.strides[d];
2257                }
2258                let diff = data[flat] - mu;
2259                acc.add(diff * diff);
2260            }
2261            result.push(acc.finalize() / axis_len as f64);
2262        }
2263
2264        let final_shape = if keepdim {
2265            out_shape
2266        } else {
2267            let mut s: Vec<usize> = self.shape.iter().enumerate()
2268                .filter(|&(i, _)| i != axis).map(|(_, &v)| v).collect();
2269            if s.is_empty() { s.push(1); }
2270            s
2271        };
2272        Tensor::from_vec(result, &final_shape)
2273    }
2274
2275    /// Standard deviation along an axis with optional keepdim.
2276    pub fn std_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2277        let var = self.var_axis(axis, keepdim)?;
2278        Ok(var.map(|x| x.sqrt()))
2279    }
2280
2281    /// Product along an axis with optional keepdim.
2282    pub fn prod_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2283        self.reduce_axis(axis, keepdim, |vals| {
2284            // Product via exp(sum(ln(abs))) for numerical stability is overkill here;
2285            // simple product is deterministic and exact for integer-like values.
2286            let mut product = 1.0f64;
2287            for &v in vals { product *= v; }
2288            product
2289        })
2290    }
2291
2292    // -----------------------------------------------------------------------
2293    // Phase 2: Sort Operations
2294    // -----------------------------------------------------------------------
2295
2296    /// Sort along an axis (stable sort). Returns the sorted tensor.
2297    /// For N-D tensors, sorts slices along the specified axis.
2298    pub fn sort_axis(&self, axis: usize, descending: bool) -> Result<Tensor, RuntimeError> {
2299        let ndim = self.ndim();
2300        if axis >= ndim {
2301            return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2302        }
2303        let data = self.to_vec();
2304        let axis_len = self.shape[axis];
2305        let out_shape = self.shape.clone();
2306        let out_numel = Self::shape_numel(&out_shape);
2307
2308        // Build strides for iterating over all non-axis positions
2309        let mut iter_shape: Vec<usize> = Vec::new();
2310        for (i, &s) in self.shape.iter().enumerate() {
2311            if i != axis { iter_shape.push(s); }
2312        }
2313        let n_slices: usize = iter_shape.iter().product::<usize>().max(1);
2314
2315        let mut result = vec![0.0f64; out_numel];
2316
2317        // We iterate over all positions with axis index = 0
2318        let mut pos = vec![0usize; ndim];
2319        for slice_idx in 0..n_slices {
2320            // Compute the N-D position (with axis dim = 0)
2321            let mut remaining = slice_idx;
2322            let mut dim_idx = 0;
2323            for d in 0..ndim {
2324                if d == axis {
2325                    pos[d] = 0;
2326                } else {
2327                    let stride = {
2328                        let mut s = 1usize;
2329                        let mut di = 0;
2330                        for d2 in 0..ndim {
2331                            if d2 == axis { continue; }
2332                            if di > dim_idx { s *= self.shape[d2]; }
2333                            di += 1;
2334                        }
2335                        s
2336                    };
2337                    pos[d] = remaining / stride;
2338                    remaining %= stride;
2339                    dim_idx += 1;
2340                }
2341            }
2342
2343            // Gather values along axis
2344            let mut vals: Vec<(f64, usize)> = Vec::with_capacity(axis_len);
2345            for k in 0..axis_len {
2346                let mut flat = self.offset;
2347                for d in 0..ndim {
2348                    let idx = if d == axis { k } else { pos[d] };
2349                    flat += idx * self.strides[d];
2350                }
2351                vals.push((data[flat], k));
2352            }
2353
2354            // Stable sort with deterministic tie-breaking by original index
2355            if descending {
2356                vals.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
2357                    .then(a.1.cmp(&b.1)));
2358            } else {
2359                vals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
2360                    .then(a.1.cmp(&b.1)));
2361            }
2362
2363            // Scatter back
2364            for (k, &(v, _)) in vals.iter().enumerate() {
2365                let mut flat = 0;
2366                let out_strides_local = Self::compute_strides(&out_shape);
2367                for d in 0..ndim {
2368                    let idx = if d == axis { k } else { pos[d] };
2369                    flat += idx * out_strides_local[d];
2370                }
2371                result[flat] = v;
2372            }
2373        }
2374
2375        Tensor::from_vec(result, &out_shape)
2376    }
2377
2378    /// N-D argsort along an axis. Returns indices tensor.
2379    pub fn argsort_axis(&self, axis: usize, descending: bool) -> Result<Tensor, RuntimeError> {
2380        let ndim = self.ndim();
2381        if axis >= ndim {
2382            return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2383        }
2384        let data = self.to_vec();
2385        let axis_len = self.shape[axis];
2386        let out_shape = self.shape.clone();
2387        let out_numel = Self::shape_numel(&out_shape);
2388
2389        let mut iter_shape: Vec<usize> = Vec::new();
2390        for (i, &s) in self.shape.iter().enumerate() {
2391            if i != axis { iter_shape.push(s); }
2392        }
2393        let n_slices: usize = iter_shape.iter().product::<usize>().max(1);
2394
2395        let mut result = vec![0.0f64; out_numel];
2396        let mut pos = vec![0usize; ndim];
2397
2398        for slice_idx in 0..n_slices {
2399            let mut remaining = slice_idx;
2400            let mut dim_idx = 0;
2401            for d in 0..ndim {
2402                if d == axis {
2403                    pos[d] = 0;
2404                } else {
2405                    let stride = {
2406                        let mut s = 1usize;
2407                        let mut di = 0;
2408                        for d2 in 0..ndim {
2409                            if d2 == axis { continue; }
2410                            if di > dim_idx { s *= self.shape[d2]; }
2411                            di += 1;
2412                        }
2413                        s
2414                    };
2415                    pos[d] = remaining / stride;
2416                    remaining %= stride;
2417                    dim_idx += 1;
2418                }
2419            }
2420
2421            let mut vals: Vec<(f64, usize)> = Vec::with_capacity(axis_len);
2422            for k in 0..axis_len {
2423                let mut flat = self.offset;
2424                for d in 0..ndim {
2425                    let idx = if d == axis { k } else { pos[d] };
2426                    flat += idx * self.strides[d];
2427                }
2428                vals.push((data[flat], k));
2429            }
2430
2431            if descending {
2432                vals.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
2433                    .then(a.1.cmp(&b.1)));
2434            } else {
2435                vals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
2436                    .then(a.1.cmp(&b.1)));
2437            }
2438
2439            for (k, &(_, orig_idx)) in vals.iter().enumerate() {
2440                let out_strides_local = Self::compute_strides(&out_shape);
2441                let mut flat = 0;
2442                for d in 0..ndim {
2443                    let idx = if d == axis { k } else { pos[d] };
2444                    flat += idx * out_strides_local[d];
2445                }
2446                result[flat] = orig_idx as f64;
2447            }
2448        }
2449
2450        Tensor::from_vec(result, &out_shape)
2451    }
2452
2453    // -----------------------------------------------------------------------
2454    // Phase 2: Einsum
2455    // -----------------------------------------------------------------------
2456
2457    /// Einstein summation notation.
2458    /// Supports patterns like "ij,jk->ik" (matmul), "ii->i" (diagonal),
2459    /// "ij->ji" (transpose), "ijk,ikl->ijl" (batched matmul).
2460    /// Uses BinnedAccumulator for all reductions.
2461    pub fn einsum(notation: &str, inputs: &[&Tensor]) -> Result<Tensor, RuntimeError> {
2462        // Parse notation: "subscripts->output" or just "subscripts" (implicit)
2463        let parts: Vec<&str> = notation.split("->").collect();
2464        if parts.len() != 2 {
2465            return Err(RuntimeError::InvalidOperation(
2466                format!("einsum: expected 'subscripts->output' notation, got '{}'", notation),
2467            ));
2468        }
2469        let input_specs: Vec<&str> = parts[0].split(',').collect();
2470        let output_spec = parts[1];
2471
2472        if input_specs.len() != inputs.len() {
2473            return Err(RuntimeError::InvalidOperation(
2474                format!("einsum: {} input specs but {} tensors", input_specs.len(), inputs.len()),
2475            ));
2476        }
2477
2478        // Build label → size mapping
2479        let mut label_size = std::collections::BTreeMap::new();
2480        for (i, &spec) in input_specs.iter().enumerate() {
2481            let chars: Vec<char> = spec.chars().collect();
2482            if chars.len() != inputs[i].ndim() {
2483                return Err(RuntimeError::InvalidOperation(
2484                    format!("einsum: spec '{}' has {} dims but tensor has {}", spec, chars.len(), inputs[i].ndim()),
2485                ));
2486            }
2487            for (d, &c) in chars.iter().enumerate() {
2488                let sz = inputs[i].shape()[d];
2489                if let Some(&prev) = label_size.get(&c) {
2490                    if prev != sz {
2491                        return Err(RuntimeError::InvalidOperation(
2492                            format!("einsum: label '{}' has conflicting sizes {} vs {}", c, prev, sz),
2493                        ));
2494                    }
2495                } else {
2496                    label_size.insert(c, sz);
2497                }
2498            }
2499        }
2500
2501        // Determine output shape
2502        let output_chars: Vec<char> = output_spec.chars().collect();
2503        let output_shape: Vec<usize> = output_chars.iter()
2504            .map(|c| label_size.get(c).copied().ok_or_else(||
2505                RuntimeError::InvalidOperation(format!("einsum: unknown label '{}' in output", c))))
2506            .collect::<Result<_, _>>()?;
2507        let out_numel = Self::shape_numel(&output_shape);
2508
2509        // Determine contraction labels (in input but not output)
2510        let output_set: std::collections::BTreeSet<char> = output_chars.iter().copied().collect();
2511        let contract_labels: Vec<char> = label_size.keys()
2512            .filter(|c| !output_set.contains(c))
2513            .copied()
2514            .collect();
2515        let contract_sizes: Vec<usize> = contract_labels.iter()
2516            .map(|c| label_size[c])
2517            .collect();
2518        let contract_numel: usize = contract_sizes.iter().product::<usize>().max(1);
2519
2520        // Precompute input spec chars
2521        let input_chars: Vec<Vec<char>> = input_specs.iter().map(|s| s.chars().collect()).collect();
2522
2523        // For each output position, iterate over contraction indices
2524        let out_strides = Self::compute_strides(&output_shape);
2525        let mut result = vec![0.0f64; out_numel];
2526
2527        // Pre-read input data
2528        let input_data: Vec<Vec<f64>> = inputs.iter().map(|t| t.to_vec()).collect();
2529        let input_strides: Vec<Vec<usize>> = inputs.iter().map(|t| t.strides.clone()).collect();
2530        let input_offsets: Vec<usize> = inputs.iter().map(|t| t.offset).collect();
2531
2532        for out_idx in 0..out_numel {
2533            // Compute output label values
2534            let mut label_vals = std::collections::BTreeMap::new();
2535            let mut remaining = out_idx;
2536            for (d, &c) in output_chars.iter().enumerate() {
2537                let stride = if d < out_strides.len() { out_strides[d] } else { 1 };
2538                label_vals.insert(c, remaining / stride);
2539                remaining %= stride;
2540            }
2541
2542            let mut acc = BinnedAccumulatorF64::new();
2543            // Iterate over all contraction index combinations
2544            for cidx in 0..contract_numel {
2545                // Compute contraction label values
2546                let mut cr = cidx;
2547                for (ci, &cl) in contract_labels.iter().enumerate() {
2548                    let stride: usize = contract_sizes[ci+1..].iter().product::<usize>().max(1);
2549                    label_vals.insert(cl, cr / stride);
2550                    cr %= stride;
2551                }
2552
2553                // Compute product of input elements
2554                let mut product = 1.0f64;
2555                for (inp_idx, chars) in input_chars.iter().enumerate() {
2556                    let mut flat = input_offsets[inp_idx];
2557                    for (d, &c) in chars.iter().enumerate() {
2558                        flat += label_vals[&c] * input_strides[inp_idx][d];
2559                    }
2560                    product *= input_data[inp_idx][flat];
2561                }
2562                acc.add(product);
2563            }
2564            result[out_idx] = acc.finalize();
2565        }
2566
2567        if output_shape.is_empty() {
2568            Tensor::from_vec(result, &[1])
2569        } else {
2570            Tensor::from_vec(result, &output_shape)
2571        }
2572    }
2573
2574    // -----------------------------------------------------------------------
2575    // Phase 2: Reshape / View Enhancements
2576    // -----------------------------------------------------------------------
2577
2578    /// Add a dimension of size 1 at position `dim`.
2579    pub fn unsqueeze(&self, dim: usize) -> Result<Tensor, RuntimeError> {
2580        let ndim = self.ndim();
2581        if dim > ndim {
2582            return Err(RuntimeError::IndexOutOfBounds { index: dim, length: ndim + 1 });
2583        }
2584        let mut new_shape = self.shape.clone();
2585        new_shape.insert(dim, 1);
2586        self.reshape(&new_shape)
2587    }
2588
2589    /// Remove a dimension of size 1 at position `dim`.
2590    /// If `dim` is `None`, removes all dimensions of size 1.
2591    pub fn squeeze(&self, dim: Option<usize>) -> Result<Tensor, RuntimeError> {
2592        match dim {
2593            Some(d) => {
2594                if d >= self.ndim() {
2595                    return Err(RuntimeError::IndexOutOfBounds { index: d, length: self.ndim() });
2596                }
2597                if self.shape[d] != 1 {
2598                    return Err(RuntimeError::InvalidOperation(
2599                        format!("squeeze: dimension {} has size {}, not 1", d, self.shape[d]),
2600                    ));
2601                }
2602                let mut new_shape = self.shape.clone();
2603                new_shape.remove(d);
2604                if new_shape.is_empty() {
2605                    new_shape.push(1); // scalar
2606                }
2607                self.reshape(&new_shape)
2608            }
2609            None => {
2610                let new_shape: Vec<usize> = self.shape.iter()
2611                    .filter(|&&s| s != 1)
2612                    .copied()
2613                    .collect();
2614                let new_shape = if new_shape.is_empty() { vec![1] } else { new_shape };
2615                self.reshape(&new_shape)
2616            }
2617        }
2618    }
2619
2620    /// Broadcast without copying. Returns a view with stride=0 for broadcasted dims.
2621    /// Same as `broadcast_to` but named for consistency with the gap-fix plan.
2622    pub fn expand(&self, target_shape: &[usize]) -> Result<Tensor, RuntimeError> {
2623        self.broadcast_to(target_shape)
2624    }
2625
2626    /// Flatten a range of dimensions [start_dim, end_dim] into a single dimension.
2627    pub fn flatten(&self, start_dim: usize, end_dim: usize) -> Result<Tensor, RuntimeError> {
2628        if start_dim > end_dim || end_dim >= self.ndim() {
2629            return Err(RuntimeError::InvalidOperation(
2630                format!("flatten: invalid dim range [{}, {}] for {}D tensor", start_dim, end_dim, self.ndim()),
2631            ));
2632        }
2633        let mut new_shape = Vec::new();
2634        for i in 0..start_dim {
2635            new_shape.push(self.shape[i]);
2636        }
2637        let flat_size: usize = self.shape[start_dim..=end_dim].iter().product();
2638        new_shape.push(flat_size);
2639        for i in (end_dim + 1)..self.ndim() {
2640            new_shape.push(self.shape[i]);
2641        }
2642        self.reshape(&new_shape)
2643    }
2644
2645    /// Split tensor into `n` roughly equal chunks along dimension `dim`.
2646    pub fn chunk(&self, n: usize, dim: usize) -> Result<Vec<Tensor>, RuntimeError> {
2647        if dim >= self.ndim() {
2648            return Err(RuntimeError::IndexOutOfBounds { index: dim, length: self.ndim() });
2649        }
2650        if n == 0 {
2651            return Err(RuntimeError::InvalidOperation("chunk: n must be > 0".into()));
2652        }
2653        let dim_size = self.shape[dim];
2654        let chunk_size = (dim_size + n - 1) / n;
2655        let mut sizes = Vec::new();
2656        let mut remaining = dim_size;
2657        while remaining > 0 {
2658            let s = remaining.min(chunk_size);
2659            sizes.push(s);
2660            remaining -= s;
2661        }
2662        self.split(&sizes, dim)
2663    }
2664
2665    /// Split tensor along dimension `dim` according to the given sizes.
2666    pub fn split(&self, sizes: &[usize], dim: usize) -> Result<Vec<Tensor>, RuntimeError> {
2667        if dim >= self.ndim() {
2668            return Err(RuntimeError::IndexOutOfBounds { index: dim, length: self.ndim() });
2669        }
2670        let total: usize = sizes.iter().sum();
2671        if total != self.shape[dim] {
2672            return Err(RuntimeError::InvalidOperation(
2673                format!("split: sizes sum {} != dim size {}", total, self.shape[dim]),
2674            ));
2675        }
2676
2677        let mut results = Vec::new();
2678        let mut offset = 0;
2679
2680        for &sz in sizes {
2681            let ranges: Vec<(usize, usize)> = self.shape.iter()
2682                .enumerate()
2683                .map(|(i, &s)| {
2684                    if i == dim { (offset, offset + sz) } else { (0, s) }
2685                })
2686                .collect();
2687            let chunk = self.slice(&ranges)?;
2688            // Materialize as contiguous
2689            results.push(chunk.to_contiguous());
2690            offset += sz;
2691        }
2692
2693        Ok(results)
2694    }
2695
2696    /// Fused `alpha * self + beta * other` element-wise. Single pass, one allocation.
2697    ///
2698    /// Critical for LSTM/GRU gates where `f * c_prev + i * g` would otherwise
2699    /// create 3 intermediate tensors.
2700    pub fn scale_add(&self, alpha: f64, other: &Tensor, beta: f64) -> Result<Tensor, RuntimeError> {
2701        if self.shape != other.shape {
2702            return Err(RuntimeError::InvalidOperation(
2703                "scale_add: shape mismatch".to_string(),
2704            ));
2705        }
2706        let a = self.to_vec();
2707        let b = other.to_vec();
2708        let result: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| alpha * x + beta * y).collect();
2709        Tensor::from_vec(result, &self.shape)
2710    }
2711}
2712