Skip to main content

cjc_runtime/
tensor.rs

1//! N-dimensional tensor with element-wise, reduction, linalg, and NN operations.
2//!
3//! [`Tensor`] is the primary numerical type in CJC. It is backed by a
4//! [`Buffer<f64>`] with COW (copy-on-write) semantics, so cloning a tensor
5//! is O(1) and mutation triggers a deep copy only when the buffer is shared.
6//!
7//! # Determinism Guarantees
8//!
9//! - **Reductions** (`sum`, `mean`, `sum_axis`) use [`BinnedAccumulatorF64`]
10//!   for order-invariant, bit-identical results.
11//! - **Matmul** uses Kahan-compensated accumulation (sequential path) or
12//!   tiled + parallel strategies (large matrices) with deterministic per-row
13//!   thread assignment.
14//! - **SIMD** kernels (via [`tensor_simd`]) avoid hardware FMA to preserve
15//!   cross-platform bit-identity.
16//! - **No** `HashMap`/`HashSet` anywhere -- all ordering is deterministic.
17//!
18//! # Layout
19//!
20//! Tensors use row-major (C-order) layout with explicit strides. Non-contiguous
21//! views (from `slice`, `transpose`, `broadcast_to`) share the underlying
22//! buffer with adjusted strides and offset. Call [`Tensor::to_contiguous`] to
23//! materialize a contiguous copy when needed.
24//!
25//! # Operation Categories
26//!
27//! | Category | Methods |
28//! |----------|---------|
29//! | Construction | `zeros`, `ones`, `randn`, `from_vec`, `from_bytes` |
30//! | Shape | `shape`, `ndim`, `len`, `reshape`, `transpose`, `slice`, `unsqueeze`, `squeeze`, `flatten` |
31//! | Element-wise | `add`, `sub`, `mul_elem`, `div_elem`, `elem_pow`, `map`, `map_simd` |
32//! | Reductions | `sum`, `mean`, `sum_axis`, `mean_axis`, `var_axis`, `std_axis` |
33//! | Linalg | `matmul`, `bmm`, `linear`, `einsum` |
34//! | NN | `softmax`, `layer_norm`, `relu`, `sigmoid`, `gelu`, `conv1d`, `conv2d`, `maxpool2d` |
35//! | Indexing | `get`, `set`, `gather`, `scatter`, `index_select`, `argsort` |
36//! | Attention | `scaled_dot_product_attention`, `split_heads`, `merge_heads` |
37//!
38//! [`Buffer<f64>`]: crate::buffer::Buffer
39//! [`BinnedAccumulatorF64`]: crate::accumulator::BinnedAccumulatorF64
40//! [`tensor_simd`]: crate::tensor_simd
41
42use cjc_repro::Rng;
43
44use crate::accumulator::{binned_sum_f64, BinnedAccumulatorF64};
45use cjc_repro::KahanAccumulatorF64;
46
47use crate::accumulator;
48use crate::buffer::Buffer;
49use crate::dispatch;
50use crate::error::RuntimeError;
51use crate::kernel as kernel_fns;
52use crate::tensor_simd::{self, BinOp, UnaryOp};
53use crate::tensor_tiled::TiledMatmul;
54
55// ---------------------------------------------------------------------------
56// 2. Tensor Runtime
57// ---------------------------------------------------------------------------
58
59/// An N-dimensional tensor backed by a [`Buffer<f64>`](crate::buffer::Buffer).
60///
61/// Supports element-wise arithmetic (SIMD-accelerated), matrix multiplication
62/// (tiled + parallel), numerically-stable reductions via
63/// [`BinnedAccumulatorF64`](crate::accumulator::BinnedAccumulatorF64),
64/// and neural network operations (softmax, layer norm, attention).
65///
66/// # Memory
67///
68/// The underlying [`Buffer<f64>`](crate::buffer::Buffer) uses copy-on-write
69/// semantics. Cloning a `Tensor` is O(1). Operations like `reshape` and
70/// `transpose` return zero-copy views when possible. Mutation via `set`
71/// triggers a deep copy only when the buffer is shared.
72#[derive(Debug, Clone)]
73pub struct Tensor {
74    /// The underlying COW data buffer.
75    pub buffer: Buffer<f64>,
76    /// Dimensions of the tensor (e.g., `[3, 4]` for a 3x4 matrix).
77    pub(crate) shape: Vec<usize>,
78    /// Row-major strides for each dimension. Non-contiguous views may have
79    /// strides that differ from the standard row-major pattern (e.g., stride=0
80    /// for broadcast dimensions).
81    pub(crate) strides: Vec<usize>,
82    /// Byte offset into the buffer where this tensor's data begins.
83    /// Non-zero for slice views.
84    pub(crate) offset: usize,
85}
86
87impl Tensor {
88    // -- Construction -------------------------------------------------------
89
90    /// Compute row-major strides for a given shape.
91    pub(crate) fn compute_strides(shape: &[usize]) -> Vec<usize> {
92        let mut strides = vec![1usize; shape.len()];
93        for i in (0..shape.len().saturating_sub(1)).rev() {
94            strides[i] = strides[i + 1] * shape[i + 1];
95        }
96        strides
97    }
98
99    /// Total number of elements implied by `shape`.
100    fn shape_numel(shape: &[usize]) -> usize {
101        shape.iter().product()
102    }
103
104    /// Create a tensor filled with zeros.
105    pub fn zeros(shape: &[usize]) -> Self {
106        let numel = Self::shape_numel(shape);
107        Tensor {
108            buffer: Buffer::alloc(numel, 0.0),
109            shape: shape.to_vec(),
110            strides: Self::compute_strides(shape),
111            offset: 0,
112        }
113    }
114
115    /// Create a tensor filled with ones.
116    pub fn ones(shape: &[usize]) -> Self {
117        let numel = Self::shape_numel(shape);
118        Tensor {
119            buffer: Buffer::alloc(numel, 1.0),
120            shape: shape.to_vec(),
121            strides: Self::compute_strides(shape),
122            offset: 0,
123        }
124    }
125
126    /// Create a tensor filled with samples from the standard normal
127    /// distribution, drawn deterministically from `rng`.
128    pub fn randn(shape: &[usize], rng: &mut Rng) -> Self {
129        let numel = Self::shape_numel(shape);
130        let data: Vec<f64> = (0..numel).map(|_| rng.next_normal_f64()).collect();
131        Tensor {
132            buffer: Buffer::from_vec(data),
133            shape: shape.to_vec(),
134            strides: Self::compute_strides(shape),
135            offset: 0,
136        }
137    }
138
139    /// Create a tensor from raw data and a shape. Returns an error if the
140    /// number of elements does not match the shape.
141    pub fn from_vec(data: Vec<f64>, shape: &[usize]) -> Result<Self, RuntimeError> {
142        let numel = Self::shape_numel(shape);
143        if data.len() != numel {
144            return Err(RuntimeError::ShapeMismatch {
145                expected: numel,
146                got: data.len(),
147            });
148        }
149        Ok(Tensor {
150            buffer: Buffer::from_vec(data),
151            shape: shape.to_vec(),
152            strides: Self::compute_strides(shape),
153            offset: 0,
154        })
155    }
156
157    // -- Accessors ----------------------------------------------------------
158
159    /// The shape of this tensor.
160    pub fn shape(&self) -> &[usize] {
161        &self.shape
162    }
163
164    /// Number of dimensions.
165    pub fn ndim(&self) -> usize {
166        self.shape.len()
167    }
168
169    /// Total number of elements.
170    pub fn len(&self) -> usize {
171        Self::shape_numel(&self.shape)
172    }
173
174    /// Whether the tensor has zero elements.
175    pub fn is_empty(&self) -> bool {
176        self.len() == 0
177    }
178
179    /// Flatten a multi-dimensional index to a linear offset in the buffer.
180    fn linear_index(&self, indices: &[usize]) -> Result<usize, RuntimeError> {
181        if indices.len() != self.shape.len() {
182            return Err(RuntimeError::DimensionMismatch {
183                expected: self.shape.len(),
184                got: indices.len(),
185            });
186        }
187        let mut off = self.offset;
188        for (i, &idx) in indices.iter().enumerate() {
189            if idx >= self.shape[i] {
190                return Err(RuntimeError::IndexOutOfBounds {
191                    index: idx,
192                    length: self.shape[i],
193                });
194            }
195            off += idx * self.strides[i];
196        }
197        Ok(off)
198    }
199
200    /// Whether this tensor is contiguous in memory (row-major, no offset).
201    pub fn is_contiguous(&self) -> bool {
202        if self.offset != 0 {
203            return false;
204        }
205        let expected = Self::compute_strides(&self.shape);
206        self.strides == expected
207    }
208
209    /// Create a zero-copy slice (view) of this tensor.
210    /// `ranges` contains `(start, end)` for each dimension.
211    pub fn slice(&self, ranges: &[(usize, usize)]) -> Result<Tensor, RuntimeError> {
212        if ranges.len() != self.shape.len() {
213            return Err(RuntimeError::DimensionMismatch {
214                expected: self.shape.len(),
215                got: ranges.len(),
216            });
217        }
218        let mut new_offset = self.offset;
219        let mut new_shape = Vec::with_capacity(ranges.len());
220        for (i, &(start, end)) in ranges.iter().enumerate() {
221            if end > self.shape[i] || start > end {
222                return Err(RuntimeError::IndexOutOfBounds {
223                    index: end,
224                    length: self.shape[i],
225                });
226            }
227            new_offset += start * self.strides[i];
228            new_shape.push(end - start);
229        }
230        Ok(Tensor {
231            buffer: self.buffer.clone(), // shared — zero copy
232            shape: new_shape,
233            strides: self.strides.clone(),
234            offset: new_offset,
235        })
236    }
237
238    /// Materialize a contiguous copy if this tensor is non-contiguous.
239    pub fn to_contiguous(&self) -> Tensor {
240        if self.is_contiguous() {
241            return self.clone();
242        }
243        let data = self.to_vec();
244        Tensor {
245            buffer: Buffer::from_vec(data),
246            shape: self.shape.clone(),
247            strides: Self::compute_strides(&self.shape),
248            offset: 0,
249        }
250    }
251
252    /// Create a broadcast view of this tensor to `target_shape`.
253    /// Uses stride=0 for dimensions that need broadcasting (size 1 -> target size).
254    pub fn broadcast_to(&self, target_shape: &[usize]) -> Result<Tensor, RuntimeError> {
255        let src_ndim = self.shape.len();
256        let tgt_ndim = target_shape.len();
257        if tgt_ndim < src_ndim {
258            return Err(RuntimeError::InvalidOperation(
259                "cannot broadcast to a smaller rank".to_string(),
260            ));
261        }
262        let pad = tgt_ndim - src_ndim;
263        let mut new_strides = vec![0usize; tgt_ndim];
264        for i in 0..tgt_ndim {
265            if i < pad {
266                // Padded dimension: stride = 0 (broadcast)
267                new_strides[i] = 0;
268            } else {
269                let src_i = i - pad;
270                if self.shape[src_i] == target_shape[i] {
271                    new_strides[i] = self.strides[src_i];
272                } else if self.shape[src_i] == 1 {
273                    new_strides[i] = 0; // broadcast
274                } else {
275                    return Err(RuntimeError::ShapeMismatch {
276                        expected: target_shape[i],
277                        got: self.shape[src_i],
278                    });
279                }
280            }
281        }
282        Ok(Tensor {
283            buffer: self.buffer.clone(),
284            shape: target_shape.to_vec(),
285            strides: new_strides,
286            offset: self.offset,
287        })
288    }
289
290    /// Read the element at the given multi-dimensional index.
291    pub fn get(&self, indices: &[usize]) -> Result<f64, RuntimeError> {
292        let offset = self.linear_index(indices)?;
293        self.buffer
294            .get(offset)
295            .ok_or(RuntimeError::IndexOutOfBounds {
296                index: offset,
297                length: self.buffer.len(),
298            })
299    }
300
301    /// Write the element at the given multi-dimensional index.
302    pub fn set(&mut self, indices: &[usize], val: f64) -> Result<(), RuntimeError> {
303        let offset = self.linear_index(indices)?;
304        self.buffer.set(offset, val)
305    }
306
307    /// Extract the raw data as a `Vec<f64>`, respecting strides and offset.
308    pub fn to_vec(&self) -> Vec<f64> {
309        if self.is_contiguous() {
310            let full = self.buffer.borrow_data();
311            let numel = self.len();
312            if full.len() == numel {
313                return full.to_vec();
314            }
315            // Buffer may be larger than the tensor's logical size
316            // (e.g. Scratchpad pre-allocates extra capacity)
317            return full[..numel].to_vec();
318        }
319        // Non-contiguous: iterate via strided access
320        let numel = self.len();
321        let mut result = Vec::with_capacity(numel);
322        let ndim = self.shape.len();
323        let mut indices = vec![0usize; ndim];
324        for _ in 0..numel {
325            let mut off = self.offset;
326            for d in 0..ndim {
327                off += indices[d] * self.strides[d];
328            }
329            result.push(self.buffer.get(off).unwrap_or(0.0));
330            // Increment multi-index (row-major order)
331            for d in (0..ndim).rev() {
332                indices[d] += 1;
333                if indices[d] < self.shape[d] {
334                    break;
335                }
336                indices[d] = 0;
337            }
338        }
339        result
340    }
341
342    // -- Reshape ------------------------------------------------------------
343
344    /// Reshape to `new_shape`. The new shape must have the same total number
345    /// of elements. The returned tensor **shares** the underlying buffer.
346    pub fn reshape(&self, new_shape: &[usize]) -> Result<Tensor, RuntimeError> {
347        let new_numel = Self::shape_numel(new_shape);
348        if new_numel != self.len() {
349            return Err(RuntimeError::ShapeMismatch {
350                expected: self.len(),
351                got: new_numel,
352            });
353        }
354        // Reshape requires contiguous data; materialize if needed
355        let tensor = if self.is_contiguous() { self.clone() } else { self.to_contiguous() };
356        Ok(Tensor {
357            buffer: tensor.buffer,
358            shape: new_shape.to_vec(),
359            strides: Self::compute_strides(new_shape),
360            offset: 0,
361        })
362    }
363
364    // -- Element-wise operations --------------------------------------------
365
366    /// Apply a binary operation element-wise with broadcasting support.
367    fn elementwise_binop(
368        &self,
369        other: &Tensor,
370        op: impl Fn(f64, f64) -> f64,
371    ) -> Result<Tensor, RuntimeError> {
372        if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
373            // Fast path: same shape, both contiguous — borrow without cloning
374            let a = self.buffer.borrow_data();
375            let b = other.buffer.borrow_data();
376            let data: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| op(x, y)).collect();
377            return Ok(Tensor {
378                buffer: Buffer::from_vec(data),
379                shape: self.shape.clone(),
380                strides: Self::compute_strides(&self.shape),
381                offset: 0,
382            });
383        }
384
385        // Broadcasting path: compute result shape
386        let result_shape = Self::broadcast_result_shape(&self.shape, &other.shape)?;
387        let a_broadcast = self.broadcast_to(&result_shape)?;
388        let b_broadcast = other.broadcast_to(&result_shape)?;
389
390        let numel = Self::shape_numel(&result_shape);
391        let ndim = result_shape.len();
392        let mut data = Vec::with_capacity(numel);
393        let mut indices = vec![0usize; ndim];
394
395        for _ in 0..numel {
396            let mut off_a = a_broadcast.offset;
397            let mut off_b = b_broadcast.offset;
398            for d in 0..ndim {
399                off_a += indices[d] * a_broadcast.strides[d];
400                off_b += indices[d] * b_broadcast.strides[d];
401            }
402            let va = a_broadcast.buffer.get(off_a).ok_or_else(|| {
403                RuntimeError::InvalidOperation(format!(
404                    "broadcast binop: left operand index {} out of bounds (buffer len {})",
405                    off_a,
406                    a_broadcast.buffer.len()
407                ))
408            })?;
409            let vb = b_broadcast.buffer.get(off_b).ok_or_else(|| {
410                RuntimeError::InvalidOperation(format!(
411                    "broadcast binop: right operand index {} out of bounds (buffer len {})",
412                    off_b,
413                    b_broadcast.buffer.len()
414                ))
415            })?;
416            data.push(op(va, vb));
417
418            for d in (0..ndim).rev() {
419                indices[d] += 1;
420                if indices[d] < result_shape[d] {
421                    break;
422                }
423                indices[d] = 0;
424            }
425        }
426
427        Ok(Tensor {
428            buffer: Buffer::from_vec(data),
429            shape: result_shape.clone(),
430            strides: Self::compute_strides(&result_shape),
431            offset: 0,
432        })
433    }
434
435    /// Compute the broadcast result shape for two shapes (NumPy rules).
436    fn broadcast_result_shape(a: &[usize], b: &[usize]) -> Result<Vec<usize>, RuntimeError> {
437        let max_ndim = a.len().max(b.len());
438        let mut result = Vec::with_capacity(max_ndim);
439        for i in 0..max_ndim {
440            let da = if i < max_ndim - a.len() { 1 } else { a[i - (max_ndim - a.len())] };
441            let db = if i < max_ndim - b.len() { 1 } else { b[i - (max_ndim - b.len())] };
442            if da == db {
443                result.push(da);
444            } else if da == 1 {
445                result.push(db);
446            } else if db == 1 {
447                result.push(da);
448            } else {
449                return Err(RuntimeError::ShapeMismatch {
450                    expected: da,
451                    got: db,
452                });
453            }
454        }
455        Ok(result)
456    }
457
458    /// SIMD-accelerated element-wise binary operation for known ops.
459    ///
460    /// For same-shape contiguous tensors, uses AVX2 (4-wide f64) when available.
461    /// Falls back to the generic closure path for broadcast cases.
462    fn elementwise_binop_simd(
463        &self,
464        other: &Tensor,
465        op: BinOp,
466        fallback: impl Fn(f64, f64) -> f64,
467    ) -> Result<Tensor, RuntimeError> {
468        if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
469            // SIMD fast path: same shape, both contiguous
470            let a = self.buffer.borrow_data();
471            let b = other.buffer.borrow_data();
472            let data = tensor_simd::simd_binop(&a, &b, op);
473            return Ok(Tensor {
474                buffer: Buffer::from_vec(data),
475                shape: self.shape.clone(),
476                strides: Self::compute_strides(&self.shape),
477                offset: 0,
478            });
479        }
480        // Broadcast path: fall through to generic
481        self.elementwise_binop(other, fallback)
482    }
483
484    /// Element-wise addition (SIMD-accelerated for contiguous same-shape tensors).
485    pub fn add(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
486        self.elementwise_binop_simd(other, BinOp::Add, |a, b| a + b)
487    }
488
489    /// Element-wise subtraction (SIMD-accelerated for contiguous same-shape tensors).
490    pub fn sub(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
491        self.elementwise_binop_simd(other, BinOp::Sub, |a, b| a - b)
492    }
493
494    /// Element-wise (Hadamard) multiplication (SIMD-accelerated for contiguous same-shape tensors).
495    pub fn mul_elem(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
496        self.elementwise_binop_simd(other, BinOp::Mul, |a, b| a * b)
497    }
498
499    /// Element-wise division (SIMD-accelerated for contiguous same-shape tensors).
500    pub fn div_elem(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
501        self.elementwise_binop_simd(other, BinOp::Div, |a, b| a / b)
502    }
503
504    /// Fused multiply-add: `self * b + c` element-wise in a single pass.
505    ///
506    /// Eliminates the intermediate tensor that separate mul + add would create.
507    /// Uses software FMA (`a * b + c` with two roundings, not hardware FMA)
508    /// to preserve bit-identity with the non-fused path.
509    pub fn fused_mul_add(&self, b: &Tensor, c: &Tensor) -> Result<Tensor, RuntimeError> {
510        if self.shape != b.shape || self.shape != c.shape {
511            return Err(RuntimeError::InvalidOperation(
512                "broadcast_fma: all three tensors must have the same shape".to_string(),
513            ));
514        }
515        if self.is_contiguous() && b.is_contiguous() && c.is_contiguous() {
516            let a_data = self.buffer.borrow_data();
517            let b_data = b.buffer.borrow_data();
518            let c_data = c.buffer.borrow_data();
519            let n = a_data.len();
520            let mut out = vec![0.0f64; n];
521            // Software FMA: a*b + c (two roundings — NOT hardware FMA which uses one rounding).
522            // This produces identical results to separate broadcast2("mul") + broadcast2("add").
523            for i in 0..n {
524                out[i] = a_data[i] * b_data[i] + c_data[i];
525            }
526            return Ok(Tensor {
527                buffer: Buffer::from_vec(out),
528                shape: self.shape.clone(),
529                strides: Self::compute_strides(&self.shape),
530                offset: 0,
531            });
532        }
533        // Non-contiguous fallback: mul then add
534        let temp = self.mul_elem(b)?;
535        temp.add(c)
536    }
537
538    // ── v0.1 Broadcasting: additional element-wise binary ops ──
539
540    /// Element-wise power: `a^b`.
541    pub fn elem_pow(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
542        self.elementwise_binop(other, |a, b| a.powf(b))
543    }
544
545    /// Element-wise minimum.
546    pub fn elem_min(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
547        self.elementwise_binop(other, |a, b| a.min(b))
548    }
549
550    /// Element-wise maximum.
551    pub fn elem_max(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
552        self.elementwise_binop(other, |a, b| a.max(b))
553    }
554
555    /// Element-wise atan2(self, other).
556    pub fn elem_atan2(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
557        self.elementwise_binop(other, |a, b| a.atan2(b))
558    }
559
560    /// Element-wise hypot(self, other).
561    pub fn elem_hypot(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
562        self.elementwise_binop(other, |a, b| a.hypot(b))
563    }
564
565    /// Apply a unary function to every element, returning a new contiguous tensor.
566    pub fn map(&self, f: impl Fn(f64) -> f64) -> Tensor {
567        let data: Vec<f64> = self.to_vec().iter().map(|&x| f(x)).collect();
568        Tensor {
569            buffer: Buffer::from_vec(data),
570            shape: self.shape.clone(),
571            strides: Self::compute_strides(&self.shape),
572            offset: 0,
573        }
574    }
575
576    /// SIMD-accelerated unary map for known operations (sqrt, abs, neg, relu).
577    ///
578    /// Uses AVX2 (4-wide f64) when available, scalar fallback otherwise.
579    /// Bit-identical to `map(f)` for the supported operations.
580    pub fn map_simd(&self, op: UnaryOp) -> Tensor {
581        let src = self.to_vec();
582        let data = tensor_simd::simd_unary(&src, op);
583        Tensor {
584            buffer: Buffer::from_vec(data),
585            shape: self.shape.clone(),
586            strides: Self::compute_strides(&self.shape),
587            offset: 0,
588        }
589    }
590
591    // -- Reductions (using BinnedAccumulator) --------------------------------
592
593    /// Sum of all elements (binned accumulation — order-invariant, deterministic).
594    pub fn sum(&self) -> f64 {
595        let data = self.buffer.borrow_data();
596        binned_sum_f64(&data)
597    }
598
599    /// Sum of all elements using BinnedAccumulator (order-invariant, deterministic).
600    ///
601    /// Bit-identical results regardless of element ordering or reduction schedule.
602    pub fn binned_sum(&self) -> f64 {
603        let data = self.buffer.borrow_data();
604        accumulator::binned_sum_f64(&data)
605    }
606
607    /// Sum with dispatched strategy based on execution context.
608    ///
609    /// Uses Kahan in serial mode, Binned in parallel/@nogc/strict/linalg mode.
610    pub fn dispatched_sum(&self, ctx: &dispatch::ReductionContext) -> f64 {
611        let data = self.buffer.borrow_data();
612        dispatch::dispatch_sum_f64(&data, ctx)
613    }
614
615    /// Mean of all elements (binned sum / count).
616    pub fn mean(&self) -> f64 {
617        let n = self.len();
618        if n == 0 {
619            return 0.0;
620        }
621        self.sum() / n as f64
622    }
623
624    /// Mean with dispatched strategy based on execution context.
625    pub fn dispatched_mean(&self, ctx: &dispatch::ReductionContext) -> f64 {
626        let n = self.len();
627        if n == 0 {
628            return 0.0;
629        }
630        self.dispatched_sum(ctx) / n as f64
631    }
632
633    /// Sum along a specific axis, returning a tensor with that dimension reduced.
634    ///
635    /// Supports N-D tensors. The reduced axis becomes size 1 in the output.
636    /// Uses BinnedAccumulator for order-invariant, deterministic summation.
637    ///
638    /// Examples:
639    /// - 2D [M, N] with axis=0: result [1, N] (sum columns)
640    /// - 2D [M, N] with axis=1: result [M, 1] (sum rows)
641    /// - 3D [A, B, C] with axis=1: result [A, 1, C]
642    pub fn sum_axis(&self, axis: usize) -> Result<Tensor, RuntimeError> {
643        let ndim = self.ndim();
644        if axis >= ndim {
645            return Err(RuntimeError::IndexOutOfBounds {
646                index: axis,
647                length: ndim,
648            });
649        }
650
651        // Build output shape: same as input but with axis dimension = 1
652        let mut out_shape = self.shape.clone();
653        out_shape[axis] = 1;
654        let out_numel = Self::shape_numel(&out_shape);
655        let out_strides = Self::compute_strides(&out_shape);
656
657        let data = self.to_vec();
658        let axis_len = self.shape[axis];
659        let mut result = vec![0.0f64; out_numel];
660
661        // For each output position, sum over the reduced axis with binned accumulation.
662        let mut indices = vec![0usize; ndim];
663        for out_idx in 0..out_numel {
664            // Compute the N-D index from flat output index
665            {
666                let mut remaining = out_idx;
667                for d in 0..ndim {
668                    indices[d] = remaining / out_strides[d];
669                    remaining %= out_strides[d];
670                }
671            }
672
673            let mut acc = BinnedAccumulatorF64::new();
674            for k in 0..axis_len {
675                // Compute input flat index with indices[axis] = k
676                let mut flat = self.offset;
677                for d in 0..ndim {
678                    let idx = if d == axis { k } else { indices[d] };
679                    flat += idx * self.strides[d];
680                }
681                acc.add(data[flat]);
682            }
683            result[out_idx] = acc.finalize();
684        }
685
686        Tensor::from_vec(result, &out_shape)
687    }
688
689    // -- Matrix multiplication (2-D only) -----------------------------------
690
691    /// Negate every element, returning a new tensor.
692    pub fn neg(&self) -> Tensor {
693        self.map(|x| -x)
694    }
695
696    /// Transpose a tensor. For 2-D: swaps rows and columns (zero-copy view).
697    /// For N-D: reverses all axes (zero-copy view).
698    pub fn transpose(&self) -> Tensor {
699        let ndim = self.ndim();
700        if ndim <= 1 {
701            return self.clone();
702        }
703        // Reverse shape and strides — zero-copy view
704        let mut new_shape = self.shape.clone();
705        let mut new_strides = self.strides.clone();
706        new_shape.reverse();
707        new_strides.reverse();
708        Tensor {
709            buffer: self.buffer.clone(), // shared — zero copy
710            shape: new_shape,
711            strides: new_strides,
712            offset: self.offset,
713        }
714    }
715
716    /// Transpose with explicit axis permutation (N-D). Zero-copy view.
717    ///
718    /// `axes` must be a permutation of `[0, 1, ..., ndim-1]`.
719    pub fn transpose_axes(&self, axes: &[usize]) -> Result<Tensor, RuntimeError> {
720        let ndim = self.ndim();
721        if axes.len() != ndim {
722            return Err(RuntimeError::InvalidOperation(
723                format!("transpose_axes: expected {} axes, got {}", ndim, axes.len()),
724            ));
725        }
726        // Validate permutation
727        let mut seen = vec![false; ndim];
728        for &ax in axes {
729            if ax >= ndim {
730                return Err(RuntimeError::IndexOutOfBounds { index: ax, length: ndim });
731            }
732            if seen[ax] {
733                return Err(RuntimeError::InvalidOperation(
734                    format!("transpose_axes: duplicate axis {ax}"),
735                ));
736            }
737            seen[ax] = true;
738        }
739        let new_shape: Vec<usize> = axes.iter().map(|&ax| self.shape[ax]).collect();
740        let new_strides: Vec<usize> = axes.iter().map(|&ax| self.strides[ax]).collect();
741        Ok(Tensor {
742            buffer: self.buffer.clone(),
743            shape: new_shape,
744            strides: new_strides,
745            offset: self.offset,
746        })
747    }
748
749    /// Multiply every element by a scalar, returning a new tensor.
750    pub fn scalar_mul(&self, s: f64) -> Tensor {
751        self.map(|x| x * s)
752    }
753
754    // ── Panicking convenience constructors (used by AD engine) --------
755
756    /// Create a tensor from raw data and shape.
757    /// **Panics** if `data.len()` does not match the shape.
758    pub fn from_vec_unchecked(data: Vec<f64>, shape: &[usize]) -> Tensor {
759        Self::from_vec(data, shape).expect("Tensor::from_vec_unchecked: shape mismatch")
760    }
761
762    /// Element-wise addition. **Panics** on shape mismatch.
763    pub fn add_unchecked(&self, other: &Tensor) -> Tensor {
764        self.add(other).expect("Tensor::add shape mismatch")
765    }
766
767    /// In-place element-wise addition. **Panics** on shape mismatch.
768    ///
769    /// Mutates `self` instead of allocating a new tensor. Uses COW: if the
770    /// underlying buffer is shared, `make_unique()` deep-copies first.
771    pub fn add_assign_unchecked(&mut self, other: &Tensor) {
772        assert_eq!(self.shape, other.shape, "Tensor::add_assign shape mismatch");
773        self.buffer.make_unique();
774        let other_data = other.buffer.borrow_data();
775        let mut self_data = self.buffer.borrow_data_mut();
776        for (a, b) in self_data.iter_mut().zip(other_data.iter()) {
777            *a += b;
778        }
779    }
780
781    /// Element-wise subtraction. **Panics** on shape mismatch.
782    pub fn sub_unchecked(&self, other: &Tensor) -> Tensor {
783        self.sub(other).expect("Tensor::sub shape mismatch")
784    }
785
786    /// Element-wise multiplication. **Panics** on shape mismatch.
787    pub fn mul_elem_unchecked(&self, other: &Tensor) -> Tensor {
788        self.mul_elem(other).expect("Tensor::mul_elem shape mismatch")
789    }
790
791    /// Element-wise division. **Panics** on shape mismatch.
792    pub fn div_elem_unchecked(&self, other: &Tensor) -> Tensor {
793        self.div_elem(other).expect("Tensor::div_elem shape mismatch")
794    }
795
796    /// Matrix multiplication. **Panics** on dimension mismatch.
797    pub fn matmul_unchecked(&self, other: &Tensor) -> Tensor {
798        self.matmul(other).expect("Tensor::matmul dimension mismatch")
799    }
800
801    /// Matrix multiplication for 2-D tensors.
802    ///
803    /// `self` is (M, K), `other` is (K, N) => result is (M, N).
804    pub fn matmul(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
805        if self.ndim() != 2 || other.ndim() != 2 {
806            return Err(RuntimeError::InvalidOperation(
807                "matmul requires 2-D tensors".to_string(),
808            ));
809        }
810        let m = self.shape[0];
811        let k = self.shape[1];
812        let k2 = other.shape[0];
813        let n = other.shape[1];
814        if k != k2 {
815            return Err(RuntimeError::DimensionMismatch {
816                expected: k,
817                got: k2,
818            });
819        }
820
821        let a = self.to_vec();
822        let b = other.to_vec();
823
824        // Parallel path (Mode A): parallelize over output rows when the parallel
825        // feature is enabled and the matrix is large enough (>= 256 in any dim).
826        #[cfg(feature = "parallel")]
827        {
828            if m >= 256 || n >= 256 || k >= 256 {
829                return Self::matmul_parallel_mode_a(&a, &b, m, n, k);
830            }
831        }
832
833        // Tiled path: use L2-friendly tiled matmul for medium-to-large matrices.
834        // Threshold: any dimension >= 64 (the default tile size).
835        // NOTE: tiled path uses naive accumulation (not binned) — different
836        // numerical path for large matrices, but better cache locality.
837        if m >= 64 || n >= 64 || k >= 64 {
838            return Self::matmul_tiled(&a, &b, m, n, k);
839        }
840
841        // Sequential path: single-threaded with binned accumulation.
842        Self::matmul_sequential(&a, &b, m, n, k)
843    }
844
845    /// Sequential matmul (always available, deterministic reference).
846    fn matmul_sequential(
847        a: &[f64], b: &[f64], m: usize, n: usize, k: usize,
848    ) -> Result<Tensor, RuntimeError> {
849        let mut result = vec![0.0f64; m * n];
850        for i in 0..m {
851            for j in 0..n {
852                let mut acc = KahanAccumulatorF64::new();
853                for p in 0..k {
854                    acc.add(a[i * k + p] * b[p * n + j]);
855                }
856                result[i * n + j] = acc.finalize();
857            }
858        }
859        Tensor::from_vec(result, &[m, n])
860    }
861
862    /// Tiled matmul: delegates to `TiledMatmul` for L2-cache-friendly tiling.
863    ///
864    /// Used for medium matrices (any dimension >= 64) where cache locality
865    /// matters but parallel overhead isn't justified. The tiled path uses
866    /// naive accumulation (not binned accumulation), trading a small amount of
867    /// floating-point precision for better cache behavior.
868    fn matmul_tiled(
869        a: &[f64], b: &[f64], m: usize, n: usize, k: usize,
870    ) -> Result<Tensor, RuntimeError> {
871        let engine = TiledMatmul::new();
872        let result = engine.matmul(a, m, k, b, n);
873        Tensor::from_vec(result, &[m, n])
874    }
875
876    /// Parallel matmul Mode A: parallelize over output rows using tiled
877    /// micro-kernels for cache locality.
878    ///
879    /// Deterministic because:
880    /// - Each output row is computed by exactly one thread.
881    /// - Within each row, accumulation uses tiled AXPY (deterministic order).
882    /// - No cross-thread reduction or merge of partial sums.
883    ///
884    /// Uses KahanAccumulatorF64 (lightweight) instead of BinnedAccumulatorF64
885    /// (32KB per accumulator) to avoid massive stack pressure in parallel mode.
886    #[cfg(feature = "parallel")]
887    fn matmul_parallel_mode_a(
888        a: &[f64], b: &[f64], m: usize, n: usize, k: usize,
889    ) -> Result<Tensor, RuntimeError> {
890        use rayon::prelude::*;
891        use cjc_repro::KahanAccumulatorF64;
892
893        // For large matrices, use tiled matmul (sequential but cache-friendly).
894        // The tiling provides far better cache locality than parallel row-wise
895        // with column-strided B access, and avoids the 32KB-per-element
896        // BinnedAccumulator overhead that caused the 128→256 regression.
897        if m >= 512 && n >= 512 {
898            // Only use rayon for very large matrices where thread overhead
899            // is amortized. Split into row-bands, each processed with tiled matmul.
900            let band_size = (m + rayon::current_num_threads() - 1) / rayon::current_num_threads();
901            let band_size = band_size.max(64); // At least 64 rows per band
902            let mut result = vec![0.0f64; m * n];
903
904            result
905                .par_chunks_mut(band_size * n)
906                .enumerate()
907                .for_each(|(band_idx, band)| {
908                    let i_start = band_idx * band_size;
909                    let i_end = (i_start + band_size).min(m);
910                    let band_m = i_end - i_start;
911                    let a_band = &a[i_start * k .. i_end * k];
912                    let engine = crate::tensor_tiled::TiledMatmul::new();
913                    let tiled_result = engine.matmul(a_band, band_m, k, b, n);
914                    band[..band_m * n].copy_from_slice(&tiled_result);
915                });
916
917            return Tensor::from_vec(result, &[m, n]);
918        }
919
920        // For medium matrices (256-511), use Kahan accumulator (16 bytes, not 32KB).
921        let mut result = vec![0.0f64; m * n];
922        result
923            .par_chunks_mut(n)
924            .enumerate()
925            .for_each(|(i, row)| {
926                for j in 0..n {
927                    let mut acc = KahanAccumulatorF64::new();
928                    for p in 0..k {
929                        acc.add(a[i * k + p] * b[p * n + j]);
930                    }
931                    row[j] = acc.finalize();
932                }
933            });
934
935        Tensor::from_vec(result, &[m, n])
936    }
937
938    // -- Transformer Kernels ------------------------------------------------
939
940    /// Batched matrix multiplication.
941    ///
942    /// `self` is `[..., M, K]`, `other` is `[..., K, N]` => result is `[..., M, N]`.
943    /// The batch dimensions must be identical (no broadcast).
944    /// For 2-D inputs, delegates to `matmul`.
945    pub fn bmm(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
946        if self.ndim() < 2 || other.ndim() < 2 {
947            return Err(RuntimeError::InvalidOperation(
948                "bmm requires at least 2-D tensors".to_string(),
949            ));
950        }
951        if self.ndim() == 2 && other.ndim() == 2 {
952            return self.matmul(other);
953        }
954        if self.ndim() != other.ndim() {
955            return Err(RuntimeError::InvalidOperation(
956                format!(
957                    "bmm requires same number of dimensions, got {} and {}",
958                    self.ndim(),
959                    other.ndim()
960                ),
961            ));
962        }
963        let nd = self.ndim();
964        let batch_dims_a = &self.shape[..nd - 2];
965        let batch_dims_b = &other.shape[..nd - 2];
966        if batch_dims_a != batch_dims_b {
967            return Err(RuntimeError::InvalidOperation(
968                format!(
969                    "bmm batch dimensions mismatch: {:?} vs {:?}",
970                    batch_dims_a, batch_dims_b
971                ),
972            ));
973        }
974        let m = self.shape[nd - 2];
975        let k = self.shape[nd - 1];
976        let k2 = other.shape[nd - 2];
977        let n = other.shape[nd - 1];
978        if k != k2 {
979            return Err(RuntimeError::DimensionMismatch {
980                expected: k,
981                got: k2,
982            });
983        }
984
985        let batch_size: usize = batch_dims_a.iter().product();
986        let a = self.to_vec();
987        let b = other.to_vec();
988        let mat_a_stride = m * k;
989        let mat_b_stride = k * n;
990        let mat_c_stride = m * n;
991        let mut result = vec![0.0f64; batch_size * mat_c_stride];
992
993        // Helper closure: compute one batch into c_slice
994        let compute_batch = |batch: usize, c_slice: &mut [f64]| {
995            let a_slice = &a[batch * mat_a_stride..(batch + 1) * mat_a_stride];
996            let b_slice = &b[batch * mat_b_stride..(batch + 1) * mat_b_stride];
997
998            if m >= 64 || n >= 64 || k >= 64 {
999                let engine = crate::tensor_tiled::TiledMatmul::new();
1000                let tiled = engine.matmul(a_slice, m, k, b_slice, n);
1001                c_slice.copy_from_slice(&tiled);
1002            } else {
1003                for i in 0..m {
1004                    for j in 0..n {
1005                        let mut acc = KahanAccumulatorF64::new();
1006                        for p in 0..k {
1007                            acc.add(a_slice[i * k + p] * b_slice[p * n + j]);
1008                        }
1009                        c_slice[i * n + j] = acc.finalize();
1010                    }
1011                }
1012            }
1013        };
1014
1015        // Parallel path: parallelize over batches when workload is large enough
1016        #[cfg(feature = "parallel")]
1017        {
1018            if batch_size > 1 && m * k >= 4096 {
1019                use rayon::prelude::*;
1020                result
1021                    .par_chunks_mut(mat_c_stride)
1022                    .enumerate()
1023                    .for_each(|(batch, c_slice)| {
1024                        compute_batch(batch, c_slice);
1025                    });
1026
1027                let mut out_shape = batch_dims_a.to_vec();
1028                out_shape.push(m);
1029                out_shape.push(n);
1030                return Tensor::from_vec(result, &out_shape);
1031            }
1032        }
1033
1034        // Sequential fallback
1035        for batch in 0..batch_size {
1036            let c_off = batch * mat_c_stride;
1037            compute_batch(batch, &mut result[c_off..c_off + mat_c_stride]);
1038        }
1039
1040        let mut out_shape = batch_dims_a.to_vec();
1041        out_shape.push(m);
1042        out_shape.push(n);
1043        Tensor::from_vec(result, &out_shape)
1044    }
1045
1046    /// Softmax along the last dimension (two-pass stable algorithm).
1047    ///
1048    /// Pass 1: find max per row (prevents overflow in exp)
1049    /// Pass 2: compute exp(x - max), accumulate sum, normalize
1050    ///
1051    /// For a tensor of shape `[..., N]`, softmax is applied independently
1052    /// to each length-N slice along the last axis.
1053    pub fn softmax(&self) -> Result<Tensor, RuntimeError> {
1054        if self.ndim() == 0 {
1055            return Err(RuntimeError::InvalidOperation(
1056                "softmax requires at least 1-D tensor".to_string(),
1057            ));
1058        }
1059        // Avoid allocation when tensor is already contiguous and starts at offset 0
1060        let data_ref;
1061        let data_vec;
1062        let data: &[f64] = if self.is_contiguous() && self.offset == 0 {
1063            data_ref = self.buffer.borrow_data();
1064            &data_ref
1065        } else {
1066            data_vec = self.to_vec();
1067            &data_vec
1068        };
1069        let n = *self.shape.last().unwrap(); // last dimension size
1070        let outer: usize = data.len() / n;  // product of all dims except last
1071        let mut result = vec![0.0f64; data.len()];
1072
1073        for row in 0..outer {
1074            let start = row * n;
1075            let end = start + n;
1076            let slice = &data[start..end];
1077
1078            // Pass 1: find max for numerical stability
1079            let mut max_val = f64::NEG_INFINITY;
1080            for &v in slice {
1081                if v > max_val {
1082                    max_val = v;
1083                }
1084            }
1085
1086            // Pass 2: exp(x - max) and accumulate sum
1087            let mut exp_vals = vec![0.0f64; n];
1088            let mut sum = 0.0f64;
1089            let mut comp = 0.0f64; // Kahan compensation
1090            for i in 0..n {
1091                let e = (slice[i] - max_val).exp();
1092                exp_vals[i] = e;
1093                // Kahan summation for the denominator
1094                let y = e - comp;
1095                let t = sum + y;
1096                comp = (t - sum) - y;
1097                sum = t;
1098            }
1099
1100            // Normalize
1101            if sum == 0.0 {
1102                // Degenerate case: all -inf inputs → uniform
1103                let uniform = 1.0 / n as f64;
1104                for i in 0..n {
1105                    result[start + i] = uniform;
1106                }
1107            } else {
1108                for i in 0..n {
1109                    result[start + i] = exp_vals[i] / sum;
1110                }
1111            }
1112        }
1113
1114        Tensor::from_vec(result, &self.shape)
1115    }
1116
1117    /// Layer normalization over the last dimension.
1118    ///
1119    /// For each length-D slice along the last axis:
1120    ///   1. mean = Σx / D  (BinnedAccumulator)
1121    ///   2. var  = Σ(x - mean)² / D  (BinnedAccumulator)
1122    ///   3. normalized = (x - mean) / √(var + eps)
1123    ///   4. output = gamma * normalized + beta
1124    ///
1125    /// `gamma` and `beta` are 1-D tensors of shape `[D]`.
1126    /// `eps` is a small constant (typically 1e-5).
1127    pub fn layer_norm(
1128        &self,
1129        gamma: &Tensor,
1130        beta: &Tensor,
1131        eps: f64,
1132    ) -> Result<Tensor, RuntimeError> {
1133        if self.ndim() == 0 {
1134            return Err(RuntimeError::InvalidOperation(
1135                "layer_norm requires at least 1-D tensor".to_string(),
1136            ));
1137        }
1138        let d = *self.shape.last().unwrap();
1139        if gamma.len() != d || beta.len() != d {
1140            return Err(RuntimeError::InvalidOperation(
1141                format!(
1142                    "layer_norm: gamma/beta length {} must match last dim {}",
1143                    gamma.len(),
1144                    d
1145                ),
1146            ));
1147        }
1148
1149        let data = self.to_vec();
1150        let gamma_data = gamma.to_vec();
1151        let beta_data = beta.to_vec();
1152        let outer = data.len() / d;
1153        let mut result = vec![0.0f64; data.len()];
1154
1155        for row in 0..outer {
1156            let start = row * d;
1157            let slice = &data[start..start + d];
1158
1159            // Pass 1: compute mean via BinnedAccumulator
1160            let mean = binned_sum_f64(slice) / d as f64;
1161
1162            // Pass 2: compute variance via BinnedAccumulator
1163            let diffs: Vec<f64> = slice.iter().map(|&x| {
1164                let diff = x - mean;
1165                diff * diff
1166            }).collect();
1167            let variance = binned_sum_f64(&diffs) / d as f64;
1168
1169            // Normalize, scale, shift
1170            let inv_std = 1.0 / (variance + eps).sqrt();
1171            for i in 0..d {
1172                let normalized = (slice[i] - mean) * inv_std;
1173                result[start + i] = gamma_data[i] * normalized + beta_data[i];
1174            }
1175        }
1176
1177        Tensor::from_vec(result, &self.shape)
1178    }
1179
1180    /// Apply a function element-wise, reusing the buffer when possible (COW).
1181    ///
1182    /// If the tensor is contiguous, starts at offset 0, and has refcount == 1,
1183    /// the data is mutated in place (zero allocations). Otherwise, allocates
1184    /// a new buffer.
1185    fn map_elementwise(&self, f: impl Fn(f64) -> f64) -> Tensor {
1186        if self.is_contiguous() && self.offset == 0 && self.buffer.refcount() == 1 {
1187            // Fast path: mutate in place (COW — we're the sole owner)
1188            let mut data = self.buffer.borrow_data().clone();
1189            for x in data.iter_mut() {
1190                *x = f(*x);
1191            }
1192            Tensor::from_vec(data, &self.shape).unwrap()
1193        } else {
1194            // Fallback: allocate new buffer
1195            let data = self.to_vec();
1196            let result: Vec<f64> = data.iter().map(|&x| f(x)).collect();
1197            Tensor::from_vec(result, &self.shape).unwrap()
1198        }
1199    }
1200
1201    /// ReLU activation: `max(0, x)` element-wise.
1202    pub fn relu(&self) -> Tensor {
1203        self.map_elementwise(|x| if x > 0.0 { x } else { 0.0 })
1204    }
1205
1206    /// Sigmoid activation: 1 / (1 + exp(-x)) element-wise.
1207    pub fn sigmoid(&self) -> Tensor {
1208        self.map_elementwise(|x| 1.0 / (1.0 + (-x).exp()))
1209    }
1210
1211    /// Tanh activation element-wise.
1212    pub fn tanh_activation(&self) -> Tensor {
1213        self.map_elementwise(|x| x.tanh())
1214    }
1215
1216    /// Leaky ReLU activation: max(alpha*x, x) element-wise.
1217    pub fn leaky_relu(&self, alpha: f64) -> Tensor {
1218        self.map_elementwise(move |x| if x > 0.0 { x } else { alpha * x })
1219    }
1220
1221    /// SiLU (Swish) activation: x * sigmoid(x) element-wise.
1222    pub fn silu(&self) -> Tensor {
1223        let data = self.to_vec();
1224        let result: Vec<f64> = data.iter().map(|&x| x / (1.0 + (-x).exp())).collect();
1225        Tensor::from_vec(result, &self.shape).unwrap()
1226    }
1227
1228    /// Mish activation: x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))).
1229    pub fn mish(&self) -> Tensor {
1230        let data = self.to_vec();
1231        let result: Vec<f64> = data.iter().map(|&x| {
1232            let sp = (1.0 + x.exp()).ln();
1233            x * sp.tanh()
1234        }).collect();
1235        Tensor::from_vec(result, &self.shape).unwrap()
1236    }
1237
1238    /// Argmax: index of the maximum element (first occurrence, deterministic).
1239    pub fn argmax(&self) -> usize {
1240        let data = self.to_vec();
1241        let mut best_idx = 0;
1242        let mut best_val = f64::NEG_INFINITY;
1243        for (i, &v) in data.iter().enumerate() {
1244            if v > best_val || (v == best_val && i < best_idx) {
1245                best_val = v;
1246                best_idx = i;
1247            }
1248        }
1249        best_idx
1250    }
1251
1252    /// Argmin: index of the minimum element (first occurrence, deterministic).
1253    pub fn argmin(&self) -> usize {
1254        let data = self.to_vec();
1255        let mut best_idx = 0;
1256        let mut best_val = f64::INFINITY;
1257        for (i, &v) in data.iter().enumerate() {
1258            if v < best_val || (v == best_val && i < best_idx) {
1259                best_val = v;
1260                best_idx = i;
1261            }
1262        }
1263        best_idx
1264    }
1265
1266    /// Clamp all elements to [min, max].
1267    pub fn clamp(&self, min: f64, max: f64) -> Tensor {
1268        let data = self.to_vec();
1269        let result: Vec<f64> = data.iter().map(|&x| x.max(min).min(max)).collect();
1270        Tensor::from_vec(result, &self.shape).unwrap()
1271    }
1272
1273    /// One-hot encoding: given a 1D tensor of integer indices and a depth,
1274    /// returns a 2D tensor of shape [len, depth].
1275    pub fn one_hot(indices: &[usize], depth: usize) -> Result<Tensor, RuntimeError> {
1276        let n = indices.len();
1277        let mut data = vec![0.0; n * depth];
1278        for (i, &idx) in indices.iter().enumerate() {
1279            if idx >= depth {
1280                return Err(RuntimeError::InvalidOperation(format!(
1281                    "one_hot: index {idx} >= depth {depth}"
1282                )));
1283            }
1284            data[i * depth + idx] = 1.0;
1285        }
1286        Tensor::from_vec(data, &[n, depth])
1287    }
1288
1289    // -----------------------------------------------------------------------
1290    // Phase B4: Tensor extensions (cat, stack, topk)
1291    // -----------------------------------------------------------------------
1292
1293    /// Concatenate tensors along existing axis.
1294    pub fn cat(tensors: &[&Tensor], axis: usize) -> Result<Tensor, RuntimeError> {
1295        if tensors.is_empty() {
1296            return Err(RuntimeError::InvalidOperation("cat: no tensors".to_string()));
1297        }
1298        let ndim = tensors[0].ndim();
1299        if axis >= ndim {
1300            return Err(RuntimeError::InvalidOperation(
1301                format!("cat: axis {axis} out of bounds for {ndim}D tensor"),
1302            ));
1303        }
1304        for (i, t) in tensors.iter().enumerate().skip(1) {
1305            if t.ndim() != ndim {
1306                return Err(RuntimeError::InvalidOperation(
1307                    format!("cat: tensor {i} has different ndim"),
1308                ));
1309            }
1310            for d in 0..ndim {
1311                if d != axis && t.shape[d] != tensors[0].shape[d] {
1312                    return Err(RuntimeError::InvalidOperation(
1313                        format!("cat: shape mismatch at dim {d}"),
1314                    ));
1315                }
1316            }
1317        }
1318        let mut out_shape = tensors[0].shape.clone();
1319        for t in tensors.iter().skip(1) {
1320            out_shape[axis] += t.shape[axis];
1321        }
1322        let total = out_shape.iter().product::<usize>();
1323        let mut result = vec![0.0; total];
1324        let mut out_strides = vec![1usize; ndim];
1325        for d in (0..ndim - 1).rev() {
1326            out_strides[d] = out_strides[d + 1] * out_shape[d + 1];
1327        }
1328        let mut offset = 0;
1329        for t in tensors {
1330            let t_data = t.to_vec();
1331            let t_total: usize = t.shape.iter().product();
1332            let mut t_strides = vec![1usize; ndim];
1333            for d in (0..ndim - 1).rev() {
1334                t_strides[d] = t_strides[d + 1] * t.shape[d + 1];
1335            }
1336            for idx in 0..t_total {
1337                let mut remaining = idx;
1338                let mut out_flat = 0;
1339                for d in 0..ndim {
1340                    let coord = remaining / t_strides[d];
1341                    remaining %= t_strides[d];
1342                    let out_coord = if d == axis { coord + offset } else { coord };
1343                    out_flat += out_coord * out_strides[d];
1344                }
1345                result[out_flat] = t_data[idx];
1346            }
1347            offset += t.shape[axis];
1348        }
1349        Tensor::from_vec(result, &out_shape)
1350    }
1351
1352    /// Stack tensors along a new axis.
1353    pub fn stack(tensors: &[&Tensor], axis: usize) -> Result<Tensor, RuntimeError> {
1354        if tensors.is_empty() {
1355            return Err(RuntimeError::InvalidOperation("stack: no tensors".to_string()));
1356        }
1357        let base_shape = &tensors[0].shape;
1358        let ndim = base_shape.len();
1359        if axis > ndim {
1360            return Err(RuntimeError::InvalidOperation(
1361                format!("stack: axis {axis} out of bounds"),
1362            ));
1363        }
1364        for (i, t) in tensors.iter().enumerate().skip(1) {
1365            if &t.shape != base_shape {
1366                return Err(RuntimeError::InvalidOperation(
1367                    format!("stack: tensor {i} shape mismatch"),
1368                ));
1369            }
1370        }
1371        let mut out_shape = Vec::with_capacity(ndim + 1);
1372        for d in 0..axis { out_shape.push(base_shape[d]); }
1373        out_shape.push(tensors.len());
1374        for d in axis..ndim { out_shape.push(base_shape[d]); }
1375        let total: usize = out_shape.iter().product();
1376        let mut result = vec![0.0; total];
1377        let inner_size: usize = base_shape[axis..].iter().product::<usize>().max(1);
1378        let outer_size: usize = base_shape[..axis].iter().product::<usize>().max(1);
1379        for (t_idx, t) in tensors.iter().enumerate() {
1380            let t_data = t.to_vec();
1381            for outer in 0..outer_size {
1382                for inner in 0..inner_size {
1383                    let src = outer * inner_size + inner;
1384                    let dst = outer * (tensors.len() * inner_size) + t_idx * inner_size + inner;
1385                    if src < t_data.len() && dst < result.len() {
1386                        result[dst] = t_data[src];
1387                    }
1388                }
1389            }
1390        }
1391        Tensor::from_vec(result, &out_shape)
1392    }
1393
1394    /// Top-k values and indices (largest k values from flat data).
1395    pub fn topk(&self, k: usize) -> Result<(Tensor, Vec<usize>), RuntimeError> {
1396        let data = self.to_vec();
1397        let n = data.len();
1398        if k > n {
1399            return Err(RuntimeError::InvalidOperation(
1400                format!("topk: k={k} exceeds data length {n}"),
1401            ));
1402        }
1403        let mut indexed: Vec<(usize, f64)> = data.into_iter().enumerate().collect();
1404        indexed.sort_by(|a, b| b.1.total_cmp(&a.1).then(a.0.cmp(&b.0)));
1405        let top_k: Vec<(usize, f64)> = indexed[..k].to_vec();
1406        let values: Vec<f64> = top_k.iter().map(|&(_, v)| v).collect();
1407        let indices: Vec<usize> = top_k.iter().map(|&(i, _)| i).collect();
1408        Ok((Tensor::from_vec(values, &[k])?, indices))
1409    }
1410
1411    /// GELU activation (approximate): x * 0.5 * (1 + tanh(√(2/π) * (x + 0.044715 * x³)))
1412    pub fn gelu(&self) -> Tensor {
1413        let data = self.to_vec();
1414        let sqrt_2_over_pi = (2.0_f64 / std::f64::consts::PI).sqrt();
1415        let result: Vec<f64> = data.iter().map(|&x| {
1416            let inner = sqrt_2_over_pi * (x + 0.044715 * x * x * x);
1417            0.5 * x * (1.0 + inner.tanh())
1418        }).collect();
1419        Tensor::from_vec(result, &self.shape).unwrap()
1420    }
1421
1422    /// Linear layer: output = input @ weight^T + bias
1423    ///
1424    /// `self` is `[..., in_features]`, `weight` is `[out_features, in_features]`,
1425    /// `bias` is `[out_features]`.
1426    /// Result is `[..., out_features]`.
1427    pub fn linear(
1428        &self,
1429        weight: &Tensor,
1430        bias: &Tensor,
1431    ) -> Result<Tensor, RuntimeError> {
1432        if weight.ndim() != 2 {
1433            return Err(RuntimeError::InvalidOperation(
1434                "linear: weight must be 2-D [out_features, in_features]".to_string(),
1435            ));
1436        }
1437        let out_features = weight.shape[0];
1438        let in_features = weight.shape[1];
1439        let last_dim = *self.shape.last().ok_or_else(|| {
1440            RuntimeError::InvalidOperation("linear: input must be at least 1-D".to_string())
1441        })?;
1442        if last_dim != in_features {
1443            return Err(RuntimeError::DimensionMismatch {
1444                expected: in_features,
1445                got: last_dim,
1446            });
1447        }
1448        if bias.len() != out_features {
1449            return Err(RuntimeError::InvalidOperation(
1450                format!(
1451                    "linear: bias length {} must match out_features {}",
1452                    bias.len(),
1453                    out_features
1454                ),
1455            ));
1456        }
1457
1458        let data = self.to_vec();
1459        let w = weight.to_vec();
1460        let b = bias.to_vec();
1461        let outer = data.len() / in_features;
1462        let mut result = vec![0.0f64; outer * out_features];
1463
1464        for row in 0..outer {
1465            let x_start = row * in_features;
1466            let x_slice = &data[x_start..x_start + in_features];
1467            let y_start = row * out_features;
1468            for j in 0..out_features {
1469                let w_start = j * in_features;
1470                let mut acc = BinnedAccumulatorF64::new();
1471                for p in 0..in_features {
1472                    acc.add(x_slice[p] * w[w_start + p]);
1473                }
1474                result[y_start + j] = acc.finalize() + b[j];
1475            }
1476        }
1477
1478        let mut out_shape = self.shape[..self.shape.len() - 1].to_vec();
1479        out_shape.push(out_features);
1480        Tensor::from_vec(result, &out_shape)
1481    }
1482
1483    /// 1D convolution: signal `[signal_len]` * filters `[out_ch, kernel_size]` + bias
1484    ///
1485    /// Returns `[out_ch, signal_len - kernel_size + 1]` (valid mode, stride=1).
1486    pub fn conv1d(
1487        &self,
1488        filters: &Tensor,
1489        bias: &Tensor,
1490    ) -> Result<Tensor, RuntimeError> {
1491        if self.ndim() != 1 {
1492            return Err(RuntimeError::InvalidOperation(
1493                "conv1d: input must be 1-D [signal_len]".to_string(),
1494            ));
1495        }
1496        if filters.ndim() != 2 {
1497            return Err(RuntimeError::InvalidOperation(
1498                "conv1d: filters must be 2-D [out_channels, kernel_size]".to_string(),
1499            ));
1500        }
1501        let signal_len = self.shape[0];
1502        let out_channels = filters.shape[0];
1503        let kernel_size = filters.shape[1];
1504        if signal_len < kernel_size {
1505            return Err(RuntimeError::InvalidOperation(
1506                format!(
1507                    "conv1d: signal_len {} < kernel_size {}",
1508                    signal_len, kernel_size
1509                ),
1510            ));
1511        }
1512        if bias.len() != out_channels {
1513            return Err(RuntimeError::InvalidOperation(
1514                format!(
1515                    "conv1d: bias length {} must match out_channels {}",
1516                    bias.len(), out_channels
1517                ),
1518            ));
1519        }
1520        let out_len = signal_len - kernel_size + 1;
1521        let s = self.to_vec();
1522        let f = filters.to_vec();
1523        let b = bias.to_vec();
1524        let mut result = vec![0.0; out_channels * out_len];
1525        kernel_fns::conv1d_raw(&s, &f, &b, &mut result, signal_len, out_channels, kernel_size);
1526        Tensor::from_vec(result, &[out_channels, out_len])
1527    }
1528
1529    /// 2D convolution — NCHW layout, valid mode, configurable stride.
1530    ///
1531    /// # Arguments
1532    /// - `self`:    `[N, C_in, H, W]` input tensor
1533    /// - `filters`: `[C_out, C_in, kH, kW]`
1534    /// - `bias`:    `[C_out]`
1535    /// - `stride`:  spatial stride (default 1)
1536    ///
1537    /// # Returns
1538    /// `[N, C_out, H_out, W_out]` where `H_out = (H - kH) / stride + 1`.
1539    ///
1540    /// Uses `BinnedAccumulatorF64` for every dot product — bit-identical results
1541    /// across all runs and hardware configurations.
1542    pub fn conv2d(
1543        &self,
1544        filters: &Tensor,
1545        bias: &Tensor,
1546        stride: usize,
1547    ) -> Result<Tensor, RuntimeError> {
1548        if self.ndim() != 4 {
1549            return Err(RuntimeError::InvalidOperation(
1550                "conv2d: input must be 4-D [N, C_in, H, W]".to_string(),
1551            ));
1552        }
1553        if filters.ndim() != 4 {
1554            return Err(RuntimeError::InvalidOperation(
1555                "conv2d: filters must be 4-D [C_out, C_in, kH, kW]".to_string(),
1556            ));
1557        }
1558        if stride == 0 {
1559            return Err(RuntimeError::InvalidOperation(
1560                "conv2d: stride must be >= 1".to_string(),
1561            ));
1562        }
1563
1564        let n    = self.shape[0];
1565        let c_in = self.shape[1];
1566        let h_in = self.shape[2];
1567        let w_in = self.shape[3];
1568
1569        let c_out      = filters.shape[0];
1570        let c_in_check = filters.shape[1];
1571        let kh         = filters.shape[2];
1572        let kw         = filters.shape[3];
1573
1574        if c_in != c_in_check {
1575            return Err(RuntimeError::InvalidOperation(format!(
1576                "conv2d: input C_in={} does not match filter C_in={}",
1577                c_in, c_in_check
1578            )));
1579        }
1580        if h_in < kh || w_in < kw {
1581            return Err(RuntimeError::InvalidOperation(format!(
1582                "conv2d: input spatial [{}, {}] is smaller than kernel [{}, {}]",
1583                h_in, w_in, kh, kw
1584            )));
1585        }
1586        if bias.len() != c_out {
1587            return Err(RuntimeError::InvalidOperation(format!(
1588                "conv2d: bias length {} must match C_out={}",
1589                bias.len(), c_out
1590            )));
1591        }
1592
1593        let h_out = (h_in - kh) / stride + 1;
1594        let w_out = (w_in - kw) / stride + 1;
1595
1596        let inp = self.to_vec();
1597        let flt = filters.to_vec();
1598        let b   = bias.to_vec();
1599        let mut result = vec![0.0f64; n * c_out * h_out * w_out];
1600
1601        kernel_fns::conv2d_raw(&inp, &flt, &b, &mut result,
1602                           n, c_in, h_in, w_in, c_out, kh, kw, stride);
1603
1604        Tensor::from_vec(result, &[n, c_out, h_out, w_out])
1605    }
1606
1607    /// 2D max-pooling — NCHW layout, non-overlapping windows.
1608    ///
1609    /// - `self`: `[N, C, H, W]`
1610    /// - `ph`, `pw`: pool height/width (stride = window size)
1611    ///
1612    /// Returns `[N, C, H/ph, W/pw]`.
1613    pub fn maxpool2d(&self, ph: usize, pw: usize) -> Result<Tensor, RuntimeError> {
1614        if self.ndim() != 4 {
1615            return Err(RuntimeError::InvalidOperation(
1616                "maxpool2d: input must be 4-D [N, C, H, W]".to_string(),
1617            ));
1618        }
1619        if ph == 0 || pw == 0 {
1620            return Err(RuntimeError::InvalidOperation(
1621                "maxpool2d: pool size must be >= 1".to_string(),
1622            ));
1623        }
1624
1625        let n    = self.shape[0];
1626        let c    = self.shape[1];
1627        let h_in = self.shape[2];
1628        let w_in = self.shape[3];
1629
1630        if h_in < ph || w_in < pw {
1631            return Err(RuntimeError::InvalidOperation(format!(
1632                "maxpool2d: input [{}, {}] smaller than pool [{}, {}]",
1633                h_in, w_in, ph, pw
1634            )));
1635        }
1636
1637        let h_out = h_in / ph;
1638        let w_out = w_in / pw;
1639
1640        let inp = self.to_vec();
1641        let mut result = vec![0.0f64; n * c * h_out * w_out];
1642
1643        kernel_fns::maxpool2d_raw(&inp, &mut result, n, c, h_in, w_in, ph, pw);
1644
1645        Tensor::from_vec(result, &[n, c, h_out, w_out])
1646    }
1647
1648    /// Applies 2-D average pooling over a `[C, H, W]` tensor.
1649    ///
1650    /// # Arguments
1651    ///
1652    /// * `kernel_h` / `kernel_w` - Pooling window size
1653    /// * `stride_h` / `stride_w` - Stride for the pooling window
1654    ///
1655    /// # Returns
1656    ///
1657    /// Tensor of shape `[C, out_h, out_w]` where `out_h = (H - kernel_h) / stride_h + 1`.
1658    ///
1659    /// # Errors
1660    ///
1661    /// Returns an error if the tensor is not 3-D or if kernel/stride produce invalid output dimensions.
1662    pub fn avgpool2d(&self, kernel_h: usize, kernel_w: usize, stride_h: usize, stride_w: usize) -> Result<Tensor, RuntimeError> {
1663        let shape = self.shape();
1664        if shape.len() != 3 {
1665            return Err(RuntimeError::InvalidOperation(format!("avgpool2d requires 3-D [C,H,W], got {:?}", shape)));
1666        }
1667        let (c, h, w) = (shape[0], shape[1], shape[2]);
1668        if kernel_h > h || kernel_w > w {
1669            return Err(RuntimeError::InvalidOperation("avgpool2d: kernel larger than input".into()));
1670        }
1671        let out_h = (h - kernel_h) / stride_h + 1;
1672        let out_w = (w - kernel_w) / stride_w + 1;
1673        let data = self.to_vec();
1674        let mut out = Vec::with_capacity(c * out_h * out_w);
1675        let pool_size = (kernel_h * kernel_w) as f64;
1676
1677        for ch in 0..c {
1678            for oh in 0..out_h {
1679                for ow in 0..out_w {
1680                    let mut sum = 0.0f64;
1681                    for kh in 0..kernel_h {
1682                        for kw in 0..kernel_w {
1683                            let ih = oh * stride_h + kh;
1684                            let iw = ow * stride_w + kw;
1685                            sum += data[ch * h * w + ih * w + iw];
1686                        }
1687                    }
1688                    out.push(sum / pool_size);
1689                }
1690            }
1691        }
1692        Tensor::from_vec(out, &[c, out_h, out_w])
1693    }
1694
1695    /// Scaled dot-product attention (single head).
1696    ///
1697    /// `queries` is `[..., T, d_k]`
1698    /// `keys`    is `[..., S, d_k]`
1699    /// `values`  is `[..., S, d_v]`
1700    ///
1701    /// Computes: softmax(Q × Kᵀ / √d_k) × V
1702    /// Returns `[..., T, d_v]`.
1703    pub fn scaled_dot_product_attention(
1704        queries: &Tensor,
1705        keys: &Tensor,
1706        values: &Tensor,
1707    ) -> Result<Tensor, RuntimeError> {
1708        if queries.ndim() < 2 || keys.ndim() < 2 || values.ndim() < 2 {
1709            return Err(RuntimeError::InvalidOperation(
1710                "attention: Q, K, V must be at least 2-D".to_string(),
1711            ));
1712        }
1713        let nd = queries.ndim();
1714        let d_k = queries.shape[nd - 1];
1715        let scale = 1.0 / (d_k as f64).sqrt();
1716
1717        // Transpose keys: swap last two dims
1718        let keys_t = keys.transpose_last_two()?;
1719
1720        // Q × K^T → [... T, S]
1721        let scores = queries.bmm(&keys_t)?;
1722
1723        // Scale
1724        let scores_scaled = scores.scalar_mul(scale);
1725
1726        // Softmax along last dim
1727        let attn_weights = scores_scaled.softmax()?;
1728
1729        // Attn × V → [... T, d_v]
1730        attn_weights.bmm(values)
1731    }
1732
1733    /// Transpose the last two dimensions of a tensor.
1734    ///
1735    /// `[..., A, B]` → `[..., B, A]`
1736    pub fn transpose_last_two(&self) -> Result<Tensor, RuntimeError> {
1737        if self.ndim() < 2 {
1738            return Err(RuntimeError::InvalidOperation(
1739                "transpose_last_two requires at least 2-D tensor".to_string(),
1740            ));
1741        }
1742        let nd = self.ndim();
1743        let rows = self.shape[nd - 2];
1744        let cols = self.shape[nd - 1];
1745        let data = self.to_vec();
1746        let batch_size: usize = self.shape[..nd - 2].iter().product::<usize>().max(1);
1747        let mat_size = rows * cols;
1748        let mut result = vec![0.0f64; data.len()];
1749
1750        for b in 0..batch_size {
1751            let off = b * mat_size;
1752            for i in 0..rows {
1753                for j in 0..cols {
1754                    result[off + j * rows + i] = data[off + i * cols + j];
1755                }
1756            }
1757        }
1758
1759        let mut out_shape = self.shape.clone();
1760        out_shape[nd - 2] = cols;
1761        out_shape[nd - 1] = rows;
1762        Tensor::from_vec(result, &out_shape)
1763    }
1764
1765    // -- Zero-Copy Weight Mapping -------------------------------------------
1766
1767    /// Create a tensor view from raw bytes — **zero allocation**.
1768    ///
1769    /// Interprets `bytes` as a contiguous block of `f64` (8 bytes each) or
1770    /// `f32` (4 bytes each, promoted to f64) values and maps them into a
1771    /// `Tensor` with the given shape.
1772    ///
1773    /// `dtype` must be `"f64"` or `"f32"`.
1774    ///
1775    /// For f64: bytes.len() must equal shape_numel * 8.
1776    /// For f32: bytes.len() must equal shape_numel * 4.
1777    ///
1778    /// The returned tensor **owns** its buffer (copied from the raw bytes)
1779    /// but performs exactly one allocation for the data vector.
1780    pub fn from_bytes(bytes: &[u8], shape: &[usize], dtype: &str) -> Result<Tensor, RuntimeError> {
1781        let numel = Self::shape_numel(shape);
1782        match dtype {
1783            "f64" => {
1784                let expected = numel * 8;
1785                if bytes.len() != expected {
1786                    return Err(RuntimeError::ShapeMismatch {
1787                        expected,
1788                        got: bytes.len(),
1789                    });
1790                }
1791                let mut data = Vec::with_capacity(numel);
1792                for i in 0..numel {
1793                    let off = i * 8;
1794                    let mut buf = [0u8; 8];
1795                    buf.copy_from_slice(&bytes[off..off + 8]);
1796                    data.push(f64::from_le_bytes(buf));
1797                }
1798                Ok(Tensor {
1799                    buffer: Buffer::from_vec(data),
1800                    shape: shape.to_vec(),
1801                    strides: Self::compute_strides(shape),
1802                    offset: 0,
1803                })
1804            }
1805            "f32" => {
1806                let expected = numel * 4;
1807                if bytes.len() != expected {
1808                    return Err(RuntimeError::ShapeMismatch {
1809                        expected,
1810                        got: bytes.len(),
1811                    });
1812                }
1813                let mut data = Vec::with_capacity(numel);
1814                for i in 0..numel {
1815                    let off = i * 4;
1816                    let mut buf = [0u8; 4];
1817                    buf.copy_from_slice(&bytes[off..off + 4]);
1818                    data.push(f32::from_le_bytes(buf) as f64);
1819                }
1820                Ok(Tensor {
1821                    buffer: Buffer::from_vec(data),
1822                    shape: shape.to_vec(),
1823                    strides: Self::compute_strides(shape),
1824                    offset: 0,
1825                })
1826            }
1827            _ => Err(RuntimeError::InvalidOperation(
1828                format!("from_bytes: unsupported dtype '{}', expected 'f32' or 'f64'", dtype),
1829            )),
1830        }
1831    }
1832
1833    // -- Multi-Head Attention Splitting -------------------------------------
1834
1835    /// Reshape a 3D tensor `[batch, seq, model_dim]` into 4D
1836    /// `[batch, num_heads, seq, head_dim]` by splitting the last dimension.
1837    ///
1838    /// This is a **zero-copy view** — it only changes shape/strides metadata.
1839    /// `model_dim` must be divisible by `num_heads`.
1840    pub fn split_heads(&self, num_heads: usize) -> Result<Tensor, RuntimeError> {
1841        if self.ndim() != 3 {
1842            return Err(RuntimeError::DimensionMismatch {
1843                expected: 3,
1844                got: self.ndim(),
1845            });
1846        }
1847        let batch = self.shape[0];
1848        let seq = self.shape[1];
1849        let model_dim = self.shape[2];
1850        if model_dim % num_heads != 0 {
1851            return Err(RuntimeError::InvalidOperation(
1852                format!(
1853                    "split_heads: model_dim {} not divisible by num_heads {}",
1854                    model_dim, num_heads
1855                ),
1856            ));
1857        }
1858        let head_dim = model_dim / num_heads;
1859        // Need contiguous data for the reshape
1860        let tensor = if self.is_contiguous() { self.clone() } else { self.to_contiguous() };
1861        // Reshape [B, S, H*D] -> [B, S, H, D] then transpose to [B, H, S, D]
1862        let reshaped = Tensor {
1863            buffer: tensor.buffer.clone(),
1864            shape: vec![batch, seq, num_heads, head_dim],
1865            strides: Self::compute_strides(&[batch, seq, num_heads, head_dim]),
1866            offset: 0,
1867        };
1868        // Transpose dims 1 and 2: [B, S, H, D] -> [B, H, S, D]
1869        // New strides: swap strides[1] and strides[2]
1870        Ok(Tensor {
1871            buffer: reshaped.buffer,
1872            shape: vec![batch, num_heads, seq, head_dim],
1873            strides: vec![
1874                reshaped.strides[0], // batch stride unchanged
1875                reshaped.strides[2], // head stride (was dim 2)
1876                reshaped.strides[1], // seq stride (was dim 1)
1877                reshaped.strides[3], // head_dim stride unchanged
1878            ],
1879            offset: 0,
1880        })
1881    }
1882
1883    /// Merge heads back: reshape 4D `[batch, num_heads, seq, head_dim]` into
1884    /// 3D `[batch, seq, model_dim]`. Materializes if non-contiguous.
1885    pub fn merge_heads(&self) -> Result<Tensor, RuntimeError> {
1886        if self.ndim() != 4 {
1887            return Err(RuntimeError::DimensionMismatch {
1888                expected: 4,
1889                got: self.ndim(),
1890            });
1891        }
1892        let batch = self.shape[0];
1893        let num_heads = self.shape[1];
1894        let seq = self.shape[2];
1895        let head_dim = self.shape[3];
1896        // Need [B, H, S, D] -> [B, S, H, D] -> [B, S, H*D]
1897        // Transpose dims 1 and 2 first
1898        let transposed = Tensor {
1899            buffer: self.buffer.clone(),
1900            shape: vec![batch, seq, num_heads, head_dim],
1901            strides: vec![
1902                self.strides[0],
1903                self.strides[2], // seq stride
1904                self.strides[1], // head stride
1905                self.strides[3],
1906            ],
1907            offset: self.offset,
1908        };
1909        // Materialize contiguous then reshape
1910        let contig = transposed.to_contiguous();
1911        let model_dim = num_heads * head_dim;
1912        Ok(Tensor {
1913            buffer: contig.buffer,
1914            shape: vec![batch, seq, model_dim],
1915            strides: Self::compute_strides(&[batch, seq, model_dim]),
1916            offset: 0,
1917        })
1918    }
1919
1920    /// View-only reshape: reinterpret shape without copying.
1921    /// Only works on contiguous tensors. Falls back to copy if non-contiguous.
1922    pub fn view_reshape(&self, new_shape: &[usize]) -> Result<Tensor, RuntimeError> {
1923        self.reshape(new_shape)
1924    }
1925
1926    // -----------------------------------------------------------------------
1927    // Phase C4: Sorting & Tensor Indexing
1928    // -----------------------------------------------------------------------
1929
1930    /// Returns indices that would sort the flattened tensor in ascending order.
1931    /// Uses f64::total_cmp for deterministic ordering of NaN.
1932    pub fn argsort(&self) -> Tensor {
1933        let data = self.to_vec();
1934        let mut indices: Vec<usize> = (0..data.len()).collect();
1935        indices.sort_by(|&a, &b| data[a].total_cmp(&data[b]));
1936        let result: Vec<f64> = indices.iter().map(|&i| i as f64).collect();
1937        Tensor::from_vec_unchecked(result, &[data.len()])
1938    }
1939
1940    /// Gather elements from the tensor along a dimension using index tensor.
1941    /// For 1D: result[i] = self[indices[i]]
1942    /// For 2D dim=0: result[i][j] = self[indices[i][j]][j]
1943    /// For 2D dim=1: result[i][j] = self[i][indices[i][j]]
1944    pub fn gather(&self, dim: usize, indices: &Tensor) -> Result<Tensor, RuntimeError> {
1945        let data = self.to_vec();
1946        let idx_data = indices.to_vec();
1947        if self.ndim() == 1 {
1948            let mut result = Vec::with_capacity(idx_data.len());
1949            for &idx in &idx_data {
1950                let i = idx as usize;
1951                if i >= data.len() {
1952                    return Err(RuntimeError::InvalidOperation(
1953                        format!("gather: index {} out of bounds for size {}", i, data.len()),
1954                    ));
1955                }
1956                result.push(data[i]);
1957            }
1958            Ok(Tensor::from_vec_unchecked(result, indices.shape()))
1959        } else if self.ndim() == 2 {
1960            let rows = self.shape[0];
1961            let cols = self.shape[1];
1962            let idx_shape = indices.shape();
1963            let out_rows = idx_shape[0];
1964            let out_cols = idx_shape[1];
1965            let mut result = vec![0.0; out_rows * out_cols];
1966            for i in 0..out_rows {
1967                for j in 0..out_cols {
1968                    let idx = idx_data[i * out_cols + j] as usize;
1969                    let val = if dim == 0 {
1970                        if idx >= rows {
1971                            return Err(RuntimeError::InvalidOperation(
1972                                format!("gather dim=0: index {} out of bounds for {} rows", idx, rows),
1973                            ));
1974                        }
1975                        data[idx * cols + j]
1976                    } else {
1977                        if idx >= cols {
1978                            return Err(RuntimeError::InvalidOperation(
1979                                format!("gather dim=1: index {} out of bounds for {} cols", idx, cols),
1980                            ));
1981                        }
1982                        data[i * cols + idx]
1983                    };
1984                    result[i * out_cols + j] = val;
1985                }
1986            }
1987            Ok(Tensor::from_vec_unchecked(result, idx_shape))
1988        } else {
1989            Err(RuntimeError::InvalidOperation(
1990                "gather: only 1D and 2D tensors supported".into(),
1991            ))
1992        }
1993    }
1994
1995    /// Scatter src values into a tensor of given shape at indices along a dimension.
1996    /// For 1D: result[indices[i]] = src[i]
1997    /// For 2D dim=0: result[indices[i][j]][j] = src[i][j]
1998    /// For 2D dim=1: result[i][indices[i][j]] = src[i][j]
1999    pub fn scatter(&self, dim: usize, indices: &Tensor, src: &Tensor) -> Result<Tensor, RuntimeError> {
2000        let mut result = self.to_vec();
2001        let idx_data = indices.to_vec();
2002        let src_data = src.to_vec();
2003        if self.ndim() == 1 {
2004            for (k, &idx) in idx_data.iter().enumerate() {
2005                let i = idx as usize;
2006                if i >= result.len() {
2007                    return Err(RuntimeError::InvalidOperation(
2008                        format!("scatter: index {} out of bounds for size {}", i, result.len()),
2009                    ));
2010                }
2011                result[i] = src_data[k];
2012            }
2013            Ok(Tensor::from_vec_unchecked(result, self.shape()))
2014        } else if self.ndim() == 2 {
2015            let cols = self.shape[1];
2016            let idx_shape = indices.shape();
2017            let out_cols = idx_shape[1];
2018            let out_rows = idx_shape[0];
2019            for i in 0..out_rows {
2020                for j in 0..out_cols {
2021                    let idx = idx_data[i * out_cols + j] as usize;
2022                    let src_val = src_data[i * out_cols + j];
2023                    if dim == 0 {
2024                        if idx >= self.shape[0] {
2025                            return Err(RuntimeError::InvalidOperation(
2026                                format!("scatter dim=0: index {} out of bounds for {} rows", idx, self.shape[0]),
2027                            ));
2028                        }
2029                        result[idx * cols + j] = src_val;
2030                    } else {
2031                        if idx >= cols {
2032                            return Err(RuntimeError::InvalidOperation(
2033                                format!("scatter dim=1: index {} out of bounds for {} cols", idx, cols),
2034                            ));
2035                        }
2036                        result[i * cols + idx] = src_val;
2037                    }
2038                }
2039            }
2040            Ok(Tensor::from_vec_unchecked(result, self.shape()))
2041        } else {
2042            Err(RuntimeError::InvalidOperation(
2043                "scatter: only 1D and 2D tensors supported".into(),
2044            ))
2045        }
2046    }
2047
2048    /// Select slices along a dimension by index.
2049    /// For 2D dim=0: selects rows
2050    /// For 2D dim=1: selects columns
2051    pub fn index_select(&self, dim: usize, indices: &Tensor) -> Result<Tensor, RuntimeError> {
2052        let data = self.to_vec();
2053        let idx_data = indices.to_vec();
2054        if self.ndim() == 1 {
2055            let mut result = Vec::with_capacity(idx_data.len());
2056            for &idx in &idx_data {
2057                let i = idx as usize;
2058                if i >= data.len() {
2059                    return Err(RuntimeError::InvalidOperation(
2060                        format!("index_select: index {} out of bounds for size {}", i, data.len()),
2061                    ));
2062                }
2063                result.push(data[i]);
2064            }
2065            Ok(Tensor::from_vec_unchecked(result, &[idx_data.len()]))
2066        } else if self.ndim() == 2 {
2067            let rows = self.shape[0];
2068            let cols = self.shape[1];
2069            let n = idx_data.len();
2070            if dim == 0 {
2071                let mut result = Vec::with_capacity(n * cols);
2072                for &idx in &idx_data {
2073                    let i = idx as usize;
2074                    if i >= rows {
2075                        return Err(RuntimeError::InvalidOperation(
2076                            format!("index_select dim=0: index {} out of bounds for {} rows", i, rows),
2077                        ));
2078                    }
2079                    for j in 0..cols {
2080                        result.push(data[i * cols + j]);
2081                    }
2082                }
2083                Ok(Tensor::from_vec_unchecked(result, &[n, cols]))
2084            } else {
2085                let mut result = Vec::with_capacity(rows * n);
2086                for i in 0..rows {
2087                    for &idx in &idx_data {
2088                        let j = idx as usize;
2089                        if j >= cols {
2090                            return Err(RuntimeError::InvalidOperation(
2091                                format!("index_select dim=1: index {} out of bounds for {} cols", j, cols),
2092                            ));
2093                        }
2094                        result.push(data[i * cols + j]);
2095                    }
2096                }
2097                Ok(Tensor::from_vec_unchecked(result, &[rows, n]))
2098            }
2099        } else {
2100            Err(RuntimeError::InvalidOperation(
2101                "index_select: only 1D and 2D tensors supported".into(),
2102            ))
2103        }
2104    }
2105
2106    // -----------------------------------------------------------------------
2107    // Phase 2: Boolean / Masking Ops
2108    // -----------------------------------------------------------------------
2109
2110    /// Element-wise conditional select: `where(condition, other)`.
2111    /// For each element, returns `self[i]` if `condition[i] != 0.0`, else `other[i]`.
2112    pub fn tensor_where(&self, condition: &Tensor, other: &Tensor) -> Result<Tensor, RuntimeError> {
2113        if self.shape() != condition.shape() || self.shape() != other.shape() {
2114            return Err(RuntimeError::InvalidOperation(
2115                format!("where: shape mismatch self={:?} cond={:?} other={:?}",
2116                    self.shape(), condition.shape(), other.shape()),
2117            ));
2118        }
2119        let s = self.to_vec();
2120        let c = condition.to_vec();
2121        let o = other.to_vec();
2122        let result: Vec<f64> = s.iter().zip(c.iter()).zip(o.iter())
2123            .map(|((&sv, &cv), &ov)| if cv != 0.0 { sv } else { ov })
2124            .collect();
2125        Tensor::from_vec(result, self.shape())
2126    }
2127
2128    /// Return `true` if any element is non-zero.
2129    pub fn any(&self) -> bool {
2130        let data = self.to_vec();
2131        data.iter().any(|&x| x != 0.0)
2132    }
2133
2134    /// Return `true` if all elements are non-zero.
2135    pub fn all(&self) -> bool {
2136        let data = self.to_vec();
2137        data.iter().all(|&x| x != 0.0)
2138    }
2139
2140    /// Return a 1-D tensor of flat indices where elements are non-zero.
2141    ///
2142    /// If no elements are non-zero, returns an empty tensor of shape `[0]`.
2143    pub fn nonzero(&self) -> Tensor {
2144        let data = self.to_vec();
2145        let indices: Vec<f64> = data.iter().enumerate()
2146            .filter(|(_, &v)| v != 0.0)
2147            .map(|(i, _)| i as f64)
2148            .collect();
2149        let len = indices.len();
2150        if len == 0 {
2151            Tensor::from_vec(vec![], &[0]).unwrap()
2152        } else {
2153            Tensor::from_vec(indices, &[len]).unwrap()
2154        }
2155    }
2156
2157    /// Fill elements where `mask` is non-zero with `value`.
2158    pub fn masked_fill(&self, mask: &Tensor, value: f64) -> Result<Tensor, RuntimeError> {
2159        if self.shape() != mask.shape() {
2160            return Err(RuntimeError::InvalidOperation(
2161                format!("masked_fill: shape mismatch self={:?} mask={:?}",
2162                    self.shape(), mask.shape()),
2163            ));
2164        }
2165        let data = self.to_vec();
2166        let m = mask.to_vec();
2167        let result: Vec<f64> = data.iter().zip(m.iter())
2168            .map(|(&d, &mv)| if mv != 0.0 { value } else { d })
2169            .collect();
2170        Tensor::from_vec(result, self.shape())
2171    }
2172
2173    // -----------------------------------------------------------------------
2174    // Phase 2: Axis Reductions with keepdim
2175    // -----------------------------------------------------------------------
2176
2177    /// Generic axis reduction using a caller-provided reduction function.
2178    ///
2179    /// Gathers values along the specified axis for each output position
2180    /// and applies `reduce_fn`. If `keepdim` is `true`, the reduced axis
2181    /// retains size 1; otherwise it is removed from the output shape.
2182    fn reduce_axis<F>(&self, axis: usize, keepdim: bool, reduce_fn: F)
2183        -> Result<Tensor, RuntimeError>
2184    where
2185        F: Fn(&[f64]) -> f64,
2186    {
2187        let ndim = self.ndim();
2188        if axis >= ndim {
2189            return Err(RuntimeError::IndexOutOfBounds {
2190                index: axis,
2191                length: ndim,
2192            });
2193        }
2194
2195        let axis_len = self.shape[axis];
2196        // Build output shape
2197        let mut out_shape: Vec<usize> = self.shape.clone();
2198        out_shape[axis] = 1;
2199        let out_numel = Self::shape_numel(&out_shape);
2200        let out_strides = Self::compute_strides(&out_shape);
2201
2202        let data = self.to_vec();
2203        let mut result = Vec::with_capacity(out_numel);
2204        let mut indices = vec![0usize; ndim];
2205
2206        for out_idx in 0..out_numel {
2207            // Compute N-D index from flat output index
2208            {
2209                let mut remaining = out_idx;
2210                for d in 0..ndim {
2211                    indices[d] = remaining / out_strides[d];
2212                    remaining %= out_strides[d];
2213                }
2214            }
2215
2216            // Gather values along the reduction axis
2217            let mut vals = Vec::with_capacity(axis_len);
2218            for k in 0..axis_len {
2219                let mut flat = self.offset;
2220                for d in 0..ndim {
2221                    let idx = if d == axis { k } else { indices[d] };
2222                    flat += idx * self.strides[d];
2223                }
2224                vals.push(data[flat]);
2225            }
2226            result.push(reduce_fn(&vals));
2227        }
2228
2229        let final_shape = if keepdim {
2230            out_shape
2231        } else {
2232            // Remove the axis dimension
2233            let mut s: Vec<usize> = self.shape.iter().enumerate()
2234                .filter(|&(i, _)| i != axis)
2235                .map(|(_, &v)| v)
2236                .collect();
2237            if s.is_empty() {
2238                s.push(1); // scalar result
2239            }
2240            s
2241        };
2242
2243        Tensor::from_vec(result, &final_shape)
2244    }
2245
2246    /// Mean along an axis with optional keepdim.
2247    ///
2248    /// Uses [`BinnedAccumulatorF64`](crate::accumulator::BinnedAccumulatorF64)
2249    /// for deterministic summation before dividing by the axis length.
2250    pub fn mean_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2251        self.reduce_axis(axis, keepdim, |vals| {
2252            let mut acc = BinnedAccumulatorF64::new();
2253            for &v in vals { acc.add(v); }
2254            acc.finalize() / vals.len() as f64
2255        })
2256    }
2257
2258    /// Max along an axis with optional keepdim. Return `(values, indices)`.
2259    ///
2260    /// Ties are broken by choosing the first occurrence (smallest index).
2261    pub fn max_axis(&self, axis: usize, keepdim: bool) -> Result<(Tensor, Tensor), RuntimeError> {
2262        let ndim = self.ndim();
2263        if axis >= ndim {
2264            return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2265        }
2266        let axis_len = self.shape[axis];
2267        let mut out_shape = self.shape.clone();
2268        out_shape[axis] = 1;
2269        let out_numel = Self::shape_numel(&out_shape);
2270        let out_strides = Self::compute_strides(&out_shape);
2271        let data = self.to_vec();
2272        let mut values = Vec::with_capacity(out_numel);
2273        let mut idx_vals = Vec::with_capacity(out_numel);
2274        let mut indices = vec![0usize; ndim];
2275
2276        for out_idx in 0..out_numel {
2277            let mut remaining = out_idx;
2278            for d in 0..ndim {
2279                indices[d] = remaining / out_strides[d];
2280                remaining %= out_strides[d];
2281            }
2282            let mut best_val = f64::NEG_INFINITY;
2283            let mut best_idx = 0usize;
2284            for k in 0..axis_len {
2285                let mut flat = self.offset;
2286                for d in 0..ndim {
2287                    let idx = if d == axis { k } else { indices[d] };
2288                    flat += idx * self.strides[d];
2289                }
2290                let v = data[flat];
2291                if v > best_val {
2292                    best_val = v;
2293                    best_idx = k;
2294                }
2295            }
2296            values.push(best_val);
2297            idx_vals.push(best_idx as f64);
2298        }
2299
2300        let final_shape = if keepdim {
2301            out_shape
2302        } else {
2303            let mut s: Vec<usize> = self.shape.iter().enumerate()
2304                .filter(|&(i, _)| i != axis).map(|(_, &v)| v).collect();
2305            if s.is_empty() { s.push(1); }
2306            s
2307        };
2308        Ok((
2309            Tensor::from_vec(values, &final_shape)?,
2310            Tensor::from_vec(idx_vals, &final_shape)?,
2311        ))
2312    }
2313
2314    /// Min along an axis with optional keepdim. Return `(values, indices)`.
2315    ///
2316    /// Ties are broken by choosing the first occurrence (smallest index).
2317    pub fn min_axis(&self, axis: usize, keepdim: bool) -> Result<(Tensor, Tensor), RuntimeError> {
2318        let ndim = self.ndim();
2319        if axis >= ndim {
2320            return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2321        }
2322        let axis_len = self.shape[axis];
2323        let mut out_shape = self.shape.clone();
2324        out_shape[axis] = 1;
2325        let out_numel = Self::shape_numel(&out_shape);
2326        let out_strides = Self::compute_strides(&out_shape);
2327        let data = self.to_vec();
2328        let mut values = Vec::with_capacity(out_numel);
2329        let mut idx_vals = Vec::with_capacity(out_numel);
2330        let mut indices = vec![0usize; ndim];
2331
2332        for out_idx in 0..out_numel {
2333            let mut remaining = out_idx;
2334            for d in 0..ndim {
2335                indices[d] = remaining / out_strides[d];
2336                remaining %= out_strides[d];
2337            }
2338            let mut best_val = f64::INFINITY;
2339            let mut best_idx = 0usize;
2340            for k in 0..axis_len {
2341                let mut flat = self.offset;
2342                for d in 0..ndim {
2343                    let idx = if d == axis { k } else { indices[d] };
2344                    flat += idx * self.strides[d];
2345                }
2346                let v = data[flat];
2347                if v < best_val {
2348                    best_val = v;
2349                    best_idx = k;
2350                }
2351            }
2352            values.push(best_val);
2353            idx_vals.push(best_idx as f64);
2354        }
2355
2356        let final_shape = if keepdim {
2357            out_shape
2358        } else {
2359            let mut s: Vec<usize> = self.shape.iter().enumerate()
2360                .filter(|&(i, _)| i != axis).map(|(_, &v)| v).collect();
2361            if s.is_empty() { s.push(1); }
2362            s
2363        };
2364        Ok((
2365            Tensor::from_vec(values, &final_shape)?,
2366            Tensor::from_vec(idx_vals, &final_shape)?,
2367        ))
2368    }
2369
2370    /// Variance along an axis with optional keepdim.
2371    ///
2372    /// Computes population variance: `Var = sum((x - mean)^2) / N`.
2373    /// Uses [`BinnedAccumulatorF64`](crate::accumulator::BinnedAccumulatorF64)
2374    /// for the squared-differences summation.
2375    pub fn var_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2376        let mean_t = self.mean_axis(axis, true)?;
2377        let ndim = self.ndim();
2378        if axis >= ndim {
2379            return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2380        }
2381        let axis_len = self.shape[axis];
2382        let mut out_shape = self.shape.clone();
2383        out_shape[axis] = 1;
2384        let out_numel = Self::shape_numel(&out_shape);
2385        let out_strides = Self::compute_strides(&out_shape);
2386        let data = self.to_vec();
2387        let mean_data = mean_t.to_vec();
2388        let mut result = Vec::with_capacity(out_numel);
2389        let mut indices = vec![0usize; ndim];
2390
2391        for out_idx in 0..out_numel {
2392            let mut remaining = out_idx;
2393            for d in 0..ndim {
2394                indices[d] = remaining / out_strides[d];
2395                remaining %= out_strides[d];
2396            }
2397            let mu = mean_data[out_idx];
2398            let mut acc = BinnedAccumulatorF64::new();
2399            for k in 0..axis_len {
2400                let mut flat = self.offset;
2401                for d in 0..ndim {
2402                    let idx = if d == axis { k } else { indices[d] };
2403                    flat += idx * self.strides[d];
2404                }
2405                let diff = data[flat] - mu;
2406                acc.add(diff * diff);
2407            }
2408            result.push(acc.finalize() / axis_len as f64);
2409        }
2410
2411        let final_shape = if keepdim {
2412            out_shape
2413        } else {
2414            let mut s: Vec<usize> = self.shape.iter().enumerate()
2415                .filter(|&(i, _)| i != axis).map(|(_, &v)| v).collect();
2416            if s.is_empty() { s.push(1); }
2417            s
2418        };
2419        Tensor::from_vec(result, &final_shape)
2420    }
2421
2422    /// Standard deviation along an axis with optional keepdim.
2423    ///
2424    /// Computed as `sqrt(var_axis(axis, keepdim))`.
2425    pub fn std_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2426        let var = self.var_axis(axis, keepdim)?;
2427        Ok(var.map(|x| x.sqrt()))
2428    }
2429
2430    /// Product along an axis with optional keepdim.
2431    ///
2432    /// Computes a simple sequential product (exact for integer-like values).
2433    pub fn prod_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2434        self.reduce_axis(axis, keepdim, |vals| {
2435            // Product via exp(sum(ln(abs))) for numerical stability is overkill here;
2436            // simple product is deterministic and exact for integer-like values.
2437            let mut product = 1.0f64;
2438            for &v in vals { product *= v; }
2439            product
2440        })
2441    }
2442
2443    // -----------------------------------------------------------------------
2444    // Phase 2: Sort Operations
2445    // -----------------------------------------------------------------------
2446
2447    /// Sort along an axis (stable sort). Return the sorted tensor.
2448    ///
2449    /// For N-D tensors, independently sorts each 1-D slice along the specified
2450    /// axis. Uses `f64::partial_cmp` with deterministic tie-breaking by
2451    /// original index position.
2452    pub fn sort_axis(&self, axis: usize, descending: bool) -> Result<Tensor, RuntimeError> {
2453        let ndim = self.ndim();
2454        if axis >= ndim {
2455            return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2456        }
2457        let data = self.to_vec();
2458        let axis_len = self.shape[axis];
2459        let out_shape = self.shape.clone();
2460        let out_numel = Self::shape_numel(&out_shape);
2461
2462        // Build strides for iterating over all non-axis positions
2463        let mut iter_shape: Vec<usize> = Vec::new();
2464        for (i, &s) in self.shape.iter().enumerate() {
2465            if i != axis { iter_shape.push(s); }
2466        }
2467        let n_slices: usize = iter_shape.iter().product::<usize>().max(1);
2468
2469        let mut result = vec![0.0f64; out_numel];
2470
2471        // We iterate over all positions with axis index = 0
2472        let mut pos = vec![0usize; ndim];
2473        for slice_idx in 0..n_slices {
2474            // Compute the N-D position (with axis dim = 0)
2475            let mut remaining = slice_idx;
2476            let mut dim_idx = 0;
2477            for d in 0..ndim {
2478                if d == axis {
2479                    pos[d] = 0;
2480                } else {
2481                    let stride = {
2482                        let mut s = 1usize;
2483                        let mut di = 0;
2484                        for d2 in 0..ndim {
2485                            if d2 == axis { continue; }
2486                            if di > dim_idx { s *= self.shape[d2]; }
2487                            di += 1;
2488                        }
2489                        s
2490                    };
2491                    pos[d] = remaining / stride;
2492                    remaining %= stride;
2493                    dim_idx += 1;
2494                }
2495            }
2496
2497            // Gather values along axis
2498            let mut vals: Vec<(f64, usize)> = Vec::with_capacity(axis_len);
2499            for k in 0..axis_len {
2500                let mut flat = self.offset;
2501                for d in 0..ndim {
2502                    let idx = if d == axis { k } else { pos[d] };
2503                    flat += idx * self.strides[d];
2504                }
2505                vals.push((data[flat], k));
2506            }
2507
2508            // Stable sort with deterministic tie-breaking by original index
2509            if descending {
2510                vals.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
2511                    .then(a.1.cmp(&b.1)));
2512            } else {
2513                vals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
2514                    .then(a.1.cmp(&b.1)));
2515            }
2516
2517            // Scatter back
2518            for (k, &(v, _)) in vals.iter().enumerate() {
2519                let mut flat = 0;
2520                let out_strides_local = Self::compute_strides(&out_shape);
2521                for d in 0..ndim {
2522                    let idx = if d == axis { k } else { pos[d] };
2523                    flat += idx * out_strides_local[d];
2524                }
2525                result[flat] = v;
2526            }
2527        }
2528
2529        Tensor::from_vec(result, &out_shape)
2530    }
2531
2532    /// N-D argsort along an axis. Return a tensor of indices that would sort
2533    /// each slice along the given axis.
2534    ///
2535    /// Deterministic tie-breaking: ties are resolved by original index order.
2536    pub fn argsort_axis(&self, axis: usize, descending: bool) -> Result<Tensor, RuntimeError> {
2537        let ndim = self.ndim();
2538        if axis >= ndim {
2539            return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2540        }
2541        let data = self.to_vec();
2542        let axis_len = self.shape[axis];
2543        let out_shape = self.shape.clone();
2544        let out_numel = Self::shape_numel(&out_shape);
2545
2546        let mut iter_shape: Vec<usize> = Vec::new();
2547        for (i, &s) in self.shape.iter().enumerate() {
2548            if i != axis { iter_shape.push(s); }
2549        }
2550        let n_slices: usize = iter_shape.iter().product::<usize>().max(1);
2551
2552        let mut result = vec![0.0f64; out_numel];
2553        let mut pos = vec![0usize; ndim];
2554
2555        for slice_idx in 0..n_slices {
2556            let mut remaining = slice_idx;
2557            let mut dim_idx = 0;
2558            for d in 0..ndim {
2559                if d == axis {
2560                    pos[d] = 0;
2561                } else {
2562                    let stride = {
2563                        let mut s = 1usize;
2564                        let mut di = 0;
2565                        for d2 in 0..ndim {
2566                            if d2 == axis { continue; }
2567                            if di > dim_idx { s *= self.shape[d2]; }
2568                            di += 1;
2569                        }
2570                        s
2571                    };
2572                    pos[d] = remaining / stride;
2573                    remaining %= stride;
2574                    dim_idx += 1;
2575                }
2576            }
2577
2578            let mut vals: Vec<(f64, usize)> = Vec::with_capacity(axis_len);
2579            for k in 0..axis_len {
2580                let mut flat = self.offset;
2581                for d in 0..ndim {
2582                    let idx = if d == axis { k } else { pos[d] };
2583                    flat += idx * self.strides[d];
2584                }
2585                vals.push((data[flat], k));
2586            }
2587
2588            if descending {
2589                vals.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
2590                    .then(a.1.cmp(&b.1)));
2591            } else {
2592                vals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
2593                    .then(a.1.cmp(&b.1)));
2594            }
2595
2596            for (k, &(_, orig_idx)) in vals.iter().enumerate() {
2597                let out_strides_local = Self::compute_strides(&out_shape);
2598                let mut flat = 0;
2599                for d in 0..ndim {
2600                    let idx = if d == axis { k } else { pos[d] };
2601                    flat += idx * out_strides_local[d];
2602                }
2603                result[flat] = orig_idx as f64;
2604            }
2605        }
2606
2607        Tensor::from_vec(result, &out_shape)
2608    }
2609
2610    // -----------------------------------------------------------------------
2611    // Phase 2: Einsum
2612    // -----------------------------------------------------------------------
2613
2614    /// Einstein summation notation.
2615    /// Supports patterns like "ij,jk->ik" (matmul), "ii->i" (diagonal),
2616    /// "ij->ji" (transpose), "ijk,ikl->ijl" (batched matmul).
2617    /// Uses BinnedAccumulator for all reductions.
2618    pub fn einsum(notation: &str, inputs: &[&Tensor]) -> Result<Tensor, RuntimeError> {
2619        // Parse notation: "subscripts->output" or just "subscripts" (implicit)
2620        let parts: Vec<&str> = notation.split("->").collect();
2621        if parts.len() != 2 {
2622            return Err(RuntimeError::InvalidOperation(
2623                format!("einsum: expected 'subscripts->output' notation, got '{}'", notation),
2624            ));
2625        }
2626        let input_specs: Vec<&str> = parts[0].split(',').collect();
2627        let output_spec = parts[1];
2628
2629        if input_specs.len() != inputs.len() {
2630            return Err(RuntimeError::InvalidOperation(
2631                format!("einsum: {} input specs but {} tensors", input_specs.len(), inputs.len()),
2632            ));
2633        }
2634
2635        // Build label → size mapping
2636        let mut label_size = std::collections::BTreeMap::new();
2637        for (i, &spec) in input_specs.iter().enumerate() {
2638            let chars: Vec<char> = spec.chars().collect();
2639            if chars.len() != inputs[i].ndim() {
2640                return Err(RuntimeError::InvalidOperation(
2641                    format!("einsum: spec '{}' has {} dims but tensor has {}", spec, chars.len(), inputs[i].ndim()),
2642                ));
2643            }
2644            for (d, &c) in chars.iter().enumerate() {
2645                let sz = inputs[i].shape()[d];
2646                if let Some(&prev) = label_size.get(&c) {
2647                    if prev != sz {
2648                        return Err(RuntimeError::InvalidOperation(
2649                            format!("einsum: label '{}' has conflicting sizes {} vs {}", c, prev, sz),
2650                        ));
2651                    }
2652                } else {
2653                    label_size.insert(c, sz);
2654                }
2655            }
2656        }
2657
2658        // Determine output shape
2659        let output_chars: Vec<char> = output_spec.chars().collect();
2660        let output_shape: Vec<usize> = output_chars.iter()
2661            .map(|c| label_size.get(c).copied().ok_or_else(||
2662                RuntimeError::InvalidOperation(format!("einsum: unknown label '{}' in output", c))))
2663            .collect::<Result<_, _>>()?;
2664        let out_numel = Self::shape_numel(&output_shape);
2665
2666        // Determine contraction labels (in input but not output)
2667        let output_set: std::collections::BTreeSet<char> = output_chars.iter().copied().collect();
2668        let contract_labels: Vec<char> = label_size.keys()
2669            .filter(|c| !output_set.contains(c))
2670            .copied()
2671            .collect();
2672        let contract_sizes: Vec<usize> = contract_labels.iter()
2673            .map(|c| label_size[c])
2674            .collect();
2675        let contract_numel: usize = contract_sizes.iter().product::<usize>().max(1);
2676
2677        // Precompute input spec chars
2678        let input_chars: Vec<Vec<char>> = input_specs.iter().map(|s| s.chars().collect()).collect();
2679
2680        // For each output position, iterate over contraction indices
2681        let out_strides = Self::compute_strides(&output_shape);
2682        let mut result = vec![0.0f64; out_numel];
2683
2684        // Pre-read input data
2685        let input_data: Vec<Vec<f64>> = inputs.iter().map(|t| t.to_vec()).collect();
2686        let input_strides: Vec<Vec<usize>> = inputs.iter().map(|t| t.strides.clone()).collect();
2687        let input_offsets: Vec<usize> = inputs.iter().map(|t| t.offset).collect();
2688
2689        for out_idx in 0..out_numel {
2690            // Compute output label values
2691            let mut label_vals = std::collections::BTreeMap::new();
2692            let mut remaining = out_idx;
2693            for (d, &c) in output_chars.iter().enumerate() {
2694                let stride = if d < out_strides.len() { out_strides[d] } else { 1 };
2695                label_vals.insert(c, remaining / stride);
2696                remaining %= stride;
2697            }
2698
2699            let mut acc = BinnedAccumulatorF64::new();
2700            // Iterate over all contraction index combinations
2701            for cidx in 0..contract_numel {
2702                // Compute contraction label values
2703                let mut cr = cidx;
2704                for (ci, &cl) in contract_labels.iter().enumerate() {
2705                    let stride: usize = contract_sizes[ci+1..].iter().product::<usize>().max(1);
2706                    label_vals.insert(cl, cr / stride);
2707                    cr %= stride;
2708                }
2709
2710                // Compute product of input elements
2711                let mut product = 1.0f64;
2712                for (inp_idx, chars) in input_chars.iter().enumerate() {
2713                    let mut flat = input_offsets[inp_idx];
2714                    for (d, &c) in chars.iter().enumerate() {
2715                        flat += label_vals[&c] * input_strides[inp_idx][d];
2716                    }
2717                    product *= input_data[inp_idx][flat];
2718                }
2719                acc.add(product);
2720            }
2721            result[out_idx] = acc.finalize();
2722        }
2723
2724        if output_shape.is_empty() {
2725            Tensor::from_vec(result, &[1])
2726        } else {
2727            Tensor::from_vec(result, &output_shape)
2728        }
2729    }
2730
2731    // -----------------------------------------------------------------------
2732    // Phase 2: Reshape / View Enhancements
2733    // -----------------------------------------------------------------------
2734
2735    /// Add a dimension of size 1 at position `dim`.
2736    ///
2737    /// For a tensor of shape `[A, B]`, `unsqueeze(0)` yields `[1, A, B]`,
2738    /// `unsqueeze(1)` yields `[A, 1, B]`, etc.
2739    pub fn unsqueeze(&self, dim: usize) -> Result<Tensor, RuntimeError> {
2740        let ndim = self.ndim();
2741        if dim > ndim {
2742            return Err(RuntimeError::IndexOutOfBounds { index: dim, length: ndim + 1 });
2743        }
2744        let mut new_shape = self.shape.clone();
2745        new_shape.insert(dim, 1);
2746        self.reshape(&new_shape)
2747    }
2748
2749    /// Remove a dimension of size 1 at position `dim`.
2750    /// If `dim` is `None`, removes all dimensions of size 1.
2751    pub fn squeeze(&self, dim: Option<usize>) -> Result<Tensor, RuntimeError> {
2752        match dim {
2753            Some(d) => {
2754                if d >= self.ndim() {
2755                    return Err(RuntimeError::IndexOutOfBounds { index: d, length: self.ndim() });
2756                }
2757                if self.shape[d] != 1 {
2758                    return Err(RuntimeError::InvalidOperation(
2759                        format!("squeeze: dimension {} has size {}, not 1", d, self.shape[d]),
2760                    ));
2761                }
2762                let mut new_shape = self.shape.clone();
2763                new_shape.remove(d);
2764                if new_shape.is_empty() {
2765                    new_shape.push(1); // scalar
2766                }
2767                self.reshape(&new_shape)
2768            }
2769            None => {
2770                let new_shape: Vec<usize> = self.shape.iter()
2771                    .filter(|&&s| s != 1)
2772                    .copied()
2773                    .collect();
2774                let new_shape = if new_shape.is_empty() { vec![1] } else { new_shape };
2775                self.reshape(&new_shape)
2776            }
2777        }
2778    }
2779
2780    /// Broadcast without copying. Return a view with `stride=0` for broadcasted dims.
2781    ///
2782    /// Alias for [`broadcast_to`](Tensor::broadcast_to).
2783    pub fn expand(&self, target_shape: &[usize]) -> Result<Tensor, RuntimeError> {
2784        self.broadcast_to(target_shape)
2785    }
2786
2787    /// Flatten a range of dimensions [start_dim, end_dim] into a single dimension.
2788    pub fn flatten(&self, start_dim: usize, end_dim: usize) -> Result<Tensor, RuntimeError> {
2789        if start_dim > end_dim || end_dim >= self.ndim() {
2790            return Err(RuntimeError::InvalidOperation(
2791                format!("flatten: invalid dim range [{}, {}] for {}D tensor", start_dim, end_dim, self.ndim()),
2792            ));
2793        }
2794        let mut new_shape = Vec::new();
2795        for i in 0..start_dim {
2796            new_shape.push(self.shape[i]);
2797        }
2798        let flat_size: usize = self.shape[start_dim..=end_dim].iter().product();
2799        new_shape.push(flat_size);
2800        for i in (end_dim + 1)..self.ndim() {
2801            new_shape.push(self.shape[i]);
2802        }
2803        self.reshape(&new_shape)
2804    }
2805
2806    /// Split tensor into `n` roughly equal chunks along dimension `dim`.
2807    pub fn chunk(&self, n: usize, dim: usize) -> Result<Vec<Tensor>, RuntimeError> {
2808        if dim >= self.ndim() {
2809            return Err(RuntimeError::IndexOutOfBounds { index: dim, length: self.ndim() });
2810        }
2811        if n == 0 {
2812            return Err(RuntimeError::InvalidOperation("chunk: n must be > 0".into()));
2813        }
2814        let dim_size = self.shape[dim];
2815        let chunk_size = (dim_size + n - 1) / n;
2816        let mut sizes = Vec::new();
2817        let mut remaining = dim_size;
2818        while remaining > 0 {
2819            let s = remaining.min(chunk_size);
2820            sizes.push(s);
2821            remaining -= s;
2822        }
2823        self.split(&sizes, dim)
2824    }
2825
2826    /// Split tensor along dimension `dim` according to the given sizes.
2827    pub fn split(&self, sizes: &[usize], dim: usize) -> Result<Vec<Tensor>, RuntimeError> {
2828        if dim >= self.ndim() {
2829            return Err(RuntimeError::IndexOutOfBounds { index: dim, length: self.ndim() });
2830        }
2831        let total: usize = sizes.iter().sum();
2832        if total != self.shape[dim] {
2833            return Err(RuntimeError::InvalidOperation(
2834                format!("split: sizes sum {} != dim size {}", total, self.shape[dim]),
2835            ));
2836        }
2837
2838        let mut results = Vec::new();
2839        let mut offset = 0;
2840
2841        for &sz in sizes {
2842            let ranges: Vec<(usize, usize)> = self.shape.iter()
2843                .enumerate()
2844                .map(|(i, &s)| {
2845                    if i == dim { (offset, offset + sz) } else { (0, s) }
2846                })
2847                .collect();
2848            let chunk = self.slice(&ranges)?;
2849            // Materialize as contiguous
2850            results.push(chunk.to_contiguous());
2851            offset += sz;
2852        }
2853
2854        Ok(results)
2855    }
2856
2857    /// Fused `alpha * self + beta * other` element-wise. Single pass, one allocation.
2858    ///
2859    /// Critical for LSTM/GRU gates where `f * c_prev + i * g` would otherwise
2860    /// create 3 intermediate tensors.
2861    pub fn scale_add(&self, alpha: f64, other: &Tensor, beta: f64) -> Result<Tensor, RuntimeError> {
2862        if self.shape != other.shape {
2863            return Err(RuntimeError::InvalidOperation(
2864                "scale_add: shape mismatch".to_string(),
2865            ));
2866        }
2867        let a = self.to_vec();
2868        let b = other.to_vec();
2869        let result: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| alpha * x + beta * y).collect();
2870        Tensor::from_vec(result, &self.shape)
2871    }
2872}
2873