Skip to main content

cjc_runtime/
tensor.rs

1//! N-dimensional tensor with element-wise, reduction, linalg, and NN operations.
2//!
3//! [`Tensor`] is the primary numerical type in CJC. It is backed by a
4//! [`Buffer<f64>`] with COW (copy-on-write) semantics, so cloning a tensor
5//! is O(1) and mutation triggers a deep copy only when the buffer is shared.
6//!
7//! # Determinism Guarantees
8//!
9//! - **Reductions** (`sum`, `mean`, `sum_axis`) use [`BinnedAccumulatorF64`]
10//!   for order-invariant, bit-identical results.
11//! - **Matmul** uses Kahan-compensated accumulation (sequential path) or
12//!   tiled + parallel strategies (large matrices) with deterministic per-row
13//!   thread assignment.
14//! - **SIMD** kernels (via [`tensor_simd`]) avoid hardware FMA to preserve
15//!   cross-platform bit-identity.
16//! - **No** `HashMap`/`HashSet` anywhere -- all ordering is deterministic.
17//!
18//! # Layout
19//!
20//! Tensors use row-major (C-order) layout with explicit strides. Non-contiguous
21//! views (from `slice`, `transpose`, `broadcast_to`) share the underlying
22//! buffer with adjusted strides and offset. Call [`Tensor::to_contiguous`] to
23//! materialize a contiguous copy when needed.
24//!
25//! # Operation Categories
26//!
27//! | Category | Methods |
28//! |----------|---------|
29//! | Construction | `zeros`, `ones`, `randn`, `from_vec`, `from_bytes` |
30//! | Shape | `shape`, `ndim`, `len`, `reshape`, `transpose`, `slice`, `unsqueeze`, `squeeze`, `flatten` |
31//! | Element-wise | `add`, `sub`, `mul_elem`, `div_elem`, `elem_pow`, `map`, `map_simd` |
32//! | Reductions | `sum`, `mean`, `sum_axis`, `mean_axis`, `var_axis`, `std_axis` |
33//! | Linalg | `matmul`, `bmm`, `linear`, `einsum` |
34//! | NN | `softmax`, `layer_norm`, `relu`, `sigmoid`, `gelu`, `conv1d`, `conv2d`, `maxpool2d` |
35//! | Indexing | `get`, `set`, `gather`, `scatter`, `index_select`, `argsort` |
36//! | Attention | `scaled_dot_product_attention`, `split_heads`, `merge_heads` |
37//!
38//! [`Buffer<f64>`]: crate::buffer::Buffer
39//! [`BinnedAccumulatorF64`]: crate::accumulator::BinnedAccumulatorF64
40//! [`tensor_simd`]: crate::tensor_simd
41
42use cjc_repro::Rng;
43
44use crate::accumulator::{binned_sum_f64, BinnedAccumulatorF64};
45use cjc_repro::KahanAccumulatorF64;
46
47use crate::accumulator;
48use crate::buffer::Buffer;
49use crate::dispatch;
50use crate::error::RuntimeError;
51use crate::kernel as kernel_fns;
52use crate::tensor_simd::{self, BinOp, UnaryOp};
53use crate::tensor_tiled::TiledMatmul;
54
55// ---------------------------------------------------------------------------
56// 2. Tensor Runtime
57// ---------------------------------------------------------------------------
58
59/// An N-dimensional tensor backed by a [`Buffer<f64>`](crate::buffer::Buffer).
60///
61/// Supports element-wise arithmetic (SIMD-accelerated), matrix multiplication
62/// (tiled + parallel), numerically-stable reductions via
63/// [`BinnedAccumulatorF64`](crate::accumulator::BinnedAccumulatorF64),
64/// and neural network operations (softmax, layer norm, attention).
65///
66/// # Memory
67///
68/// The underlying [`Buffer<f64>`](crate::buffer::Buffer) uses copy-on-write
69/// semantics. Cloning a `Tensor` is O(1). Operations like `reshape` and
70/// `transpose` return zero-copy views when possible. Mutation via `set`
71/// triggers a deep copy only when the buffer is shared.
72#[derive(Debug, Clone)]
73pub struct Tensor {
74    /// The underlying COW data buffer.
75    pub buffer: Buffer<f64>,
76    /// Dimensions of the tensor (e.g., `[3, 4]` for a 3x4 matrix).
77    pub(crate) shape: Vec<usize>,
78    /// Row-major strides for each dimension. Non-contiguous views may have
79    /// strides that differ from the standard row-major pattern (e.g., stride=0
80    /// for broadcast dimensions).
81    pub(crate) strides: Vec<usize>,
82    /// Byte offset into the buffer where this tensor's data begins.
83    /// Non-zero for slice views.
84    pub(crate) offset: usize,
85}
86
87impl Tensor {
88    // -- Construction -------------------------------------------------------
89
90    /// Compute row-major strides for a given shape.
91    pub(crate) fn compute_strides(shape: &[usize]) -> Vec<usize> {
92        let mut strides = vec![1usize; shape.len()];
93        for i in (0..shape.len().saturating_sub(1)).rev() {
94            strides[i] = strides[i + 1] * shape[i + 1];
95        }
96        strides
97    }
98
99    /// Total number of elements implied by `shape`.
100    fn shape_numel(shape: &[usize]) -> usize {
101        shape.iter().product()
102    }
103
104    /// Create a tensor filled with zeros.
105    pub fn zeros(shape: &[usize]) -> Self {
106        let numel = Self::shape_numel(shape);
107        Tensor {
108            buffer: Buffer::alloc(numel, 0.0),
109            shape: shape.to_vec(),
110            strides: Self::compute_strides(shape),
111            offset: 0,
112        }
113    }
114
115    /// Create a tensor filled with ones.
116    pub fn ones(shape: &[usize]) -> Self {
117        let numel = Self::shape_numel(shape);
118        Tensor {
119            buffer: Buffer::alloc(numel, 1.0),
120            shape: shape.to_vec(),
121            strides: Self::compute_strides(shape),
122            offset: 0,
123        }
124    }
125
126    /// Create a tensor filled with samples from the standard normal
127    /// distribution, drawn deterministically from `rng`.
128    pub fn randn(shape: &[usize], rng: &mut Rng) -> Self {
129        let numel = Self::shape_numel(shape);
130        let data: Vec<f64> = (0..numel).map(|_| rng.next_normal_f64()).collect();
131        Tensor {
132            buffer: Buffer::from_vec(data),
133            shape: shape.to_vec(),
134            strides: Self::compute_strides(shape),
135            offset: 0,
136        }
137    }
138
139    /// Create a tensor from raw data and a shape. Returns an error if the
140    /// number of elements does not match the shape.
141    pub fn from_vec(data: Vec<f64>, shape: &[usize]) -> Result<Self, RuntimeError> {
142        let numel = Self::shape_numel(shape);
143        if data.len() != numel {
144            return Err(RuntimeError::ShapeMismatch {
145                expected: numel,
146                got: data.len(),
147            });
148        }
149        Ok(Tensor {
150            buffer: Buffer::from_vec(data),
151            shape: shape.to_vec(),
152            strides: Self::compute_strides(shape),
153            offset: 0,
154        })
155    }
156
157    // -- Accessors ----------------------------------------------------------
158
159    /// The shape of this tensor.
160    pub fn shape(&self) -> &[usize] {
161        &self.shape
162    }
163
164    /// Number of dimensions.
165    pub fn ndim(&self) -> usize {
166        self.shape.len()
167    }
168
169    /// Total number of elements.
170    pub fn len(&self) -> usize {
171        Self::shape_numel(&self.shape)
172    }
173
174    /// Whether the tensor has zero elements.
175    pub fn is_empty(&self) -> bool {
176        self.len() == 0
177    }
178
179    /// Flatten a multi-dimensional index to a linear offset in the buffer.
180    fn linear_index(&self, indices: &[usize]) -> Result<usize, RuntimeError> {
181        if indices.len() != self.shape.len() {
182            return Err(RuntimeError::DimensionMismatch {
183                expected: self.shape.len(),
184                got: indices.len(),
185            });
186        }
187        let mut off = self.offset;
188        for (i, &idx) in indices.iter().enumerate() {
189            if idx >= self.shape[i] {
190                return Err(RuntimeError::IndexOutOfBounds {
191                    index: idx,
192                    length: self.shape[i],
193                });
194            }
195            off += idx * self.strides[i];
196        }
197        Ok(off)
198    }
199
200    /// Whether this tensor is contiguous in memory (row-major, no offset).
201    pub fn is_contiguous(&self) -> bool {
202        if self.offset != 0 {
203            return false;
204        }
205        let expected = Self::compute_strides(&self.shape);
206        self.strides == expected
207    }
208
209    /// Create a zero-copy slice (view) of this tensor.
210    /// `ranges` contains `(start, end)` for each dimension.
211    pub fn slice(&self, ranges: &[(usize, usize)]) -> Result<Tensor, RuntimeError> {
212        if ranges.len() != self.shape.len() {
213            return Err(RuntimeError::DimensionMismatch {
214                expected: self.shape.len(),
215                got: ranges.len(),
216            });
217        }
218        let mut new_offset = self.offset;
219        let mut new_shape = Vec::with_capacity(ranges.len());
220        for (i, &(start, end)) in ranges.iter().enumerate() {
221            if end > self.shape[i] || start > end {
222                return Err(RuntimeError::IndexOutOfBounds {
223                    index: end,
224                    length: self.shape[i],
225                });
226            }
227            new_offset += start * self.strides[i];
228            new_shape.push(end - start);
229        }
230        Ok(Tensor {
231            buffer: self.buffer.clone(), // shared — zero copy
232            shape: new_shape,
233            strides: self.strides.clone(),
234            offset: new_offset,
235        })
236    }
237
238    /// Materialize a contiguous copy if this tensor is non-contiguous.
239    pub fn to_contiguous(&self) -> Tensor {
240        if self.is_contiguous() {
241            return self.clone();
242        }
243        let data = self.to_vec();
244        Tensor {
245            buffer: Buffer::from_vec(data),
246            shape: self.shape.clone(),
247            strides: Self::compute_strides(&self.shape),
248            offset: 0,
249        }
250    }
251
252    /// Create a broadcast view of this tensor to `target_shape`.
253    /// Uses stride=0 for dimensions that need broadcasting (size 1 -> target size).
254    pub fn broadcast_to(&self, target_shape: &[usize]) -> Result<Tensor, RuntimeError> {
255        let src_ndim = self.shape.len();
256        let tgt_ndim = target_shape.len();
257        if tgt_ndim < src_ndim {
258            return Err(RuntimeError::InvalidOperation(
259                "cannot broadcast to a smaller rank".to_string(),
260            ));
261        }
262        let pad = tgt_ndim - src_ndim;
263        let mut new_strides = vec![0usize; tgt_ndim];
264        for i in 0..tgt_ndim {
265            if i < pad {
266                // Padded dimension: stride = 0 (broadcast)
267                new_strides[i] = 0;
268            } else {
269                let src_i = i - pad;
270                if self.shape[src_i] == target_shape[i] {
271                    new_strides[i] = self.strides[src_i];
272                } else if self.shape[src_i] == 1 {
273                    new_strides[i] = 0; // broadcast
274                } else {
275                    return Err(RuntimeError::ShapeMismatch {
276                        expected: target_shape[i],
277                        got: self.shape[src_i],
278                    });
279                }
280            }
281        }
282        Ok(Tensor {
283            buffer: self.buffer.clone(),
284            shape: target_shape.to_vec(),
285            strides: new_strides,
286            offset: self.offset,
287        })
288    }
289
290    /// Read the element at the given multi-dimensional index.
291    pub fn get(&self, indices: &[usize]) -> Result<f64, RuntimeError> {
292        let offset = self.linear_index(indices)?;
293        self.buffer
294            .get(offset)
295            .ok_or(RuntimeError::IndexOutOfBounds {
296                index: offset,
297                length: self.buffer.len(),
298            })
299    }
300
301    /// Write the element at the given multi-dimensional index.
302    pub fn set(&mut self, indices: &[usize], val: f64) -> Result<(), RuntimeError> {
303        let offset = self.linear_index(indices)?;
304        self.buffer.set(offset, val)
305    }
306
307    /// Extract the raw data as a `Vec<f64>`, respecting strides and offset.
308    pub fn to_vec(&self) -> Vec<f64> {
309        if self.is_contiguous() {
310            let full = self.buffer.borrow_data();
311            let numel = self.len();
312            if full.len() == numel {
313                return full.to_vec();
314            }
315            // Buffer may be larger than the tensor's logical size
316            // (e.g. Scratchpad pre-allocates extra capacity)
317            return full[..numel].to_vec();
318        }
319        // Non-contiguous: iterate via strided access
320        let numel = self.len();
321        let mut result = Vec::with_capacity(numel);
322        let ndim = self.shape.len();
323        let mut indices = vec![0usize; ndim];
324        for _ in 0..numel {
325            let mut off = self.offset;
326            for d in 0..ndim {
327                off += indices[d] * self.strides[d];
328            }
329            result.push(self.buffer.get(off).unwrap_or(0.0));
330            // Increment multi-index (row-major order)
331            for d in (0..ndim).rev() {
332                indices[d] += 1;
333                if indices[d] < self.shape[d] {
334                    break;
335                }
336                indices[d] = 0;
337            }
338        }
339        result
340    }
341
342    // -- Reshape ------------------------------------------------------------
343
344    /// Reshape to `new_shape`. The new shape must have the same total number
345    /// of elements. The returned tensor **shares** the underlying buffer.
346    pub fn reshape(&self, new_shape: &[usize]) -> Result<Tensor, RuntimeError> {
347        let new_numel = Self::shape_numel(new_shape);
348        if new_numel != self.len() {
349            return Err(RuntimeError::ShapeMismatch {
350                expected: self.len(),
351                got: new_numel,
352            });
353        }
354        // Reshape requires contiguous data; materialize if needed
355        let tensor = if self.is_contiguous() { self.clone() } else { self.to_contiguous() };
356        Ok(Tensor {
357            buffer: tensor.buffer,
358            shape: new_shape.to_vec(),
359            strides: Self::compute_strides(new_shape),
360            offset: 0,
361        })
362    }
363
364    // -- Element-wise operations --------------------------------------------
365
366    /// Apply a binary operation element-wise with broadcasting support.
367    fn elementwise_binop(
368        &self,
369        other: &Tensor,
370        op: impl Fn(f64, f64) -> f64,
371    ) -> Result<Tensor, RuntimeError> {
372        if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
373            // Fast path: same shape, both contiguous — borrow without cloning
374            let a = self.buffer.borrow_data();
375            let b = other.buffer.borrow_data();
376            let data: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| op(x, y)).collect();
377            return Ok(Tensor {
378                buffer: Buffer::from_vec(data),
379                shape: self.shape.clone(),
380                strides: Self::compute_strides(&self.shape),
381                offset: 0,
382            });
383        }
384
385        // Broadcasting path: compute result shape
386        let result_shape = Self::broadcast_result_shape(&self.shape, &other.shape)?;
387        let a_broadcast = self.broadcast_to(&result_shape)?;
388        let b_broadcast = other.broadcast_to(&result_shape)?;
389
390        let numel = Self::shape_numel(&result_shape);
391        let ndim = result_shape.len();
392        let mut data = Vec::with_capacity(numel);
393        let mut indices = vec![0usize; ndim];
394
395        for _ in 0..numel {
396            let mut off_a = a_broadcast.offset;
397            let mut off_b = b_broadcast.offset;
398            for d in 0..ndim {
399                off_a += indices[d] * a_broadcast.strides[d];
400                off_b += indices[d] * b_broadcast.strides[d];
401            }
402            let va = a_broadcast.buffer.get(off_a).ok_or_else(|| {
403                RuntimeError::InvalidOperation(format!(
404                    "broadcast binop: left operand index {} out of bounds (buffer len {})",
405                    off_a,
406                    a_broadcast.buffer.len()
407                ))
408            })?;
409            let vb = b_broadcast.buffer.get(off_b).ok_or_else(|| {
410                RuntimeError::InvalidOperation(format!(
411                    "broadcast binop: right operand index {} out of bounds (buffer len {})",
412                    off_b,
413                    b_broadcast.buffer.len()
414                ))
415            })?;
416            data.push(op(va, vb));
417
418            for d in (0..ndim).rev() {
419                indices[d] += 1;
420                if indices[d] < result_shape[d] {
421                    break;
422                }
423                indices[d] = 0;
424            }
425        }
426
427        Ok(Tensor {
428            buffer: Buffer::from_vec(data),
429            shape: result_shape.clone(),
430            strides: Self::compute_strides(&result_shape),
431            offset: 0,
432        })
433    }
434
435    /// Compute the broadcast result shape for two shapes (NumPy rules).
436    fn broadcast_result_shape(a: &[usize], b: &[usize]) -> Result<Vec<usize>, RuntimeError> {
437        let max_ndim = a.len().max(b.len());
438        let mut result = Vec::with_capacity(max_ndim);
439        for i in 0..max_ndim {
440            let da = if i < max_ndim - a.len() { 1 } else { a[i - (max_ndim - a.len())] };
441            let db = if i < max_ndim - b.len() { 1 } else { b[i - (max_ndim - b.len())] };
442            if da == db {
443                result.push(da);
444            } else if da == 1 {
445                result.push(db);
446            } else if db == 1 {
447                result.push(da);
448            } else {
449                return Err(RuntimeError::ShapeMismatch {
450                    expected: da,
451                    got: db,
452                });
453            }
454        }
455        Ok(result)
456    }
457
458    /// SIMD-accelerated element-wise binary operation for known ops.
459    ///
460    /// For same-shape contiguous tensors, uses AVX2 (4-wide f64) when available.
461    /// Falls back to the generic closure path for broadcast cases.
462    fn elementwise_binop_simd(
463        &self,
464        other: &Tensor,
465        op: BinOp,
466        fallback: impl Fn(f64, f64) -> f64,
467    ) -> Result<Tensor, RuntimeError> {
468        if self.shape == other.shape && self.is_contiguous() && other.is_contiguous() {
469            // SIMD fast path: same shape, both contiguous
470            let a = self.buffer.borrow_data();
471            let b = other.buffer.borrow_data();
472            let data = tensor_simd::simd_binop(&a, &b, op);
473            return Ok(Tensor {
474                buffer: Buffer::from_vec(data),
475                shape: self.shape.clone(),
476                strides: Self::compute_strides(&self.shape),
477                offset: 0,
478            });
479        }
480        // Broadcast path: fall through to generic
481        self.elementwise_binop(other, fallback)
482    }
483
484    /// Element-wise addition (SIMD-accelerated for contiguous same-shape tensors).
485    pub fn add(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
486        self.elementwise_binop_simd(other, BinOp::Add, |a, b| a + b)
487    }
488
489    /// Element-wise subtraction (SIMD-accelerated for contiguous same-shape tensors).
490    pub fn sub(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
491        self.elementwise_binop_simd(other, BinOp::Sub, |a, b| a - b)
492    }
493
494    /// Element-wise (Hadamard) multiplication (SIMD-accelerated for contiguous same-shape tensors).
495    pub fn mul_elem(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
496        self.elementwise_binop_simd(other, BinOp::Mul, |a, b| a * b)
497    }
498
499    /// Element-wise division (SIMD-accelerated for contiguous same-shape tensors).
500    pub fn div_elem(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
501        self.elementwise_binop_simd(other, BinOp::Div, |a, b| a / b)
502    }
503
504    /// Fused multiply-add: `self * b + c` element-wise in a single pass.
505    ///
506    /// Eliminates the intermediate tensor that separate mul + add would create.
507    /// Uses software FMA (`a * b + c` with two roundings, not hardware FMA)
508    /// to preserve bit-identity with the non-fused path.
509    pub fn fused_mul_add(&self, b: &Tensor, c: &Tensor) -> Result<Tensor, RuntimeError> {
510        if self.shape != b.shape || self.shape != c.shape {
511            return Err(RuntimeError::InvalidOperation(
512                "broadcast_fma: all three tensors must have the same shape".to_string(),
513            ));
514        }
515        if self.is_contiguous() && b.is_contiguous() && c.is_contiguous() {
516            let a_data = self.buffer.borrow_data();
517            let b_data = b.buffer.borrow_data();
518            let c_data = c.buffer.borrow_data();
519            let n = a_data.len();
520            let mut out = vec![0.0f64; n];
521            // Software FMA: a*b + c (two roundings — NOT hardware FMA which uses one rounding).
522            // This produces identical results to separate broadcast2("mul") + broadcast2("add").
523            for i in 0..n {
524                out[i] = a_data[i] * b_data[i] + c_data[i];
525            }
526            return Ok(Tensor {
527                buffer: Buffer::from_vec(out),
528                shape: self.shape.clone(),
529                strides: Self::compute_strides(&self.shape),
530                offset: 0,
531            });
532        }
533        // Non-contiguous fallback: mul then add
534        let temp = self.mul_elem(b)?;
535        temp.add(c)
536    }
537
538    // ── v0.1 Broadcasting: additional element-wise binary ops ──
539
540    /// Element-wise power: `a^b`.
541    pub fn elem_pow(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
542        self.elementwise_binop(other, |a, b| a.powf(b))
543    }
544
545    /// Element-wise minimum.
546    pub fn elem_min(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
547        self.elementwise_binop(other, |a, b| a.min(b))
548    }
549
550    /// Element-wise maximum.
551    pub fn elem_max(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
552        self.elementwise_binop(other, |a, b| a.max(b))
553    }
554
555    /// Element-wise atan2(self, other).
556    pub fn elem_atan2(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
557        self.elementwise_binop(other, |a, b| a.atan2(b))
558    }
559
560    /// Element-wise hypot(self, other).
561    pub fn elem_hypot(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
562        self.elementwise_binop(other, |a, b| a.hypot(b))
563    }
564
565    /// Apply a unary function to every element, returning a new contiguous tensor.
566    pub fn map(&self, f: impl Fn(f64) -> f64) -> Tensor {
567        let data: Vec<f64> = self.to_vec().iter().map(|&x| f(x)).collect();
568        Tensor {
569            buffer: Buffer::from_vec(data),
570            shape: self.shape.clone(),
571            strides: Self::compute_strides(&self.shape),
572            offset: 0,
573        }
574    }
575
576    /// SIMD-accelerated unary map for known operations (sqrt, abs, neg, relu).
577    ///
578    /// Uses AVX2 (4-wide f64) when available, scalar fallback otherwise.
579    /// Bit-identical to `map(f)` for the supported operations.
580    pub fn map_simd(&self, op: UnaryOp) -> Tensor {
581        let src = self.to_vec();
582        let data = tensor_simd::simd_unary(&src, op);
583        Tensor {
584            buffer: Buffer::from_vec(data),
585            shape: self.shape.clone(),
586            strides: Self::compute_strides(&self.shape),
587            offset: 0,
588        }
589    }
590
591    // -- Reductions (using BinnedAccumulator) --------------------------------
592
593    /// Sum of all elements (binned accumulation — order-invariant, deterministic).
594    pub fn sum(&self) -> f64 {
595        let data = self.buffer.borrow_data();
596        binned_sum_f64(&data)
597    }
598
599    /// Sum of all elements using BinnedAccumulator (order-invariant, deterministic).
600    ///
601    /// Bit-identical results regardless of element ordering or reduction schedule.
602    pub fn binned_sum(&self) -> f64 {
603        let data = self.buffer.borrow_data();
604        accumulator::binned_sum_f64(&data)
605    }
606
607    /// Sum with dispatched strategy based on execution context.
608    ///
609    /// Uses Kahan in serial mode, Binned in parallel/@nogc/strict/linalg mode.
610    pub fn dispatched_sum(&self, ctx: &dispatch::ReductionContext) -> f64 {
611        let data = self.buffer.borrow_data();
612        dispatch::dispatch_sum_f64(&data, ctx)
613    }
614
615    /// Mean of all elements (binned sum / count).
616    pub fn mean(&self) -> f64 {
617        let n = self.len();
618        if n == 0 {
619            return 0.0;
620        }
621        self.sum() / n as f64
622    }
623
624    /// Mean with dispatched strategy based on execution context.
625    pub fn dispatched_mean(&self, ctx: &dispatch::ReductionContext) -> f64 {
626        let n = self.len();
627        if n == 0 {
628            return 0.0;
629        }
630        self.dispatched_sum(ctx) / n as f64
631    }
632
633    /// Sum along a specific axis, returning a tensor with that dimension reduced.
634    ///
635    /// Supports N-D tensors. The reduced axis becomes size 1 in the output.
636    /// Uses BinnedAccumulator for order-invariant, deterministic summation.
637    ///
638    /// Examples:
639    /// - 2D [M, N] with axis=0: result [1, N] (sum columns)
640    /// - 2D [M, N] with axis=1: result [M, 1] (sum rows)
641    /// - 3D [A, B, C] with axis=1: result [A, 1, C]
642    pub fn sum_axis(&self, axis: usize) -> Result<Tensor, RuntimeError> {
643        let ndim = self.ndim();
644        if axis >= ndim {
645            return Err(RuntimeError::IndexOutOfBounds {
646                index: axis,
647                length: ndim,
648            });
649        }
650
651        // Build output shape: same as input but with axis dimension = 1
652        let mut out_shape = self.shape.clone();
653        out_shape[axis] = 1;
654        let out_numel = Self::shape_numel(&out_shape);
655        let out_strides = Self::compute_strides(&out_shape);
656
657        let data = self.to_vec();
658        let axis_len = self.shape[axis];
659        let mut result = vec![0.0f64; out_numel];
660
661        // For each output position, sum over the reduced axis with binned accumulation.
662        let mut indices = vec![0usize; ndim];
663        for out_idx in 0..out_numel {
664            // Compute the N-D index from flat output index
665            {
666                let mut remaining = out_idx;
667                for d in 0..ndim {
668                    indices[d] = remaining / out_strides[d];
669                    remaining %= out_strides[d];
670                }
671            }
672
673            let mut acc = BinnedAccumulatorF64::new();
674            for k in 0..axis_len {
675                // Compute input flat index with indices[axis] = k
676                let mut flat = self.offset;
677                for d in 0..ndim {
678                    let idx = if d == axis { k } else { indices[d] };
679                    flat += idx * self.strides[d];
680                }
681                acc.add(data[flat]);
682            }
683            result[out_idx] = acc.finalize();
684        }
685
686        Tensor::from_vec(result, &out_shape)
687    }
688
689    // -- Matrix multiplication (2-D only) -----------------------------------
690
691    /// Negate every element, returning a new tensor.
692    pub fn neg(&self) -> Tensor {
693        self.map(|x| -x)
694    }
695
696    /// Transpose a tensor. For 2-D: swaps rows and columns (zero-copy view).
697    /// For N-D: reverses all axes (zero-copy view).
698    pub fn transpose(&self) -> Tensor {
699        let ndim = self.ndim();
700        if ndim <= 1 {
701            return self.clone();
702        }
703        // Reverse shape and strides — zero-copy view
704        let mut new_shape = self.shape.clone();
705        let mut new_strides = self.strides.clone();
706        new_shape.reverse();
707        new_strides.reverse();
708        Tensor {
709            buffer: self.buffer.clone(), // shared — zero copy
710            shape: new_shape,
711            strides: new_strides,
712            offset: self.offset,
713        }
714    }
715
716    /// Transpose with explicit axis permutation (N-D). Zero-copy view.
717    ///
718    /// `axes` must be a permutation of `[0, 1, ..., ndim-1]`.
719    pub fn transpose_axes(&self, axes: &[usize]) -> Result<Tensor, RuntimeError> {
720        let ndim = self.ndim();
721        if axes.len() != ndim {
722            return Err(RuntimeError::InvalidOperation(
723                format!("transpose_axes: expected {} axes, got {}", ndim, axes.len()),
724            ));
725        }
726        // Validate permutation
727        let mut seen = vec![false; ndim];
728        for &ax in axes {
729            if ax >= ndim {
730                return Err(RuntimeError::IndexOutOfBounds { index: ax, length: ndim });
731            }
732            if seen[ax] {
733                return Err(RuntimeError::InvalidOperation(
734                    format!("transpose_axes: duplicate axis {ax}"),
735                ));
736            }
737            seen[ax] = true;
738        }
739        let new_shape: Vec<usize> = axes.iter().map(|&ax| self.shape[ax]).collect();
740        let new_strides: Vec<usize> = axes.iter().map(|&ax| self.strides[ax]).collect();
741        Ok(Tensor {
742            buffer: self.buffer.clone(),
743            shape: new_shape,
744            strides: new_strides,
745            offset: self.offset,
746        })
747    }
748
749    /// Multiply every element by a scalar, returning a new tensor.
750    pub fn scalar_mul(&self, s: f64) -> Tensor {
751        self.map(|x| x * s)
752    }
753
754    // ── Panicking convenience constructors (used by AD engine) --------
755
756    /// Create a tensor from raw data and shape.
757    /// **Panics** if `data.len()` does not match the shape.
758    pub fn from_vec_unchecked(data: Vec<f64>, shape: &[usize]) -> Tensor {
759        Self::from_vec(data, shape).expect("Tensor::from_vec_unchecked: shape mismatch")
760    }
761
762    /// Element-wise addition. **Panics** on shape mismatch.
763    pub fn add_unchecked(&self, other: &Tensor) -> Tensor {
764        self.add(other).expect("Tensor::add shape mismatch")
765    }
766
767    /// Element-wise subtraction. **Panics** on shape mismatch.
768    pub fn sub_unchecked(&self, other: &Tensor) -> Tensor {
769        self.sub(other).expect("Tensor::sub shape mismatch")
770    }
771
772    /// Element-wise multiplication. **Panics** on shape mismatch.
773    pub fn mul_elem_unchecked(&self, other: &Tensor) -> Tensor {
774        self.mul_elem(other).expect("Tensor::mul_elem shape mismatch")
775    }
776
777    /// Element-wise division. **Panics** on shape mismatch.
778    pub fn div_elem_unchecked(&self, other: &Tensor) -> Tensor {
779        self.div_elem(other).expect("Tensor::div_elem shape mismatch")
780    }
781
782    /// Matrix multiplication. **Panics** on dimension mismatch.
783    pub fn matmul_unchecked(&self, other: &Tensor) -> Tensor {
784        self.matmul(other).expect("Tensor::matmul dimension mismatch")
785    }
786
787    /// Matrix multiplication for 2-D tensors.
788    ///
789    /// `self` is (M, K), `other` is (K, N) => result is (M, N).
790    pub fn matmul(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
791        if self.ndim() != 2 || other.ndim() != 2 {
792            return Err(RuntimeError::InvalidOperation(
793                "matmul requires 2-D tensors".to_string(),
794            ));
795        }
796        let m = self.shape[0];
797        let k = self.shape[1];
798        let k2 = other.shape[0];
799        let n = other.shape[1];
800        if k != k2 {
801            return Err(RuntimeError::DimensionMismatch {
802                expected: k,
803                got: k2,
804            });
805        }
806
807        let a = self.to_vec();
808        let b = other.to_vec();
809
810        // Parallel path (Mode A): parallelize over output rows when the parallel
811        // feature is enabled and the matrix is large enough (>= 256 in any dim).
812        #[cfg(feature = "parallel")]
813        {
814            if m >= 256 || n >= 256 || k >= 256 {
815                return Self::matmul_parallel_mode_a(&a, &b, m, n, k);
816            }
817        }
818
819        // Tiled path: use L2-friendly tiled matmul for medium-to-large matrices.
820        // Threshold: any dimension >= 64 (the default tile size).
821        // NOTE: tiled path uses naive accumulation (not binned) — different
822        // numerical path for large matrices, but better cache locality.
823        if m >= 64 || n >= 64 || k >= 64 {
824            return Self::matmul_tiled(&a, &b, m, n, k);
825        }
826
827        // Sequential path: single-threaded with binned accumulation.
828        Self::matmul_sequential(&a, &b, m, n, k)
829    }
830
831    /// Sequential matmul (always available, deterministic reference).
832    fn matmul_sequential(
833        a: &[f64], b: &[f64], m: usize, n: usize, k: usize,
834    ) -> Result<Tensor, RuntimeError> {
835        let mut result = vec![0.0f64; m * n];
836        for i in 0..m {
837            for j in 0..n {
838                let mut acc = KahanAccumulatorF64::new();
839                for p in 0..k {
840                    acc.add(a[i * k + p] * b[p * n + j]);
841                }
842                result[i * n + j] = acc.finalize();
843            }
844        }
845        Tensor::from_vec(result, &[m, n])
846    }
847
848    /// Tiled matmul: delegates to `TiledMatmul` for L2-cache-friendly tiling.
849    ///
850    /// Used for medium matrices (any dimension >= 64) where cache locality
851    /// matters but parallel overhead isn't justified. The tiled path uses
852    /// naive accumulation (not binned accumulation), trading a small amount of
853    /// floating-point precision for better cache behavior.
854    fn matmul_tiled(
855        a: &[f64], b: &[f64], m: usize, n: usize, k: usize,
856    ) -> Result<Tensor, RuntimeError> {
857        let engine = TiledMatmul::new();
858        let result = engine.matmul(a, m, k, b, n);
859        Tensor::from_vec(result, &[m, n])
860    }
861
862    /// Parallel matmul Mode A: parallelize over output rows using tiled
863    /// micro-kernels for cache locality.
864    ///
865    /// Deterministic because:
866    /// - Each output row is computed by exactly one thread.
867    /// - Within each row, accumulation uses tiled AXPY (deterministic order).
868    /// - No cross-thread reduction or merge of partial sums.
869    ///
870    /// Uses KahanAccumulatorF64 (lightweight) instead of BinnedAccumulatorF64
871    /// (32KB per accumulator) to avoid massive stack pressure in parallel mode.
872    #[cfg(feature = "parallel")]
873    fn matmul_parallel_mode_a(
874        a: &[f64], b: &[f64], m: usize, n: usize, k: usize,
875    ) -> Result<Tensor, RuntimeError> {
876        use rayon::prelude::*;
877        use cjc_repro::KahanAccumulatorF64;
878
879        // For large matrices, use tiled matmul (sequential but cache-friendly).
880        // The tiling provides far better cache locality than parallel row-wise
881        // with column-strided B access, and avoids the 32KB-per-element
882        // BinnedAccumulator overhead that caused the 128→256 regression.
883        if m >= 512 && n >= 512 {
884            // Only use rayon for very large matrices where thread overhead
885            // is amortized. Split into row-bands, each processed with tiled matmul.
886            let band_size = (m + rayon::current_num_threads() - 1) / rayon::current_num_threads();
887            let band_size = band_size.max(64); // At least 64 rows per band
888            let mut result = vec![0.0f64; m * n];
889
890            result
891                .par_chunks_mut(band_size * n)
892                .enumerate()
893                .for_each(|(band_idx, band)| {
894                    let i_start = band_idx * band_size;
895                    let i_end = (i_start + band_size).min(m);
896                    let band_m = i_end - i_start;
897                    let a_band = &a[i_start * k .. i_end * k];
898                    let engine = crate::tensor_tiled::TiledMatmul::new();
899                    let tiled_result = engine.matmul(a_band, band_m, k, b, n);
900                    band[..band_m * n].copy_from_slice(&tiled_result);
901                });
902
903            return Tensor::from_vec(result, &[m, n]);
904        }
905
906        // For medium matrices (256-511), use Kahan accumulator (16 bytes, not 32KB).
907        let mut result = vec![0.0f64; m * n];
908        result
909            .par_chunks_mut(n)
910            .enumerate()
911            .for_each(|(i, row)| {
912                for j in 0..n {
913                    let mut acc = KahanAccumulatorF64::new();
914                    for p in 0..k {
915                        acc.add(a[i * k + p] * b[p * n + j]);
916                    }
917                    row[j] = acc.finalize();
918                }
919            });
920
921        Tensor::from_vec(result, &[m, n])
922    }
923
924    // -- Transformer Kernels ------------------------------------------------
925
926    /// Batched matrix multiplication.
927    ///
928    /// `self` is `[..., M, K]`, `other` is `[..., K, N]` => result is `[..., M, N]`.
929    /// The batch dimensions must be identical (no broadcast).
930    /// For 2-D inputs, delegates to `matmul`.
931    pub fn bmm(&self, other: &Tensor) -> Result<Tensor, RuntimeError> {
932        if self.ndim() < 2 || other.ndim() < 2 {
933            return Err(RuntimeError::InvalidOperation(
934                "bmm requires at least 2-D tensors".to_string(),
935            ));
936        }
937        if self.ndim() == 2 && other.ndim() == 2 {
938            return self.matmul(other);
939        }
940        if self.ndim() != other.ndim() {
941            return Err(RuntimeError::InvalidOperation(
942                format!(
943                    "bmm requires same number of dimensions, got {} and {}",
944                    self.ndim(),
945                    other.ndim()
946                ),
947            ));
948        }
949        let nd = self.ndim();
950        let batch_dims_a = &self.shape[..nd - 2];
951        let batch_dims_b = &other.shape[..nd - 2];
952        if batch_dims_a != batch_dims_b {
953            return Err(RuntimeError::InvalidOperation(
954                format!(
955                    "bmm batch dimensions mismatch: {:?} vs {:?}",
956                    batch_dims_a, batch_dims_b
957                ),
958            ));
959        }
960        let m = self.shape[nd - 2];
961        let k = self.shape[nd - 1];
962        let k2 = other.shape[nd - 2];
963        let n = other.shape[nd - 1];
964        if k != k2 {
965            return Err(RuntimeError::DimensionMismatch {
966                expected: k,
967                got: k2,
968            });
969        }
970
971        let batch_size: usize = batch_dims_a.iter().product();
972        let a = self.to_vec();
973        let b = other.to_vec();
974        let mat_a_stride = m * k;
975        let mat_b_stride = k * n;
976        let mat_c_stride = m * n;
977        let mut result = vec![0.0f64; batch_size * mat_c_stride];
978
979        // Helper closure: compute one batch into c_slice
980        let compute_batch = |batch: usize, c_slice: &mut [f64]| {
981            let a_slice = &a[batch * mat_a_stride..(batch + 1) * mat_a_stride];
982            let b_slice = &b[batch * mat_b_stride..(batch + 1) * mat_b_stride];
983
984            if m >= 64 || n >= 64 || k >= 64 {
985                let engine = crate::tensor_tiled::TiledMatmul::new();
986                let tiled = engine.matmul(a_slice, m, k, b_slice, n);
987                c_slice.copy_from_slice(&tiled);
988            } else {
989                for i in 0..m {
990                    for j in 0..n {
991                        let mut acc = KahanAccumulatorF64::new();
992                        for p in 0..k {
993                            acc.add(a_slice[i * k + p] * b_slice[p * n + j]);
994                        }
995                        c_slice[i * n + j] = acc.finalize();
996                    }
997                }
998            }
999        };
1000
1001        // Parallel path: parallelize over batches when workload is large enough
1002        #[cfg(feature = "parallel")]
1003        {
1004            if batch_size > 1 && m * k >= 4096 {
1005                use rayon::prelude::*;
1006                result
1007                    .par_chunks_mut(mat_c_stride)
1008                    .enumerate()
1009                    .for_each(|(batch, c_slice)| {
1010                        compute_batch(batch, c_slice);
1011                    });
1012
1013                let mut out_shape = batch_dims_a.to_vec();
1014                out_shape.push(m);
1015                out_shape.push(n);
1016                return Tensor::from_vec(result, &out_shape);
1017            }
1018        }
1019
1020        // Sequential fallback
1021        for batch in 0..batch_size {
1022            let c_off = batch * mat_c_stride;
1023            compute_batch(batch, &mut result[c_off..c_off + mat_c_stride]);
1024        }
1025
1026        let mut out_shape = batch_dims_a.to_vec();
1027        out_shape.push(m);
1028        out_shape.push(n);
1029        Tensor::from_vec(result, &out_shape)
1030    }
1031
1032    /// Softmax along the last dimension (two-pass stable algorithm).
1033    ///
1034    /// Pass 1: find max per row (prevents overflow in exp)
1035    /// Pass 2: compute exp(x - max), accumulate sum, normalize
1036    ///
1037    /// For a tensor of shape `[..., N]`, softmax is applied independently
1038    /// to each length-N slice along the last axis.
1039    pub fn softmax(&self) -> Result<Tensor, RuntimeError> {
1040        if self.ndim() == 0 {
1041            return Err(RuntimeError::InvalidOperation(
1042                "softmax requires at least 1-D tensor".to_string(),
1043            ));
1044        }
1045        // Avoid allocation when tensor is already contiguous and starts at offset 0
1046        let data_ref;
1047        let data_vec;
1048        let data: &[f64] = if self.is_contiguous() && self.offset == 0 {
1049            data_ref = self.buffer.borrow_data();
1050            &data_ref
1051        } else {
1052            data_vec = self.to_vec();
1053            &data_vec
1054        };
1055        let n = *self.shape.last().unwrap(); // last dimension size
1056        let outer: usize = data.len() / n;  // product of all dims except last
1057        let mut result = vec![0.0f64; data.len()];
1058
1059        for row in 0..outer {
1060            let start = row * n;
1061            let end = start + n;
1062            let slice = &data[start..end];
1063
1064            // Pass 1: find max for numerical stability
1065            let mut max_val = f64::NEG_INFINITY;
1066            for &v in slice {
1067                if v > max_val {
1068                    max_val = v;
1069                }
1070            }
1071
1072            // Pass 2: exp(x - max) and accumulate sum
1073            let mut exp_vals = vec![0.0f64; n];
1074            let mut sum = 0.0f64;
1075            let mut comp = 0.0f64; // Kahan compensation
1076            for i in 0..n {
1077                let e = (slice[i] - max_val).exp();
1078                exp_vals[i] = e;
1079                // Kahan summation for the denominator
1080                let y = e - comp;
1081                let t = sum + y;
1082                comp = (t - sum) - y;
1083                sum = t;
1084            }
1085
1086            // Normalize
1087            if sum == 0.0 {
1088                // Degenerate case: all -inf inputs → uniform
1089                let uniform = 1.0 / n as f64;
1090                for i in 0..n {
1091                    result[start + i] = uniform;
1092                }
1093            } else {
1094                for i in 0..n {
1095                    result[start + i] = exp_vals[i] / sum;
1096                }
1097            }
1098        }
1099
1100        Tensor::from_vec(result, &self.shape)
1101    }
1102
1103    /// Layer normalization over the last dimension.
1104    ///
1105    /// For each length-D slice along the last axis:
1106    ///   1. mean = Σx / D  (BinnedAccumulator)
1107    ///   2. var  = Σ(x - mean)² / D  (BinnedAccumulator)
1108    ///   3. normalized = (x - mean) / √(var + eps)
1109    ///   4. output = gamma * normalized + beta
1110    ///
1111    /// `gamma` and `beta` are 1-D tensors of shape `[D]`.
1112    /// `eps` is a small constant (typically 1e-5).
1113    pub fn layer_norm(
1114        &self,
1115        gamma: &Tensor,
1116        beta: &Tensor,
1117        eps: f64,
1118    ) -> Result<Tensor, RuntimeError> {
1119        if self.ndim() == 0 {
1120            return Err(RuntimeError::InvalidOperation(
1121                "layer_norm requires at least 1-D tensor".to_string(),
1122            ));
1123        }
1124        let d = *self.shape.last().unwrap();
1125        if gamma.len() != d || beta.len() != d {
1126            return Err(RuntimeError::InvalidOperation(
1127                format!(
1128                    "layer_norm: gamma/beta length {} must match last dim {}",
1129                    gamma.len(),
1130                    d
1131                ),
1132            ));
1133        }
1134
1135        let data = self.to_vec();
1136        let gamma_data = gamma.to_vec();
1137        let beta_data = beta.to_vec();
1138        let outer = data.len() / d;
1139        let mut result = vec![0.0f64; data.len()];
1140
1141        for row in 0..outer {
1142            let start = row * d;
1143            let slice = &data[start..start + d];
1144
1145            // Pass 1: compute mean via BinnedAccumulator
1146            let mean = binned_sum_f64(slice) / d as f64;
1147
1148            // Pass 2: compute variance via BinnedAccumulator
1149            let diffs: Vec<f64> = slice.iter().map(|&x| {
1150                let diff = x - mean;
1151                diff * diff
1152            }).collect();
1153            let variance = binned_sum_f64(&diffs) / d as f64;
1154
1155            // Normalize, scale, shift
1156            let inv_std = 1.0 / (variance + eps).sqrt();
1157            for i in 0..d {
1158                let normalized = (slice[i] - mean) * inv_std;
1159                result[start + i] = gamma_data[i] * normalized + beta_data[i];
1160            }
1161        }
1162
1163        Tensor::from_vec(result, &self.shape)
1164    }
1165
1166    /// Apply a function element-wise, reusing the buffer when possible (COW).
1167    ///
1168    /// If the tensor is contiguous, starts at offset 0, and has refcount == 1,
1169    /// the data is mutated in place (zero allocations). Otherwise, allocates
1170    /// a new buffer.
1171    fn map_elementwise(&self, f: impl Fn(f64) -> f64) -> Tensor {
1172        if self.is_contiguous() && self.offset == 0 && self.buffer.refcount() == 1 {
1173            // Fast path: mutate in place (COW — we're the sole owner)
1174            let mut data = self.buffer.borrow_data().clone();
1175            for x in data.iter_mut() {
1176                *x = f(*x);
1177            }
1178            Tensor::from_vec(data, &self.shape).unwrap()
1179        } else {
1180            // Fallback: allocate new buffer
1181            let data = self.to_vec();
1182            let result: Vec<f64> = data.iter().map(|&x| f(x)).collect();
1183            Tensor::from_vec(result, &self.shape).unwrap()
1184        }
1185    }
1186
1187    /// ReLU activation: `max(0, x)` element-wise.
1188    pub fn relu(&self) -> Tensor {
1189        self.map_elementwise(|x| if x > 0.0 { x } else { 0.0 })
1190    }
1191
1192    /// Sigmoid activation: 1 / (1 + exp(-x)) element-wise.
1193    pub fn sigmoid(&self) -> Tensor {
1194        self.map_elementwise(|x| 1.0 / (1.0 + (-x).exp()))
1195    }
1196
1197    /// Tanh activation element-wise.
1198    pub fn tanh_activation(&self) -> Tensor {
1199        self.map_elementwise(|x| x.tanh())
1200    }
1201
1202    /// Leaky ReLU activation: max(alpha*x, x) element-wise.
1203    pub fn leaky_relu(&self, alpha: f64) -> Tensor {
1204        self.map_elementwise(move |x| if x > 0.0 { x } else { alpha * x })
1205    }
1206
1207    /// SiLU (Swish) activation: x * sigmoid(x) element-wise.
1208    pub fn silu(&self) -> Tensor {
1209        let data = self.to_vec();
1210        let result: Vec<f64> = data.iter().map(|&x| x / (1.0 + (-x).exp())).collect();
1211        Tensor::from_vec(result, &self.shape).unwrap()
1212    }
1213
1214    /// Mish activation: x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))).
1215    pub fn mish(&self) -> Tensor {
1216        let data = self.to_vec();
1217        let result: Vec<f64> = data.iter().map(|&x| {
1218            let sp = (1.0 + x.exp()).ln();
1219            x * sp.tanh()
1220        }).collect();
1221        Tensor::from_vec(result, &self.shape).unwrap()
1222    }
1223
1224    /// Argmax: index of the maximum element (first occurrence, deterministic).
1225    pub fn argmax(&self) -> usize {
1226        let data = self.to_vec();
1227        let mut best_idx = 0;
1228        let mut best_val = f64::NEG_INFINITY;
1229        for (i, &v) in data.iter().enumerate() {
1230            if v > best_val || (v == best_val && i < best_idx) {
1231                best_val = v;
1232                best_idx = i;
1233            }
1234        }
1235        best_idx
1236    }
1237
1238    /// Argmin: index of the minimum element (first occurrence, deterministic).
1239    pub fn argmin(&self) -> usize {
1240        let data = self.to_vec();
1241        let mut best_idx = 0;
1242        let mut best_val = f64::INFINITY;
1243        for (i, &v) in data.iter().enumerate() {
1244            if v < best_val || (v == best_val && i < best_idx) {
1245                best_val = v;
1246                best_idx = i;
1247            }
1248        }
1249        best_idx
1250    }
1251
1252    /// Clamp all elements to [min, max].
1253    pub fn clamp(&self, min: f64, max: f64) -> Tensor {
1254        let data = self.to_vec();
1255        let result: Vec<f64> = data.iter().map(|&x| x.max(min).min(max)).collect();
1256        Tensor::from_vec(result, &self.shape).unwrap()
1257    }
1258
1259    /// One-hot encoding: given a 1D tensor of integer indices and a depth,
1260    /// returns a 2D tensor of shape [len, depth].
1261    pub fn one_hot(indices: &[usize], depth: usize) -> Result<Tensor, RuntimeError> {
1262        let n = indices.len();
1263        let mut data = vec![0.0; n * depth];
1264        for (i, &idx) in indices.iter().enumerate() {
1265            if idx >= depth {
1266                return Err(RuntimeError::InvalidOperation(format!(
1267                    "one_hot: index {idx} >= depth {depth}"
1268                )));
1269            }
1270            data[i * depth + idx] = 1.0;
1271        }
1272        Tensor::from_vec(data, &[n, depth])
1273    }
1274
1275    // -----------------------------------------------------------------------
1276    // Phase B4: Tensor extensions (cat, stack, topk)
1277    // -----------------------------------------------------------------------
1278
1279    /// Concatenate tensors along existing axis.
1280    pub fn cat(tensors: &[&Tensor], axis: usize) -> Result<Tensor, RuntimeError> {
1281        if tensors.is_empty() {
1282            return Err(RuntimeError::InvalidOperation("cat: no tensors".to_string()));
1283        }
1284        let ndim = tensors[0].ndim();
1285        if axis >= ndim {
1286            return Err(RuntimeError::InvalidOperation(
1287                format!("cat: axis {axis} out of bounds for {ndim}D tensor"),
1288            ));
1289        }
1290        for (i, t) in tensors.iter().enumerate().skip(1) {
1291            if t.ndim() != ndim {
1292                return Err(RuntimeError::InvalidOperation(
1293                    format!("cat: tensor {i} has different ndim"),
1294                ));
1295            }
1296            for d in 0..ndim {
1297                if d != axis && t.shape[d] != tensors[0].shape[d] {
1298                    return Err(RuntimeError::InvalidOperation(
1299                        format!("cat: shape mismatch at dim {d}"),
1300                    ));
1301                }
1302            }
1303        }
1304        let mut out_shape = tensors[0].shape.clone();
1305        for t in tensors.iter().skip(1) {
1306            out_shape[axis] += t.shape[axis];
1307        }
1308        let total = out_shape.iter().product::<usize>();
1309        let mut result = vec![0.0; total];
1310        let mut out_strides = vec![1usize; ndim];
1311        for d in (0..ndim - 1).rev() {
1312            out_strides[d] = out_strides[d + 1] * out_shape[d + 1];
1313        }
1314        let mut offset = 0;
1315        for t in tensors {
1316            let t_data = t.to_vec();
1317            let t_total: usize = t.shape.iter().product();
1318            let mut t_strides = vec![1usize; ndim];
1319            for d in (0..ndim - 1).rev() {
1320                t_strides[d] = t_strides[d + 1] * t.shape[d + 1];
1321            }
1322            for idx in 0..t_total {
1323                let mut remaining = idx;
1324                let mut out_flat = 0;
1325                for d in 0..ndim {
1326                    let coord = remaining / t_strides[d];
1327                    remaining %= t_strides[d];
1328                    let out_coord = if d == axis { coord + offset } else { coord };
1329                    out_flat += out_coord * out_strides[d];
1330                }
1331                result[out_flat] = t_data[idx];
1332            }
1333            offset += t.shape[axis];
1334        }
1335        Tensor::from_vec(result, &out_shape)
1336    }
1337
1338    /// Stack tensors along a new axis.
1339    pub fn stack(tensors: &[&Tensor], axis: usize) -> Result<Tensor, RuntimeError> {
1340        if tensors.is_empty() {
1341            return Err(RuntimeError::InvalidOperation("stack: no tensors".to_string()));
1342        }
1343        let base_shape = &tensors[0].shape;
1344        let ndim = base_shape.len();
1345        if axis > ndim {
1346            return Err(RuntimeError::InvalidOperation(
1347                format!("stack: axis {axis} out of bounds"),
1348            ));
1349        }
1350        for (i, t) in tensors.iter().enumerate().skip(1) {
1351            if &t.shape != base_shape {
1352                return Err(RuntimeError::InvalidOperation(
1353                    format!("stack: tensor {i} shape mismatch"),
1354                ));
1355            }
1356        }
1357        let mut out_shape = Vec::with_capacity(ndim + 1);
1358        for d in 0..axis { out_shape.push(base_shape[d]); }
1359        out_shape.push(tensors.len());
1360        for d in axis..ndim { out_shape.push(base_shape[d]); }
1361        let total: usize = out_shape.iter().product();
1362        let mut result = vec![0.0; total];
1363        let inner_size: usize = base_shape[axis..].iter().product::<usize>().max(1);
1364        let outer_size: usize = base_shape[..axis].iter().product::<usize>().max(1);
1365        for (t_idx, t) in tensors.iter().enumerate() {
1366            let t_data = t.to_vec();
1367            for outer in 0..outer_size {
1368                for inner in 0..inner_size {
1369                    let src = outer * inner_size + inner;
1370                    let dst = outer * (tensors.len() * inner_size) + t_idx * inner_size + inner;
1371                    if src < t_data.len() && dst < result.len() {
1372                        result[dst] = t_data[src];
1373                    }
1374                }
1375            }
1376        }
1377        Tensor::from_vec(result, &out_shape)
1378    }
1379
1380    /// Top-k values and indices (largest k values from flat data).
1381    pub fn topk(&self, k: usize) -> Result<(Tensor, Vec<usize>), RuntimeError> {
1382        let data = self.to_vec();
1383        let n = data.len();
1384        if k > n {
1385            return Err(RuntimeError::InvalidOperation(
1386                format!("topk: k={k} exceeds data length {n}"),
1387            ));
1388        }
1389        let mut indexed: Vec<(usize, f64)> = data.into_iter().enumerate().collect();
1390        indexed.sort_by(|a, b| b.1.total_cmp(&a.1).then(a.0.cmp(&b.0)));
1391        let top_k: Vec<(usize, f64)> = indexed[..k].to_vec();
1392        let values: Vec<f64> = top_k.iter().map(|&(_, v)| v).collect();
1393        let indices: Vec<usize> = top_k.iter().map(|&(i, _)| i).collect();
1394        Ok((Tensor::from_vec(values, &[k])?, indices))
1395    }
1396
1397    /// GELU activation (approximate): x * 0.5 * (1 + tanh(√(2/π) * (x + 0.044715 * x³)))
1398    pub fn gelu(&self) -> Tensor {
1399        let data = self.to_vec();
1400        let sqrt_2_over_pi = (2.0_f64 / std::f64::consts::PI).sqrt();
1401        let result: Vec<f64> = data.iter().map(|&x| {
1402            let inner = sqrt_2_over_pi * (x + 0.044715 * x * x * x);
1403            0.5 * x * (1.0 + inner.tanh())
1404        }).collect();
1405        Tensor::from_vec(result, &self.shape).unwrap()
1406    }
1407
1408    /// Linear layer: output = input @ weight^T + bias
1409    ///
1410    /// `self` is `[..., in_features]`, `weight` is `[out_features, in_features]`,
1411    /// `bias` is `[out_features]`.
1412    /// Result is `[..., out_features]`.
1413    pub fn linear(
1414        &self,
1415        weight: &Tensor,
1416        bias: &Tensor,
1417    ) -> Result<Tensor, RuntimeError> {
1418        if weight.ndim() != 2 {
1419            return Err(RuntimeError::InvalidOperation(
1420                "linear: weight must be 2-D [out_features, in_features]".to_string(),
1421            ));
1422        }
1423        let out_features = weight.shape[0];
1424        let in_features = weight.shape[1];
1425        let last_dim = *self.shape.last().ok_or_else(|| {
1426            RuntimeError::InvalidOperation("linear: input must be at least 1-D".to_string())
1427        })?;
1428        if last_dim != in_features {
1429            return Err(RuntimeError::DimensionMismatch {
1430                expected: in_features,
1431                got: last_dim,
1432            });
1433        }
1434        if bias.len() != out_features {
1435            return Err(RuntimeError::InvalidOperation(
1436                format!(
1437                    "linear: bias length {} must match out_features {}",
1438                    bias.len(),
1439                    out_features
1440                ),
1441            ));
1442        }
1443
1444        let data = self.to_vec();
1445        let w = weight.to_vec();
1446        let b = bias.to_vec();
1447        let outer = data.len() / in_features;
1448        let mut result = vec![0.0f64; outer * out_features];
1449
1450        for row in 0..outer {
1451            let x_start = row * in_features;
1452            let x_slice = &data[x_start..x_start + in_features];
1453            let y_start = row * out_features;
1454            for j in 0..out_features {
1455                let w_start = j * in_features;
1456                let mut acc = BinnedAccumulatorF64::new();
1457                for p in 0..in_features {
1458                    acc.add(x_slice[p] * w[w_start + p]);
1459                }
1460                result[y_start + j] = acc.finalize() + b[j];
1461            }
1462        }
1463
1464        let mut out_shape = self.shape[..self.shape.len() - 1].to_vec();
1465        out_shape.push(out_features);
1466        Tensor::from_vec(result, &out_shape)
1467    }
1468
1469    /// 1D convolution: signal `[signal_len]` * filters `[out_ch, kernel_size]` + bias
1470    ///
1471    /// Returns `[out_ch, signal_len - kernel_size + 1]` (valid mode, stride=1).
1472    pub fn conv1d(
1473        &self,
1474        filters: &Tensor,
1475        bias: &Tensor,
1476    ) -> Result<Tensor, RuntimeError> {
1477        if self.ndim() != 1 {
1478            return Err(RuntimeError::InvalidOperation(
1479                "conv1d: input must be 1-D [signal_len]".to_string(),
1480            ));
1481        }
1482        if filters.ndim() != 2 {
1483            return Err(RuntimeError::InvalidOperation(
1484                "conv1d: filters must be 2-D [out_channels, kernel_size]".to_string(),
1485            ));
1486        }
1487        let signal_len = self.shape[0];
1488        let out_channels = filters.shape[0];
1489        let kernel_size = filters.shape[1];
1490        if signal_len < kernel_size {
1491            return Err(RuntimeError::InvalidOperation(
1492                format!(
1493                    "conv1d: signal_len {} < kernel_size {}",
1494                    signal_len, kernel_size
1495                ),
1496            ));
1497        }
1498        if bias.len() != out_channels {
1499            return Err(RuntimeError::InvalidOperation(
1500                format!(
1501                    "conv1d: bias length {} must match out_channels {}",
1502                    bias.len(), out_channels
1503                ),
1504            ));
1505        }
1506        let out_len = signal_len - kernel_size + 1;
1507        let s = self.to_vec();
1508        let f = filters.to_vec();
1509        let b = bias.to_vec();
1510        let mut result = vec![0.0; out_channels * out_len];
1511        kernel_fns::conv1d_raw(&s, &f, &b, &mut result, signal_len, out_channels, kernel_size);
1512        Tensor::from_vec(result, &[out_channels, out_len])
1513    }
1514
1515    /// 2D convolution — NCHW layout, valid mode, configurable stride.
1516    ///
1517    /// # Arguments
1518    /// - `self`:    `[N, C_in, H, W]` input tensor
1519    /// - `filters`: `[C_out, C_in, kH, kW]`
1520    /// - `bias`:    `[C_out]`
1521    /// - `stride`:  spatial stride (default 1)
1522    ///
1523    /// # Returns
1524    /// `[N, C_out, H_out, W_out]` where `H_out = (H - kH) / stride + 1`.
1525    ///
1526    /// Uses `BinnedAccumulatorF64` for every dot product — bit-identical results
1527    /// across all runs and hardware configurations.
1528    pub fn conv2d(
1529        &self,
1530        filters: &Tensor,
1531        bias: &Tensor,
1532        stride: usize,
1533    ) -> Result<Tensor, RuntimeError> {
1534        if self.ndim() != 4 {
1535            return Err(RuntimeError::InvalidOperation(
1536                "conv2d: input must be 4-D [N, C_in, H, W]".to_string(),
1537            ));
1538        }
1539        if filters.ndim() != 4 {
1540            return Err(RuntimeError::InvalidOperation(
1541                "conv2d: filters must be 4-D [C_out, C_in, kH, kW]".to_string(),
1542            ));
1543        }
1544        if stride == 0 {
1545            return Err(RuntimeError::InvalidOperation(
1546                "conv2d: stride must be >= 1".to_string(),
1547            ));
1548        }
1549
1550        let n    = self.shape[0];
1551        let c_in = self.shape[1];
1552        let h_in = self.shape[2];
1553        let w_in = self.shape[3];
1554
1555        let c_out      = filters.shape[0];
1556        let c_in_check = filters.shape[1];
1557        let kh         = filters.shape[2];
1558        let kw         = filters.shape[3];
1559
1560        if c_in != c_in_check {
1561            return Err(RuntimeError::InvalidOperation(format!(
1562                "conv2d: input C_in={} does not match filter C_in={}",
1563                c_in, c_in_check
1564            )));
1565        }
1566        if h_in < kh || w_in < kw {
1567            return Err(RuntimeError::InvalidOperation(format!(
1568                "conv2d: input spatial [{}, {}] is smaller than kernel [{}, {}]",
1569                h_in, w_in, kh, kw
1570            )));
1571        }
1572        if bias.len() != c_out {
1573            return Err(RuntimeError::InvalidOperation(format!(
1574                "conv2d: bias length {} must match C_out={}",
1575                bias.len(), c_out
1576            )));
1577        }
1578
1579        let h_out = (h_in - kh) / stride + 1;
1580        let w_out = (w_in - kw) / stride + 1;
1581
1582        let inp = self.to_vec();
1583        let flt = filters.to_vec();
1584        let b   = bias.to_vec();
1585        let mut result = vec![0.0f64; n * c_out * h_out * w_out];
1586
1587        kernel_fns::conv2d_raw(&inp, &flt, &b, &mut result,
1588                           n, c_in, h_in, w_in, c_out, kh, kw, stride);
1589
1590        Tensor::from_vec(result, &[n, c_out, h_out, w_out])
1591    }
1592
1593    /// 2D max-pooling — NCHW layout, non-overlapping windows.
1594    ///
1595    /// - `self`: `[N, C, H, W]`
1596    /// - `ph`, `pw`: pool height/width (stride = window size)
1597    ///
1598    /// Returns `[N, C, H/ph, W/pw]`.
1599    pub fn maxpool2d(&self, ph: usize, pw: usize) -> Result<Tensor, RuntimeError> {
1600        if self.ndim() != 4 {
1601            return Err(RuntimeError::InvalidOperation(
1602                "maxpool2d: input must be 4-D [N, C, H, W]".to_string(),
1603            ));
1604        }
1605        if ph == 0 || pw == 0 {
1606            return Err(RuntimeError::InvalidOperation(
1607                "maxpool2d: pool size must be >= 1".to_string(),
1608            ));
1609        }
1610
1611        let n    = self.shape[0];
1612        let c    = self.shape[1];
1613        let h_in = self.shape[2];
1614        let w_in = self.shape[3];
1615
1616        if h_in < ph || w_in < pw {
1617            return Err(RuntimeError::InvalidOperation(format!(
1618                "maxpool2d: input [{}, {}] smaller than pool [{}, {}]",
1619                h_in, w_in, ph, pw
1620            )));
1621        }
1622
1623        let h_out = h_in / ph;
1624        let w_out = w_in / pw;
1625
1626        let inp = self.to_vec();
1627        let mut result = vec![0.0f64; n * c * h_out * w_out];
1628
1629        kernel_fns::maxpool2d_raw(&inp, &mut result, n, c, h_in, w_in, ph, pw);
1630
1631        Tensor::from_vec(result, &[n, c, h_out, w_out])
1632    }
1633
1634    /// Applies 2-D average pooling over a `[C, H, W]` tensor.
1635    ///
1636    /// # Arguments
1637    ///
1638    /// * `kernel_h` / `kernel_w` - Pooling window size
1639    /// * `stride_h` / `stride_w` - Stride for the pooling window
1640    ///
1641    /// # Returns
1642    ///
1643    /// Tensor of shape `[C, out_h, out_w]` where `out_h = (H - kernel_h) / stride_h + 1`.
1644    ///
1645    /// # Errors
1646    ///
1647    /// Returns an error if the tensor is not 3-D or if kernel/stride produce invalid output dimensions.
1648    pub fn avgpool2d(&self, kernel_h: usize, kernel_w: usize, stride_h: usize, stride_w: usize) -> Result<Tensor, RuntimeError> {
1649        let shape = self.shape();
1650        if shape.len() != 3 {
1651            return Err(RuntimeError::InvalidOperation(format!("avgpool2d requires 3-D [C,H,W], got {:?}", shape)));
1652        }
1653        let (c, h, w) = (shape[0], shape[1], shape[2]);
1654        if kernel_h > h || kernel_w > w {
1655            return Err(RuntimeError::InvalidOperation("avgpool2d: kernel larger than input".into()));
1656        }
1657        let out_h = (h - kernel_h) / stride_h + 1;
1658        let out_w = (w - kernel_w) / stride_w + 1;
1659        let data = self.to_vec();
1660        let mut out = Vec::with_capacity(c * out_h * out_w);
1661        let pool_size = (kernel_h * kernel_w) as f64;
1662
1663        for ch in 0..c {
1664            for oh in 0..out_h {
1665                for ow in 0..out_w {
1666                    let mut sum = 0.0f64;
1667                    for kh in 0..kernel_h {
1668                        for kw in 0..kernel_w {
1669                            let ih = oh * stride_h + kh;
1670                            let iw = ow * stride_w + kw;
1671                            sum += data[ch * h * w + ih * w + iw];
1672                        }
1673                    }
1674                    out.push(sum / pool_size);
1675                }
1676            }
1677        }
1678        Tensor::from_vec(out, &[c, out_h, out_w])
1679    }
1680
1681    /// Scaled dot-product attention (single head).
1682    ///
1683    /// `queries` is `[..., T, d_k]`
1684    /// `keys`    is `[..., S, d_k]`
1685    /// `values`  is `[..., S, d_v]`
1686    ///
1687    /// Computes: softmax(Q × Kᵀ / √d_k) × V
1688    /// Returns `[..., T, d_v]`.
1689    pub fn scaled_dot_product_attention(
1690        queries: &Tensor,
1691        keys: &Tensor,
1692        values: &Tensor,
1693    ) -> Result<Tensor, RuntimeError> {
1694        if queries.ndim() < 2 || keys.ndim() < 2 || values.ndim() < 2 {
1695            return Err(RuntimeError::InvalidOperation(
1696                "attention: Q, K, V must be at least 2-D".to_string(),
1697            ));
1698        }
1699        let nd = queries.ndim();
1700        let d_k = queries.shape[nd - 1];
1701        let scale = 1.0 / (d_k as f64).sqrt();
1702
1703        // Transpose keys: swap last two dims
1704        let keys_t = keys.transpose_last_two()?;
1705
1706        // Q × K^T → [... T, S]
1707        let scores = queries.bmm(&keys_t)?;
1708
1709        // Scale
1710        let scores_scaled = scores.scalar_mul(scale);
1711
1712        // Softmax along last dim
1713        let attn_weights = scores_scaled.softmax()?;
1714
1715        // Attn × V → [... T, d_v]
1716        attn_weights.bmm(values)
1717    }
1718
1719    /// Transpose the last two dimensions of a tensor.
1720    ///
1721    /// `[..., A, B]` → `[..., B, A]`
1722    pub fn transpose_last_two(&self) -> Result<Tensor, RuntimeError> {
1723        if self.ndim() < 2 {
1724            return Err(RuntimeError::InvalidOperation(
1725                "transpose_last_two requires at least 2-D tensor".to_string(),
1726            ));
1727        }
1728        let nd = self.ndim();
1729        let rows = self.shape[nd - 2];
1730        let cols = self.shape[nd - 1];
1731        let data = self.to_vec();
1732        let batch_size: usize = self.shape[..nd - 2].iter().product::<usize>().max(1);
1733        let mat_size = rows * cols;
1734        let mut result = vec![0.0f64; data.len()];
1735
1736        for b in 0..batch_size {
1737            let off = b * mat_size;
1738            for i in 0..rows {
1739                for j in 0..cols {
1740                    result[off + j * rows + i] = data[off + i * cols + j];
1741                }
1742            }
1743        }
1744
1745        let mut out_shape = self.shape.clone();
1746        out_shape[nd - 2] = cols;
1747        out_shape[nd - 1] = rows;
1748        Tensor::from_vec(result, &out_shape)
1749    }
1750
1751    // -- Zero-Copy Weight Mapping -------------------------------------------
1752
1753    /// Create a tensor view from raw bytes — **zero allocation**.
1754    ///
1755    /// Interprets `bytes` as a contiguous block of `f64` (8 bytes each) or
1756    /// `f32` (4 bytes each, promoted to f64) values and maps them into a
1757    /// `Tensor` with the given shape.
1758    ///
1759    /// `dtype` must be `"f64"` or `"f32"`.
1760    ///
1761    /// For f64: bytes.len() must equal shape_numel * 8.
1762    /// For f32: bytes.len() must equal shape_numel * 4.
1763    ///
1764    /// The returned tensor **owns** its buffer (copied from the raw bytes)
1765    /// but performs exactly one allocation for the data vector.
1766    pub fn from_bytes(bytes: &[u8], shape: &[usize], dtype: &str) -> Result<Tensor, RuntimeError> {
1767        let numel = Self::shape_numel(shape);
1768        match dtype {
1769            "f64" => {
1770                let expected = numel * 8;
1771                if bytes.len() != expected {
1772                    return Err(RuntimeError::ShapeMismatch {
1773                        expected,
1774                        got: bytes.len(),
1775                    });
1776                }
1777                let mut data = Vec::with_capacity(numel);
1778                for i in 0..numel {
1779                    let off = i * 8;
1780                    let mut buf = [0u8; 8];
1781                    buf.copy_from_slice(&bytes[off..off + 8]);
1782                    data.push(f64::from_le_bytes(buf));
1783                }
1784                Ok(Tensor {
1785                    buffer: Buffer::from_vec(data),
1786                    shape: shape.to_vec(),
1787                    strides: Self::compute_strides(shape),
1788                    offset: 0,
1789                })
1790            }
1791            "f32" => {
1792                let expected = numel * 4;
1793                if bytes.len() != expected {
1794                    return Err(RuntimeError::ShapeMismatch {
1795                        expected,
1796                        got: bytes.len(),
1797                    });
1798                }
1799                let mut data = Vec::with_capacity(numel);
1800                for i in 0..numel {
1801                    let off = i * 4;
1802                    let mut buf = [0u8; 4];
1803                    buf.copy_from_slice(&bytes[off..off + 4]);
1804                    data.push(f32::from_le_bytes(buf) as f64);
1805                }
1806                Ok(Tensor {
1807                    buffer: Buffer::from_vec(data),
1808                    shape: shape.to_vec(),
1809                    strides: Self::compute_strides(shape),
1810                    offset: 0,
1811                })
1812            }
1813            _ => Err(RuntimeError::InvalidOperation(
1814                format!("from_bytes: unsupported dtype '{}', expected 'f32' or 'f64'", dtype),
1815            )),
1816        }
1817    }
1818
1819    // -- Multi-Head Attention Splitting -------------------------------------
1820
1821    /// Reshape a 3D tensor `[batch, seq, model_dim]` into 4D
1822    /// `[batch, num_heads, seq, head_dim]` by splitting the last dimension.
1823    ///
1824    /// This is a **zero-copy view** — it only changes shape/strides metadata.
1825    /// `model_dim` must be divisible by `num_heads`.
1826    pub fn split_heads(&self, num_heads: usize) -> Result<Tensor, RuntimeError> {
1827        if self.ndim() != 3 {
1828            return Err(RuntimeError::DimensionMismatch {
1829                expected: 3,
1830                got: self.ndim(),
1831            });
1832        }
1833        let batch = self.shape[0];
1834        let seq = self.shape[1];
1835        let model_dim = self.shape[2];
1836        if model_dim % num_heads != 0 {
1837            return Err(RuntimeError::InvalidOperation(
1838                format!(
1839                    "split_heads: model_dim {} not divisible by num_heads {}",
1840                    model_dim, num_heads
1841                ),
1842            ));
1843        }
1844        let head_dim = model_dim / num_heads;
1845        // Need contiguous data for the reshape
1846        let tensor = if self.is_contiguous() { self.clone() } else { self.to_contiguous() };
1847        // Reshape [B, S, H*D] -> [B, S, H, D] then transpose to [B, H, S, D]
1848        let reshaped = Tensor {
1849            buffer: tensor.buffer.clone(),
1850            shape: vec![batch, seq, num_heads, head_dim],
1851            strides: Self::compute_strides(&[batch, seq, num_heads, head_dim]),
1852            offset: 0,
1853        };
1854        // Transpose dims 1 and 2: [B, S, H, D] -> [B, H, S, D]
1855        // New strides: swap strides[1] and strides[2]
1856        Ok(Tensor {
1857            buffer: reshaped.buffer,
1858            shape: vec![batch, num_heads, seq, head_dim],
1859            strides: vec![
1860                reshaped.strides[0], // batch stride unchanged
1861                reshaped.strides[2], // head stride (was dim 2)
1862                reshaped.strides[1], // seq stride (was dim 1)
1863                reshaped.strides[3], // head_dim stride unchanged
1864            ],
1865            offset: 0,
1866        })
1867    }
1868
1869    /// Merge heads back: reshape 4D `[batch, num_heads, seq, head_dim]` into
1870    /// 3D `[batch, seq, model_dim]`. Materializes if non-contiguous.
1871    pub fn merge_heads(&self) -> Result<Tensor, RuntimeError> {
1872        if self.ndim() != 4 {
1873            return Err(RuntimeError::DimensionMismatch {
1874                expected: 4,
1875                got: self.ndim(),
1876            });
1877        }
1878        let batch = self.shape[0];
1879        let num_heads = self.shape[1];
1880        let seq = self.shape[2];
1881        let head_dim = self.shape[3];
1882        // Need [B, H, S, D] -> [B, S, H, D] -> [B, S, H*D]
1883        // Transpose dims 1 and 2 first
1884        let transposed = Tensor {
1885            buffer: self.buffer.clone(),
1886            shape: vec![batch, seq, num_heads, head_dim],
1887            strides: vec![
1888                self.strides[0],
1889                self.strides[2], // seq stride
1890                self.strides[1], // head stride
1891                self.strides[3],
1892            ],
1893            offset: self.offset,
1894        };
1895        // Materialize contiguous then reshape
1896        let contig = transposed.to_contiguous();
1897        let model_dim = num_heads * head_dim;
1898        Ok(Tensor {
1899            buffer: contig.buffer,
1900            shape: vec![batch, seq, model_dim],
1901            strides: Self::compute_strides(&[batch, seq, model_dim]),
1902            offset: 0,
1903        })
1904    }
1905
1906    /// View-only reshape: reinterpret shape without copying.
1907    /// Only works on contiguous tensors. Falls back to copy if non-contiguous.
1908    pub fn view_reshape(&self, new_shape: &[usize]) -> Result<Tensor, RuntimeError> {
1909        self.reshape(new_shape)
1910    }
1911
1912    // -----------------------------------------------------------------------
1913    // Phase C4: Sorting & Tensor Indexing
1914    // -----------------------------------------------------------------------
1915
1916    /// Returns indices that would sort the flattened tensor in ascending order.
1917    /// Uses f64::total_cmp for deterministic ordering of NaN.
1918    pub fn argsort(&self) -> Tensor {
1919        let data = self.to_vec();
1920        let mut indices: Vec<usize> = (0..data.len()).collect();
1921        indices.sort_by(|&a, &b| data[a].total_cmp(&data[b]));
1922        let result: Vec<f64> = indices.iter().map(|&i| i as f64).collect();
1923        Tensor::from_vec_unchecked(result, &[data.len()])
1924    }
1925
1926    /// Gather elements from the tensor along a dimension using index tensor.
1927    /// For 1D: result[i] = self[indices[i]]
1928    /// For 2D dim=0: result[i][j] = self[indices[i][j]][j]
1929    /// For 2D dim=1: result[i][j] = self[i][indices[i][j]]
1930    pub fn gather(&self, dim: usize, indices: &Tensor) -> Result<Tensor, RuntimeError> {
1931        let data = self.to_vec();
1932        let idx_data = indices.to_vec();
1933        if self.ndim() == 1 {
1934            let mut result = Vec::with_capacity(idx_data.len());
1935            for &idx in &idx_data {
1936                let i = idx as usize;
1937                if i >= data.len() {
1938                    return Err(RuntimeError::InvalidOperation(
1939                        format!("gather: index {} out of bounds for size {}", i, data.len()),
1940                    ));
1941                }
1942                result.push(data[i]);
1943            }
1944            Ok(Tensor::from_vec_unchecked(result, indices.shape()))
1945        } else if self.ndim() == 2 {
1946            let rows = self.shape[0];
1947            let cols = self.shape[1];
1948            let idx_shape = indices.shape();
1949            let out_rows = idx_shape[0];
1950            let out_cols = idx_shape[1];
1951            let mut result = vec![0.0; out_rows * out_cols];
1952            for i in 0..out_rows {
1953                for j in 0..out_cols {
1954                    let idx = idx_data[i * out_cols + j] as usize;
1955                    let val = if dim == 0 {
1956                        if idx >= rows {
1957                            return Err(RuntimeError::InvalidOperation(
1958                                format!("gather dim=0: index {} out of bounds for {} rows", idx, rows),
1959                            ));
1960                        }
1961                        data[idx * cols + j]
1962                    } else {
1963                        if idx >= cols {
1964                            return Err(RuntimeError::InvalidOperation(
1965                                format!("gather dim=1: index {} out of bounds for {} cols", idx, cols),
1966                            ));
1967                        }
1968                        data[i * cols + idx]
1969                    };
1970                    result[i * out_cols + j] = val;
1971                }
1972            }
1973            Ok(Tensor::from_vec_unchecked(result, idx_shape))
1974        } else {
1975            Err(RuntimeError::InvalidOperation(
1976                "gather: only 1D and 2D tensors supported".into(),
1977            ))
1978        }
1979    }
1980
1981    /// Scatter src values into a tensor of given shape at indices along a dimension.
1982    /// For 1D: result[indices[i]] = src[i]
1983    /// For 2D dim=0: result[indices[i][j]][j] = src[i][j]
1984    /// For 2D dim=1: result[i][indices[i][j]] = src[i][j]
1985    pub fn scatter(&self, dim: usize, indices: &Tensor, src: &Tensor) -> Result<Tensor, RuntimeError> {
1986        let mut result = self.to_vec();
1987        let idx_data = indices.to_vec();
1988        let src_data = src.to_vec();
1989        if self.ndim() == 1 {
1990            for (k, &idx) in idx_data.iter().enumerate() {
1991                let i = idx as usize;
1992                if i >= result.len() {
1993                    return Err(RuntimeError::InvalidOperation(
1994                        format!("scatter: index {} out of bounds for size {}", i, result.len()),
1995                    ));
1996                }
1997                result[i] = src_data[k];
1998            }
1999            Ok(Tensor::from_vec_unchecked(result, self.shape()))
2000        } else if self.ndim() == 2 {
2001            let cols = self.shape[1];
2002            let idx_shape = indices.shape();
2003            let out_cols = idx_shape[1];
2004            let out_rows = idx_shape[0];
2005            for i in 0..out_rows {
2006                for j in 0..out_cols {
2007                    let idx = idx_data[i * out_cols + j] as usize;
2008                    let src_val = src_data[i * out_cols + j];
2009                    if dim == 0 {
2010                        if idx >= self.shape[0] {
2011                            return Err(RuntimeError::InvalidOperation(
2012                                format!("scatter dim=0: index {} out of bounds for {} rows", idx, self.shape[0]),
2013                            ));
2014                        }
2015                        result[idx * cols + j] = src_val;
2016                    } else {
2017                        if idx >= cols {
2018                            return Err(RuntimeError::InvalidOperation(
2019                                format!("scatter dim=1: index {} out of bounds for {} cols", idx, cols),
2020                            ));
2021                        }
2022                        result[i * cols + idx] = src_val;
2023                    }
2024                }
2025            }
2026            Ok(Tensor::from_vec_unchecked(result, self.shape()))
2027        } else {
2028            Err(RuntimeError::InvalidOperation(
2029                "scatter: only 1D and 2D tensors supported".into(),
2030            ))
2031        }
2032    }
2033
2034    /// Select slices along a dimension by index.
2035    /// For 2D dim=0: selects rows
2036    /// For 2D dim=1: selects columns
2037    pub fn index_select(&self, dim: usize, indices: &Tensor) -> Result<Tensor, RuntimeError> {
2038        let data = self.to_vec();
2039        let idx_data = indices.to_vec();
2040        if self.ndim() == 1 {
2041            let mut result = Vec::with_capacity(idx_data.len());
2042            for &idx in &idx_data {
2043                let i = idx as usize;
2044                if i >= data.len() {
2045                    return Err(RuntimeError::InvalidOperation(
2046                        format!("index_select: index {} out of bounds for size {}", i, data.len()),
2047                    ));
2048                }
2049                result.push(data[i]);
2050            }
2051            Ok(Tensor::from_vec_unchecked(result, &[idx_data.len()]))
2052        } else if self.ndim() == 2 {
2053            let rows = self.shape[0];
2054            let cols = self.shape[1];
2055            let n = idx_data.len();
2056            if dim == 0 {
2057                let mut result = Vec::with_capacity(n * cols);
2058                for &idx in &idx_data {
2059                    let i = idx as usize;
2060                    if i >= rows {
2061                        return Err(RuntimeError::InvalidOperation(
2062                            format!("index_select dim=0: index {} out of bounds for {} rows", i, rows),
2063                        ));
2064                    }
2065                    for j in 0..cols {
2066                        result.push(data[i * cols + j]);
2067                    }
2068                }
2069                Ok(Tensor::from_vec_unchecked(result, &[n, cols]))
2070            } else {
2071                let mut result = Vec::with_capacity(rows * n);
2072                for i in 0..rows {
2073                    for &idx in &idx_data {
2074                        let j = idx as usize;
2075                        if j >= cols {
2076                            return Err(RuntimeError::InvalidOperation(
2077                                format!("index_select dim=1: index {} out of bounds for {} cols", j, cols),
2078                            ));
2079                        }
2080                        result.push(data[i * cols + j]);
2081                    }
2082                }
2083                Ok(Tensor::from_vec_unchecked(result, &[rows, n]))
2084            }
2085        } else {
2086            Err(RuntimeError::InvalidOperation(
2087                "index_select: only 1D and 2D tensors supported".into(),
2088            ))
2089        }
2090    }
2091
2092    // -----------------------------------------------------------------------
2093    // Phase 2: Boolean / Masking Ops
2094    // -----------------------------------------------------------------------
2095
2096    /// Element-wise conditional select: `where(condition, other)`.
2097    /// For each element, returns `self[i]` if `condition[i] != 0.0`, else `other[i]`.
2098    pub fn tensor_where(&self, condition: &Tensor, other: &Tensor) -> Result<Tensor, RuntimeError> {
2099        if self.shape() != condition.shape() || self.shape() != other.shape() {
2100            return Err(RuntimeError::InvalidOperation(
2101                format!("where: shape mismatch self={:?} cond={:?} other={:?}",
2102                    self.shape(), condition.shape(), other.shape()),
2103            ));
2104        }
2105        let s = self.to_vec();
2106        let c = condition.to_vec();
2107        let o = other.to_vec();
2108        let result: Vec<f64> = s.iter().zip(c.iter()).zip(o.iter())
2109            .map(|((&sv, &cv), &ov)| if cv != 0.0 { sv } else { ov })
2110            .collect();
2111        Tensor::from_vec(result, self.shape())
2112    }
2113
2114    /// Return `true` if any element is non-zero.
2115    pub fn any(&self) -> bool {
2116        let data = self.to_vec();
2117        data.iter().any(|&x| x != 0.0)
2118    }
2119
2120    /// Return `true` if all elements are non-zero.
2121    pub fn all(&self) -> bool {
2122        let data = self.to_vec();
2123        data.iter().all(|&x| x != 0.0)
2124    }
2125
2126    /// Return a 1-D tensor of flat indices where elements are non-zero.
2127    ///
2128    /// If no elements are non-zero, returns an empty tensor of shape `[0]`.
2129    pub fn nonzero(&self) -> Tensor {
2130        let data = self.to_vec();
2131        let indices: Vec<f64> = data.iter().enumerate()
2132            .filter(|(_, &v)| v != 0.0)
2133            .map(|(i, _)| i as f64)
2134            .collect();
2135        let len = indices.len();
2136        if len == 0 {
2137            Tensor::from_vec(vec![], &[0]).unwrap()
2138        } else {
2139            Tensor::from_vec(indices, &[len]).unwrap()
2140        }
2141    }
2142
2143    /// Fill elements where `mask` is non-zero with `value`.
2144    pub fn masked_fill(&self, mask: &Tensor, value: f64) -> Result<Tensor, RuntimeError> {
2145        if self.shape() != mask.shape() {
2146            return Err(RuntimeError::InvalidOperation(
2147                format!("masked_fill: shape mismatch self={:?} mask={:?}",
2148                    self.shape(), mask.shape()),
2149            ));
2150        }
2151        let data = self.to_vec();
2152        let m = mask.to_vec();
2153        let result: Vec<f64> = data.iter().zip(m.iter())
2154            .map(|(&d, &mv)| if mv != 0.0 { value } else { d })
2155            .collect();
2156        Tensor::from_vec(result, self.shape())
2157    }
2158
2159    // -----------------------------------------------------------------------
2160    // Phase 2: Axis Reductions with keepdim
2161    // -----------------------------------------------------------------------
2162
2163    /// Generic axis reduction using a caller-provided reduction function.
2164    ///
2165    /// Gathers values along the specified axis for each output position
2166    /// and applies `reduce_fn`. If `keepdim` is `true`, the reduced axis
2167    /// retains size 1; otherwise it is removed from the output shape.
2168    fn reduce_axis<F>(&self, axis: usize, keepdim: bool, reduce_fn: F)
2169        -> Result<Tensor, RuntimeError>
2170    where
2171        F: Fn(&[f64]) -> f64,
2172    {
2173        let ndim = self.ndim();
2174        if axis >= ndim {
2175            return Err(RuntimeError::IndexOutOfBounds {
2176                index: axis,
2177                length: ndim,
2178            });
2179        }
2180
2181        let axis_len = self.shape[axis];
2182        // Build output shape
2183        let mut out_shape: Vec<usize> = self.shape.clone();
2184        out_shape[axis] = 1;
2185        let out_numel = Self::shape_numel(&out_shape);
2186        let out_strides = Self::compute_strides(&out_shape);
2187
2188        let data = self.to_vec();
2189        let mut result = Vec::with_capacity(out_numel);
2190        let mut indices = vec![0usize; ndim];
2191
2192        for out_idx in 0..out_numel {
2193            // Compute N-D index from flat output index
2194            {
2195                let mut remaining = out_idx;
2196                for d in 0..ndim {
2197                    indices[d] = remaining / out_strides[d];
2198                    remaining %= out_strides[d];
2199                }
2200            }
2201
2202            // Gather values along the reduction axis
2203            let mut vals = Vec::with_capacity(axis_len);
2204            for k in 0..axis_len {
2205                let mut flat = self.offset;
2206                for d in 0..ndim {
2207                    let idx = if d == axis { k } else { indices[d] };
2208                    flat += idx * self.strides[d];
2209                }
2210                vals.push(data[flat]);
2211            }
2212            result.push(reduce_fn(&vals));
2213        }
2214
2215        let final_shape = if keepdim {
2216            out_shape
2217        } else {
2218            // Remove the axis dimension
2219            let mut s: Vec<usize> = self.shape.iter().enumerate()
2220                .filter(|&(i, _)| i != axis)
2221                .map(|(_, &v)| v)
2222                .collect();
2223            if s.is_empty() {
2224                s.push(1); // scalar result
2225            }
2226            s
2227        };
2228
2229        Tensor::from_vec(result, &final_shape)
2230    }
2231
2232    /// Mean along an axis with optional keepdim.
2233    ///
2234    /// Uses [`BinnedAccumulatorF64`](crate::accumulator::BinnedAccumulatorF64)
2235    /// for deterministic summation before dividing by the axis length.
2236    pub fn mean_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2237        self.reduce_axis(axis, keepdim, |vals| {
2238            let mut acc = BinnedAccumulatorF64::new();
2239            for &v in vals { acc.add(v); }
2240            acc.finalize() / vals.len() as f64
2241        })
2242    }
2243
2244    /// Max along an axis with optional keepdim. Return `(values, indices)`.
2245    ///
2246    /// Ties are broken by choosing the first occurrence (smallest index).
2247    pub fn max_axis(&self, axis: usize, keepdim: bool) -> Result<(Tensor, Tensor), RuntimeError> {
2248        let ndim = self.ndim();
2249        if axis >= ndim {
2250            return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2251        }
2252        let axis_len = self.shape[axis];
2253        let mut out_shape = self.shape.clone();
2254        out_shape[axis] = 1;
2255        let out_numel = Self::shape_numel(&out_shape);
2256        let out_strides = Self::compute_strides(&out_shape);
2257        let data = self.to_vec();
2258        let mut values = Vec::with_capacity(out_numel);
2259        let mut idx_vals = Vec::with_capacity(out_numel);
2260        let mut indices = vec![0usize; ndim];
2261
2262        for out_idx in 0..out_numel {
2263            let mut remaining = out_idx;
2264            for d in 0..ndim {
2265                indices[d] = remaining / out_strides[d];
2266                remaining %= out_strides[d];
2267            }
2268            let mut best_val = f64::NEG_INFINITY;
2269            let mut best_idx = 0usize;
2270            for k in 0..axis_len {
2271                let mut flat = self.offset;
2272                for d in 0..ndim {
2273                    let idx = if d == axis { k } else { indices[d] };
2274                    flat += idx * self.strides[d];
2275                }
2276                let v = data[flat];
2277                if v > best_val {
2278                    best_val = v;
2279                    best_idx = k;
2280                }
2281            }
2282            values.push(best_val);
2283            idx_vals.push(best_idx as f64);
2284        }
2285
2286        let final_shape = if keepdim {
2287            out_shape
2288        } else {
2289            let mut s: Vec<usize> = self.shape.iter().enumerate()
2290                .filter(|&(i, _)| i != axis).map(|(_, &v)| v).collect();
2291            if s.is_empty() { s.push(1); }
2292            s
2293        };
2294        Ok((
2295            Tensor::from_vec(values, &final_shape)?,
2296            Tensor::from_vec(idx_vals, &final_shape)?,
2297        ))
2298    }
2299
2300    /// Min along an axis with optional keepdim. Return `(values, indices)`.
2301    ///
2302    /// Ties are broken by choosing the first occurrence (smallest index).
2303    pub fn min_axis(&self, axis: usize, keepdim: bool) -> Result<(Tensor, Tensor), RuntimeError> {
2304        let ndim = self.ndim();
2305        if axis >= ndim {
2306            return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2307        }
2308        let axis_len = self.shape[axis];
2309        let mut out_shape = self.shape.clone();
2310        out_shape[axis] = 1;
2311        let out_numel = Self::shape_numel(&out_shape);
2312        let out_strides = Self::compute_strides(&out_shape);
2313        let data = self.to_vec();
2314        let mut values = Vec::with_capacity(out_numel);
2315        let mut idx_vals = Vec::with_capacity(out_numel);
2316        let mut indices = vec![0usize; ndim];
2317
2318        for out_idx in 0..out_numel {
2319            let mut remaining = out_idx;
2320            for d in 0..ndim {
2321                indices[d] = remaining / out_strides[d];
2322                remaining %= out_strides[d];
2323            }
2324            let mut best_val = f64::INFINITY;
2325            let mut best_idx = 0usize;
2326            for k in 0..axis_len {
2327                let mut flat = self.offset;
2328                for d in 0..ndim {
2329                    let idx = if d == axis { k } else { indices[d] };
2330                    flat += idx * self.strides[d];
2331                }
2332                let v = data[flat];
2333                if v < best_val {
2334                    best_val = v;
2335                    best_idx = k;
2336                }
2337            }
2338            values.push(best_val);
2339            idx_vals.push(best_idx as f64);
2340        }
2341
2342        let final_shape = if keepdim {
2343            out_shape
2344        } else {
2345            let mut s: Vec<usize> = self.shape.iter().enumerate()
2346                .filter(|&(i, _)| i != axis).map(|(_, &v)| v).collect();
2347            if s.is_empty() { s.push(1); }
2348            s
2349        };
2350        Ok((
2351            Tensor::from_vec(values, &final_shape)?,
2352            Tensor::from_vec(idx_vals, &final_shape)?,
2353        ))
2354    }
2355
2356    /// Variance along an axis with optional keepdim.
2357    ///
2358    /// Computes population variance: `Var = sum((x - mean)^2) / N`.
2359    /// Uses [`BinnedAccumulatorF64`](crate::accumulator::BinnedAccumulatorF64)
2360    /// for the squared-differences summation.
2361    pub fn var_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2362        let mean_t = self.mean_axis(axis, true)?;
2363        let ndim = self.ndim();
2364        if axis >= ndim {
2365            return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2366        }
2367        let axis_len = self.shape[axis];
2368        let mut out_shape = self.shape.clone();
2369        out_shape[axis] = 1;
2370        let out_numel = Self::shape_numel(&out_shape);
2371        let out_strides = Self::compute_strides(&out_shape);
2372        let data = self.to_vec();
2373        let mean_data = mean_t.to_vec();
2374        let mut result = Vec::with_capacity(out_numel);
2375        let mut indices = vec![0usize; ndim];
2376
2377        for out_idx in 0..out_numel {
2378            let mut remaining = out_idx;
2379            for d in 0..ndim {
2380                indices[d] = remaining / out_strides[d];
2381                remaining %= out_strides[d];
2382            }
2383            let mu = mean_data[out_idx];
2384            let mut acc = BinnedAccumulatorF64::new();
2385            for k in 0..axis_len {
2386                let mut flat = self.offset;
2387                for d in 0..ndim {
2388                    let idx = if d == axis { k } else { indices[d] };
2389                    flat += idx * self.strides[d];
2390                }
2391                let diff = data[flat] - mu;
2392                acc.add(diff * diff);
2393            }
2394            result.push(acc.finalize() / axis_len as f64);
2395        }
2396
2397        let final_shape = if keepdim {
2398            out_shape
2399        } else {
2400            let mut s: Vec<usize> = self.shape.iter().enumerate()
2401                .filter(|&(i, _)| i != axis).map(|(_, &v)| v).collect();
2402            if s.is_empty() { s.push(1); }
2403            s
2404        };
2405        Tensor::from_vec(result, &final_shape)
2406    }
2407
2408    /// Standard deviation along an axis with optional keepdim.
2409    ///
2410    /// Computed as `sqrt(var_axis(axis, keepdim))`.
2411    pub fn std_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2412        let var = self.var_axis(axis, keepdim)?;
2413        Ok(var.map(|x| x.sqrt()))
2414    }
2415
2416    /// Product along an axis with optional keepdim.
2417    ///
2418    /// Computes a simple sequential product (exact for integer-like values).
2419    pub fn prod_axis(&self, axis: usize, keepdim: bool) -> Result<Tensor, RuntimeError> {
2420        self.reduce_axis(axis, keepdim, |vals| {
2421            // Product via exp(sum(ln(abs))) for numerical stability is overkill here;
2422            // simple product is deterministic and exact for integer-like values.
2423            let mut product = 1.0f64;
2424            for &v in vals { product *= v; }
2425            product
2426        })
2427    }
2428
2429    // -----------------------------------------------------------------------
2430    // Phase 2: Sort Operations
2431    // -----------------------------------------------------------------------
2432
2433    /// Sort along an axis (stable sort). Return the sorted tensor.
2434    ///
2435    /// For N-D tensors, independently sorts each 1-D slice along the specified
2436    /// axis. Uses `f64::partial_cmp` with deterministic tie-breaking by
2437    /// original index position.
2438    pub fn sort_axis(&self, axis: usize, descending: bool) -> Result<Tensor, RuntimeError> {
2439        let ndim = self.ndim();
2440        if axis >= ndim {
2441            return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2442        }
2443        let data = self.to_vec();
2444        let axis_len = self.shape[axis];
2445        let out_shape = self.shape.clone();
2446        let out_numel = Self::shape_numel(&out_shape);
2447
2448        // Build strides for iterating over all non-axis positions
2449        let mut iter_shape: Vec<usize> = Vec::new();
2450        for (i, &s) in self.shape.iter().enumerate() {
2451            if i != axis { iter_shape.push(s); }
2452        }
2453        let n_slices: usize = iter_shape.iter().product::<usize>().max(1);
2454
2455        let mut result = vec![0.0f64; out_numel];
2456
2457        // We iterate over all positions with axis index = 0
2458        let mut pos = vec![0usize; ndim];
2459        for slice_idx in 0..n_slices {
2460            // Compute the N-D position (with axis dim = 0)
2461            let mut remaining = slice_idx;
2462            let mut dim_idx = 0;
2463            for d in 0..ndim {
2464                if d == axis {
2465                    pos[d] = 0;
2466                } else {
2467                    let stride = {
2468                        let mut s = 1usize;
2469                        let mut di = 0;
2470                        for d2 in 0..ndim {
2471                            if d2 == axis { continue; }
2472                            if di > dim_idx { s *= self.shape[d2]; }
2473                            di += 1;
2474                        }
2475                        s
2476                    };
2477                    pos[d] = remaining / stride;
2478                    remaining %= stride;
2479                    dim_idx += 1;
2480                }
2481            }
2482
2483            // Gather values along axis
2484            let mut vals: Vec<(f64, usize)> = Vec::with_capacity(axis_len);
2485            for k in 0..axis_len {
2486                let mut flat = self.offset;
2487                for d in 0..ndim {
2488                    let idx = if d == axis { k } else { pos[d] };
2489                    flat += idx * self.strides[d];
2490                }
2491                vals.push((data[flat], k));
2492            }
2493
2494            // Stable sort with deterministic tie-breaking by original index
2495            if descending {
2496                vals.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
2497                    .then(a.1.cmp(&b.1)));
2498            } else {
2499                vals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
2500                    .then(a.1.cmp(&b.1)));
2501            }
2502
2503            // Scatter back
2504            for (k, &(v, _)) in vals.iter().enumerate() {
2505                let mut flat = 0;
2506                let out_strides_local = Self::compute_strides(&out_shape);
2507                for d in 0..ndim {
2508                    let idx = if d == axis { k } else { pos[d] };
2509                    flat += idx * out_strides_local[d];
2510                }
2511                result[flat] = v;
2512            }
2513        }
2514
2515        Tensor::from_vec(result, &out_shape)
2516    }
2517
2518    /// N-D argsort along an axis. Return a tensor of indices that would sort
2519    /// each slice along the given axis.
2520    ///
2521    /// Deterministic tie-breaking: ties are resolved by original index order.
2522    pub fn argsort_axis(&self, axis: usize, descending: bool) -> Result<Tensor, RuntimeError> {
2523        let ndim = self.ndim();
2524        if axis >= ndim {
2525            return Err(RuntimeError::IndexOutOfBounds { index: axis, length: ndim });
2526        }
2527        let data = self.to_vec();
2528        let axis_len = self.shape[axis];
2529        let out_shape = self.shape.clone();
2530        let out_numel = Self::shape_numel(&out_shape);
2531
2532        let mut iter_shape: Vec<usize> = Vec::new();
2533        for (i, &s) in self.shape.iter().enumerate() {
2534            if i != axis { iter_shape.push(s); }
2535        }
2536        let n_slices: usize = iter_shape.iter().product::<usize>().max(1);
2537
2538        let mut result = vec![0.0f64; out_numel];
2539        let mut pos = vec![0usize; ndim];
2540
2541        for slice_idx in 0..n_slices {
2542            let mut remaining = slice_idx;
2543            let mut dim_idx = 0;
2544            for d in 0..ndim {
2545                if d == axis {
2546                    pos[d] = 0;
2547                } else {
2548                    let stride = {
2549                        let mut s = 1usize;
2550                        let mut di = 0;
2551                        for d2 in 0..ndim {
2552                            if d2 == axis { continue; }
2553                            if di > dim_idx { s *= self.shape[d2]; }
2554                            di += 1;
2555                        }
2556                        s
2557                    };
2558                    pos[d] = remaining / stride;
2559                    remaining %= stride;
2560                    dim_idx += 1;
2561                }
2562            }
2563
2564            let mut vals: Vec<(f64, usize)> = Vec::with_capacity(axis_len);
2565            for k in 0..axis_len {
2566                let mut flat = self.offset;
2567                for d in 0..ndim {
2568                    let idx = if d == axis { k } else { pos[d] };
2569                    flat += idx * self.strides[d];
2570                }
2571                vals.push((data[flat], k));
2572            }
2573
2574            if descending {
2575                vals.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)
2576                    .then(a.1.cmp(&b.1)));
2577            } else {
2578                vals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
2579                    .then(a.1.cmp(&b.1)));
2580            }
2581
2582            for (k, &(_, orig_idx)) in vals.iter().enumerate() {
2583                let out_strides_local = Self::compute_strides(&out_shape);
2584                let mut flat = 0;
2585                for d in 0..ndim {
2586                    let idx = if d == axis { k } else { pos[d] };
2587                    flat += idx * out_strides_local[d];
2588                }
2589                result[flat] = orig_idx as f64;
2590            }
2591        }
2592
2593        Tensor::from_vec(result, &out_shape)
2594    }
2595
2596    // -----------------------------------------------------------------------
2597    // Phase 2: Einsum
2598    // -----------------------------------------------------------------------
2599
2600    /// Einstein summation notation.
2601    /// Supports patterns like "ij,jk->ik" (matmul), "ii->i" (diagonal),
2602    /// "ij->ji" (transpose), "ijk,ikl->ijl" (batched matmul).
2603    /// Uses BinnedAccumulator for all reductions.
2604    pub fn einsum(notation: &str, inputs: &[&Tensor]) -> Result<Tensor, RuntimeError> {
2605        // Parse notation: "subscripts->output" or just "subscripts" (implicit)
2606        let parts: Vec<&str> = notation.split("->").collect();
2607        if parts.len() != 2 {
2608            return Err(RuntimeError::InvalidOperation(
2609                format!("einsum: expected 'subscripts->output' notation, got '{}'", notation),
2610            ));
2611        }
2612        let input_specs: Vec<&str> = parts[0].split(',').collect();
2613        let output_spec = parts[1];
2614
2615        if input_specs.len() != inputs.len() {
2616            return Err(RuntimeError::InvalidOperation(
2617                format!("einsum: {} input specs but {} tensors", input_specs.len(), inputs.len()),
2618            ));
2619        }
2620
2621        // Build label → size mapping
2622        let mut label_size = std::collections::BTreeMap::new();
2623        for (i, &spec) in input_specs.iter().enumerate() {
2624            let chars: Vec<char> = spec.chars().collect();
2625            if chars.len() != inputs[i].ndim() {
2626                return Err(RuntimeError::InvalidOperation(
2627                    format!("einsum: spec '{}' has {} dims but tensor has {}", spec, chars.len(), inputs[i].ndim()),
2628                ));
2629            }
2630            for (d, &c) in chars.iter().enumerate() {
2631                let sz = inputs[i].shape()[d];
2632                if let Some(&prev) = label_size.get(&c) {
2633                    if prev != sz {
2634                        return Err(RuntimeError::InvalidOperation(
2635                            format!("einsum: label '{}' has conflicting sizes {} vs {}", c, prev, sz),
2636                        ));
2637                    }
2638                } else {
2639                    label_size.insert(c, sz);
2640                }
2641            }
2642        }
2643
2644        // Determine output shape
2645        let output_chars: Vec<char> = output_spec.chars().collect();
2646        let output_shape: Vec<usize> = output_chars.iter()
2647            .map(|c| label_size.get(c).copied().ok_or_else(||
2648                RuntimeError::InvalidOperation(format!("einsum: unknown label '{}' in output", c))))
2649            .collect::<Result<_, _>>()?;
2650        let out_numel = Self::shape_numel(&output_shape);
2651
2652        // Determine contraction labels (in input but not output)
2653        let output_set: std::collections::BTreeSet<char> = output_chars.iter().copied().collect();
2654        let contract_labels: Vec<char> = label_size.keys()
2655            .filter(|c| !output_set.contains(c))
2656            .copied()
2657            .collect();
2658        let contract_sizes: Vec<usize> = contract_labels.iter()
2659            .map(|c| label_size[c])
2660            .collect();
2661        let contract_numel: usize = contract_sizes.iter().product::<usize>().max(1);
2662
2663        // Precompute input spec chars
2664        let input_chars: Vec<Vec<char>> = input_specs.iter().map(|s| s.chars().collect()).collect();
2665
2666        // For each output position, iterate over contraction indices
2667        let out_strides = Self::compute_strides(&output_shape);
2668        let mut result = vec![0.0f64; out_numel];
2669
2670        // Pre-read input data
2671        let input_data: Vec<Vec<f64>> = inputs.iter().map(|t| t.to_vec()).collect();
2672        let input_strides: Vec<Vec<usize>> = inputs.iter().map(|t| t.strides.clone()).collect();
2673        let input_offsets: Vec<usize> = inputs.iter().map(|t| t.offset).collect();
2674
2675        for out_idx in 0..out_numel {
2676            // Compute output label values
2677            let mut label_vals = std::collections::BTreeMap::new();
2678            let mut remaining = out_idx;
2679            for (d, &c) in output_chars.iter().enumerate() {
2680                let stride = if d < out_strides.len() { out_strides[d] } else { 1 };
2681                label_vals.insert(c, remaining / stride);
2682                remaining %= stride;
2683            }
2684
2685            let mut acc = BinnedAccumulatorF64::new();
2686            // Iterate over all contraction index combinations
2687            for cidx in 0..contract_numel {
2688                // Compute contraction label values
2689                let mut cr = cidx;
2690                for (ci, &cl) in contract_labels.iter().enumerate() {
2691                    let stride: usize = contract_sizes[ci+1..].iter().product::<usize>().max(1);
2692                    label_vals.insert(cl, cr / stride);
2693                    cr %= stride;
2694                }
2695
2696                // Compute product of input elements
2697                let mut product = 1.0f64;
2698                for (inp_idx, chars) in input_chars.iter().enumerate() {
2699                    let mut flat = input_offsets[inp_idx];
2700                    for (d, &c) in chars.iter().enumerate() {
2701                        flat += label_vals[&c] * input_strides[inp_idx][d];
2702                    }
2703                    product *= input_data[inp_idx][flat];
2704                }
2705                acc.add(product);
2706            }
2707            result[out_idx] = acc.finalize();
2708        }
2709
2710        if output_shape.is_empty() {
2711            Tensor::from_vec(result, &[1])
2712        } else {
2713            Tensor::from_vec(result, &output_shape)
2714        }
2715    }
2716
2717    // -----------------------------------------------------------------------
2718    // Phase 2: Reshape / View Enhancements
2719    // -----------------------------------------------------------------------
2720
2721    /// Add a dimension of size 1 at position `dim`.
2722    ///
2723    /// For a tensor of shape `[A, B]`, `unsqueeze(0)` yields `[1, A, B]`,
2724    /// `unsqueeze(1)` yields `[A, 1, B]`, etc.
2725    pub fn unsqueeze(&self, dim: usize) -> Result<Tensor, RuntimeError> {
2726        let ndim = self.ndim();
2727        if dim > ndim {
2728            return Err(RuntimeError::IndexOutOfBounds { index: dim, length: ndim + 1 });
2729        }
2730        let mut new_shape = self.shape.clone();
2731        new_shape.insert(dim, 1);
2732        self.reshape(&new_shape)
2733    }
2734
2735    /// Remove a dimension of size 1 at position `dim`.
2736    /// If `dim` is `None`, removes all dimensions of size 1.
2737    pub fn squeeze(&self, dim: Option<usize>) -> Result<Tensor, RuntimeError> {
2738        match dim {
2739            Some(d) => {
2740                if d >= self.ndim() {
2741                    return Err(RuntimeError::IndexOutOfBounds { index: d, length: self.ndim() });
2742                }
2743                if self.shape[d] != 1 {
2744                    return Err(RuntimeError::InvalidOperation(
2745                        format!("squeeze: dimension {} has size {}, not 1", d, self.shape[d]),
2746                    ));
2747                }
2748                let mut new_shape = self.shape.clone();
2749                new_shape.remove(d);
2750                if new_shape.is_empty() {
2751                    new_shape.push(1); // scalar
2752                }
2753                self.reshape(&new_shape)
2754            }
2755            None => {
2756                let new_shape: Vec<usize> = self.shape.iter()
2757                    .filter(|&&s| s != 1)
2758                    .copied()
2759                    .collect();
2760                let new_shape = if new_shape.is_empty() { vec![1] } else { new_shape };
2761                self.reshape(&new_shape)
2762            }
2763        }
2764    }
2765
2766    /// Broadcast without copying. Return a view with `stride=0` for broadcasted dims.
2767    ///
2768    /// Alias for [`broadcast_to`](Tensor::broadcast_to).
2769    pub fn expand(&self, target_shape: &[usize]) -> Result<Tensor, RuntimeError> {
2770        self.broadcast_to(target_shape)
2771    }
2772
2773    /// Flatten a range of dimensions [start_dim, end_dim] into a single dimension.
2774    pub fn flatten(&self, start_dim: usize, end_dim: usize) -> Result<Tensor, RuntimeError> {
2775        if start_dim > end_dim || end_dim >= self.ndim() {
2776            return Err(RuntimeError::InvalidOperation(
2777                format!("flatten: invalid dim range [{}, {}] for {}D tensor", start_dim, end_dim, self.ndim()),
2778            ));
2779        }
2780        let mut new_shape = Vec::new();
2781        for i in 0..start_dim {
2782            new_shape.push(self.shape[i]);
2783        }
2784        let flat_size: usize = self.shape[start_dim..=end_dim].iter().product();
2785        new_shape.push(flat_size);
2786        for i in (end_dim + 1)..self.ndim() {
2787            new_shape.push(self.shape[i]);
2788        }
2789        self.reshape(&new_shape)
2790    }
2791
2792    /// Split tensor into `n` roughly equal chunks along dimension `dim`.
2793    pub fn chunk(&self, n: usize, dim: usize) -> Result<Vec<Tensor>, RuntimeError> {
2794        if dim >= self.ndim() {
2795            return Err(RuntimeError::IndexOutOfBounds { index: dim, length: self.ndim() });
2796        }
2797        if n == 0 {
2798            return Err(RuntimeError::InvalidOperation("chunk: n must be > 0".into()));
2799        }
2800        let dim_size = self.shape[dim];
2801        let chunk_size = (dim_size + n - 1) / n;
2802        let mut sizes = Vec::new();
2803        let mut remaining = dim_size;
2804        while remaining > 0 {
2805            let s = remaining.min(chunk_size);
2806            sizes.push(s);
2807            remaining -= s;
2808        }
2809        self.split(&sizes, dim)
2810    }
2811
2812    /// Split tensor along dimension `dim` according to the given sizes.
2813    pub fn split(&self, sizes: &[usize], dim: usize) -> Result<Vec<Tensor>, RuntimeError> {
2814        if dim >= self.ndim() {
2815            return Err(RuntimeError::IndexOutOfBounds { index: dim, length: self.ndim() });
2816        }
2817        let total: usize = sizes.iter().sum();
2818        if total != self.shape[dim] {
2819            return Err(RuntimeError::InvalidOperation(
2820                format!("split: sizes sum {} != dim size {}", total, self.shape[dim]),
2821            ));
2822        }
2823
2824        let mut results = Vec::new();
2825        let mut offset = 0;
2826
2827        for &sz in sizes {
2828            let ranges: Vec<(usize, usize)> = self.shape.iter()
2829                .enumerate()
2830                .map(|(i, &s)| {
2831                    if i == dim { (offset, offset + sz) } else { (0, s) }
2832                })
2833                .collect();
2834            let chunk = self.slice(&ranges)?;
2835            // Materialize as contiguous
2836            results.push(chunk.to_contiguous());
2837            offset += sz;
2838        }
2839
2840        Ok(results)
2841    }
2842
2843    /// Fused `alpha * self + beta * other` element-wise. Single pass, one allocation.
2844    ///
2845    /// Critical for LSTM/GRU gates where `f * c_prev + i * g` would otherwise
2846    /// create 3 intermediate tensors.
2847    pub fn scale_add(&self, alpha: f64, other: &Tensor, beta: f64) -> Result<Tensor, RuntimeError> {
2848        if self.shape != other.shape {
2849            return Err(RuntimeError::InvalidOperation(
2850                "scale_add: shape mismatch".to_string(),
2851            ));
2852        }
2853        let a = self.to_vec();
2854        let b = other.to_vec();
2855        let result: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| alpha * x + beta * y).collect();
2856        Tensor::from_vec(result, &self.shape)
2857    }
2858}
2859