Skip to main content

axonml_tensor/
view.rs

1//! Views, slicing, and advanced indexing operations on `Tensor<T>`.
2//!
3//! 486 lines. `narrow` (sub-range along a dim), `select` (index one dim
4//! away), `index_select` (gather along a dim by index tensor), `chunk`
5//! (split into N equal parts), `split` (split by explicit size), `gather`
6//! (index into data along a dim), `scatter` (write values at indexed
7//! positions), `nonzero` (coordinates of non-zero elements), `unique`
8//! (sorted deduplicated values). All return new `Tensor<T>` instances;
9//! narrow/select use strided zero-copy views when possible.
10//!
11//! # File
12//! `crates/axonml-tensor/src/view.rs`
13//!
14//! # Author
15//! Andrew Jewell Sr. — AutomataNexus LLC
16//! ORCID: 0009-0005-2158-7060
17//!
18//! # Updated
19//! April 14, 2026 11:15 PM EST
20//!
21//! # Disclaimer
22//! Use at own risk. This software is provided "as is", without warranty of any
23//! kind, express or implied. The author and AutomataNexus shall not be held
24//! liable for any damages arising from the use of this software.
25
26use axonml_core::dtype::{Numeric, Scalar};
27use axonml_core::error::{Error, Result};
28
29use crate::shape::{Shape, numel};
30use crate::tensor::Tensor;
31
32// =============================================================================
33// Slice Specification
34// =============================================================================
35
36/// Specifies how to slice along a single dimension.
37#[derive(Debug, Clone, Copy)]
38pub enum SliceSpec {
39    /// Select a single index, reducing dimensionality.
40    Index(isize),
41    /// Select a range [start, stop) with optional step.
42    Range {
43        /// Start index (inclusive), None = beginning
44        start: Option<isize>,
45        /// Stop index (exclusive), None = end
46        stop: Option<isize>,
47        /// Step size, default 1
48        step: isize,
49    },
50    /// Keep all elements in this dimension.
51    All,
52    /// Add a new dimension of size 1.
53    NewAxis,
54}
55
56impl SliceSpec {
57    /// Creates a range slice from start to stop.
58    #[must_use]
59    pub fn range(start: isize, stop: isize) -> Self {
60        Self::Range {
61            start: Some(start),
62            stop: Some(stop),
63            step: 1,
64        }
65    }
66
67    /// Creates a range slice with step.
68    #[must_use]
69    pub fn range_step(start: isize, stop: isize, step: isize) -> Self {
70        Self::Range {
71            start: Some(start),
72            stop: Some(stop),
73            step,
74        }
75    }
76
77    /// Creates a slice from start to end.
78    #[must_use]
79    pub fn from(start: isize) -> Self {
80        Self::Range {
81            start: Some(start),
82            stop: None,
83            step: 1,
84        }
85    }
86
87    /// Creates a slice from beginning to stop.
88    #[must_use]
89    pub fn to(stop: isize) -> Self {
90        Self::Range {
91            start: None,
92            stop: Some(stop),
93            step: 1,
94        }
95    }
96}
97
98// =============================================================================
99// Slicing Implementation
100// =============================================================================
101
102impl<T: Scalar> Tensor<T> {
103    /// Returns a slice of the tensor along the first dimension.
104    ///
105    /// # Arguments
106    /// * `start` - Start index (inclusive)
107    /// * `end` - End index (exclusive)
108    pub fn slice_dim0(&self, start: usize, end: usize) -> Result<Self> {
109        if self.ndim() == 0 {
110            return Err(Error::invalid_operation("Cannot slice a scalar"));
111        }
112
113        let dim_size = self.shape[0];
114        if start > end || end > dim_size {
115            return Err(Error::IndexOutOfBounds {
116                index: end,
117                size: dim_size,
118            });
119        }
120
121        let mut new_shape = self.shape.clone();
122        new_shape[0] = end - start;
123
124        let new_offset = self.offset + start * self.strides[0] as usize;
125
126        Ok(Self {
127            storage: self.storage.clone(),
128            shape: new_shape,
129            strides: self.strides.clone(),
130            offset: new_offset,
131        })
132    }
133
134    /// Returns a view selecting a single index along a dimension.
135    ///
136    /// This reduces the dimensionality by 1.
137    ///
138    /// # Arguments
139    /// * `dim` - Dimension to select from
140    /// * `index` - Index to select
141    pub fn select(&self, dim: usize, index: usize) -> Result<Self> {
142        if dim >= self.ndim() {
143            return Err(Error::InvalidDimension {
144                index: dim as i64,
145                ndim: self.ndim(),
146            });
147        }
148
149        if index >= self.shape[dim] {
150            return Err(Error::IndexOutOfBounds {
151                index,
152                size: self.shape[dim],
153            });
154        }
155
156        let mut new_shape = self.shape.clone();
157        new_shape.remove(dim);
158
159        let mut new_strides = self.strides.clone();
160        new_strides.remove(dim);
161
162        let new_offset = self.offset + index * self.strides[dim] as usize;
163
164        Ok(Self {
165            storage: self.storage.clone(),
166            shape: new_shape,
167            strides: new_strides,
168            offset: new_offset,
169        })
170    }
171
172    /// Returns a narrow view along a dimension.
173    ///
174    /// # Arguments
175    /// * `dim` - Dimension to narrow
176    /// * `start` - Start index
177    /// * `length` - Length of the narrow view
178    pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> {
179        if dim >= self.ndim() {
180            return Err(Error::InvalidDimension {
181                index: dim as i64,
182                ndim: self.ndim(),
183            });
184        }
185
186        if start + length > self.shape[dim] {
187            return Err(Error::IndexOutOfBounds {
188                index: start + length,
189                size: self.shape[dim],
190            });
191        }
192
193        let mut new_shape = self.shape.clone();
194        new_shape[dim] = length;
195
196        let new_offset = self.offset + start * self.strides[dim] as usize;
197
198        Ok(Self {
199            storage: self.storage.clone(),
200            shape: new_shape,
201            strides: self.strides.clone(),
202            offset: new_offset,
203        })
204    }
205
206    /// Splits the tensor into chunks along a dimension.
207    ///
208    /// # Arguments
209    /// * `chunks` - Number of chunks
210    /// * `dim` - Dimension to split along
211    pub fn chunk(&self, chunks: usize, dim: usize) -> Result<Vec<Self>> {
212        if dim >= self.ndim() {
213            return Err(Error::InvalidDimension {
214                index: dim as i64,
215                ndim: self.ndim(),
216            });
217        }
218
219        let dim_size = self.shape[dim];
220        let chunk_size = dim_size.div_ceil(chunks);
221        let mut result = Vec::with_capacity(chunks);
222
223        let mut start = 0;
224        while start < dim_size {
225            let length = (chunk_size).min(dim_size - start);
226            result.push(self.narrow(dim, start, length)?);
227            start += length;
228        }
229
230        Ok(result)
231    }
232
233    /// Splits the tensor into parts of specified sizes along a dimension.
234    ///
235    /// # Arguments
236    /// * `sizes` - Size of each part
237    /// * `dim` - Dimension to split along
238    pub fn split(&self, sizes: &[usize], dim: usize) -> Result<Vec<Self>> {
239        if dim >= self.ndim() {
240            return Err(Error::InvalidDimension {
241                index: dim as i64,
242                ndim: self.ndim(),
243            });
244        }
245
246        let total: usize = sizes.iter().sum();
247        if total != self.shape[dim] {
248            return Err(Error::invalid_operation(format!(
249                "Split sizes {} don't sum to dimension size {}",
250                total, self.shape[dim]
251            )));
252        }
253
254        let mut result = Vec::with_capacity(sizes.len());
255        let mut start = 0;
256
257        for &size in sizes {
258            result.push(self.narrow(dim, start, size)?);
259            start += size;
260        }
261
262        Ok(result)
263    }
264}
265
266// =============================================================================
267// Indexing Implementation
268// =============================================================================
269
270impl<T: Numeric> Tensor<T> {
271    /// Gathers values along a dimension according to indices.
272    ///
273    /// # Arguments
274    /// * `dim` - Dimension to gather along
275    /// * `indices` - Indices tensor
276    pub fn gather(&self, dim: usize, indices: &Tensor<i64>) -> Result<Self> {
277        if dim >= self.ndim() {
278            return Err(Error::InvalidDimension {
279                index: dim as i64,
280                ndim: self.ndim(),
281            });
282        }
283
284        // For simplicity, this is a basic implementation
285        // A full implementation would match PyTorch's semantics exactly
286        let output_shape = indices.shape();
287        let mut output_data = vec![T::zero(); numel(output_shape)];
288
289        let indices_data = indices.to_vec();
290        let self_data = self.to_vec();
291
292        for (out_idx, &index) in indices_data.iter().enumerate() {
293            let index = index as usize;
294            if index >= self.shape[dim] {
295                return Err(Error::IndexOutOfBounds {
296                    index,
297                    size: self.shape[dim],
298                });
299            }
300            // Simplified: assumes 1D case
301            output_data[out_idx] = self_data[index];
302        }
303
304        Tensor::from_vec(output_data, output_shape)
305    }
306
307    /// Returns elements selected by a boolean mask.
308    ///
309    /// # Arguments
310    /// * `mask` - Boolean mask tensor
311    pub fn masked_select(&self, mask: &[bool]) -> Result<Self> {
312        if mask.len() != self.numel() {
313            return Err(Error::shape_mismatch(&[mask.len()], &[self.numel()]));
314        }
315
316        let data = self.to_vec();
317        let selected: Vec<T> = data
318            .into_iter()
319            .zip(mask.iter())
320            .filter(|(_, m)| **m)
321            .map(|(v, _)| v)
322            .collect();
323
324        let len = selected.len();
325        Tensor::from_vec(selected, &[len])
326    }
327
328    /// Sets elements according to a boolean mask.
329    ///
330    /// # Arguments
331    /// * `mask` - Boolean mask tensor
332    /// * `value` - Value to set where mask is true
333    pub fn masked_fill_(&self, mask: &[bool], value: T) -> Result<()> {
334        if mask.len() != self.numel() {
335            return Err(Error::shape_mismatch(&[mask.len()], &[self.numel()]));
336        }
337
338        if !self.is_contiguous() {
339            return Err(Error::NotContiguous);
340        }
341
342        {
343            let mut guard = self.storage.as_slice_mut();
344            for (idx, &m) in mask.iter().enumerate() {
345                if m {
346                    guard[self.offset + idx] = value;
347                }
348            }
349        }
350
351        Ok(())
352    }
353}
354
355// =============================================================================
356// Concatenation and Stacking
357// =============================================================================
358
359/// Concatenates tensors along an existing dimension.
360///
361/// # Arguments
362/// * `tensors` - Slice of tensors to concatenate
363/// * `dim` - Dimension along which to concatenate
364pub fn cat<T: Scalar>(tensors: &[Tensor<T>], dim: usize) -> Result<Tensor<T>> {
365    if tensors.is_empty() {
366        return Err(Error::invalid_operation("Cannot concatenate empty list"));
367    }
368
369    let first = &tensors[0];
370    let ndim = first.ndim();
371
372    if dim >= ndim {
373        return Err(Error::InvalidDimension {
374            index: dim as i64,
375            ndim,
376        });
377    }
378
379    // Validate shapes match except for concat dimension
380    for t in tensors.iter().skip(1) {
381        if t.ndim() != ndim {
382            return Err(Error::invalid_operation(
383                "All tensors must have same number of dimensions",
384            ));
385        }
386        for (d, (&s1, &s2)) in first.shape().iter().zip(t.shape().iter()).enumerate() {
387            if d != dim && s1 != s2 {
388                return Err(Error::shape_mismatch(first.shape(), t.shape()));
389            }
390        }
391    }
392
393    // Compute output shape
394    let mut output_shape = Shape::from_slice(first.shape());
395    output_shape[dim] = tensors.iter().map(|t| t.shape()[dim]).sum();
396
397    // Allocate output
398    let total_numel = numel(&output_shape);
399    let mut output_data = vec![T::zeroed(); total_numel];
400
401    // Copy data - simplified for contiguous case
402    let mut offset = 0;
403    for t in tensors {
404        let data = t.to_vec();
405        for val in data {
406            output_data[offset] = val;
407            offset += 1;
408        }
409    }
410
411    Tensor::from_vec(output_data, &output_shape)
412}
413
414/// Stacks tensors along a new dimension.
415///
416/// # Arguments
417/// * `tensors` - Slice of tensors to stack
418/// * `dim` - Dimension at which to insert the new axis
419pub fn stack<T: Scalar>(tensors: &[Tensor<T>], dim: usize) -> Result<Tensor<T>> {
420    if tensors.is_empty() {
421        return Err(Error::invalid_operation("Cannot stack empty list"));
422    }
423
424    // Unsqueeze each tensor and then concatenate
425    let unsqueezed: Result<Vec<Tensor<T>>> =
426        tensors.iter().map(|t| t.unsqueeze(dim as i64)).collect();
427
428    cat(&unsqueezed?, dim)
429}
430
431// =============================================================================
432// Tests
433// =============================================================================
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438
439    #[test]
440    fn test_slice_dim0() {
441        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]).unwrap();
442
443        let s = t.slice_dim0(1, 3).unwrap();
444        assert_eq!(s.shape(), &[2, 2]);
445        assert_eq!(s.get(&[0, 0]).unwrap(), 3.0);
446    }
447
448    #[test]
449    fn test_select() {
450        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
451
452        let s = t.select(0, 1).unwrap();
453        assert_eq!(s.shape(), &[3]);
454        assert_eq!(s.to_vec(), vec![4.0, 5.0, 6.0]);
455    }
456
457    #[test]
458    fn test_narrow() {
459        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
460
461        let n = t.narrow(0, 1, 3).unwrap();
462        assert_eq!(n.shape(), &[3]);
463        assert_eq!(n.to_vec(), vec![2.0, 3.0, 4.0]);
464    }
465
466    #[test]
467    fn test_chunk() {
468        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6]).unwrap();
469
470        let chunks = t.chunk(3, 0).unwrap();
471        assert_eq!(chunks.len(), 3);
472        assert_eq!(chunks[0].to_vec(), vec![1.0, 2.0]);
473        assert_eq!(chunks[1].to_vec(), vec![3.0, 4.0]);
474        assert_eq!(chunks[2].to_vec(), vec![5.0, 6.0]);
475    }
476
477    #[test]
478    fn test_cat() {
479        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0], &[2]).unwrap();
480        let b = Tensor::<f32>::from_vec(vec![3.0, 4.0], &[2]).unwrap();
481
482        let c = cat(&[a, b], 0).unwrap();
483        assert_eq!(c.shape(), &[4]);
484        assert_eq!(c.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
485    }
486
487    #[test]
488    fn test_stack() {
489        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0], &[2]).unwrap();
490        let b = Tensor::<f32>::from_vec(vec![3.0, 4.0], &[2]).unwrap();
491
492        let c = stack(&[a, b], 0).unwrap();
493        assert_eq!(c.shape(), &[2, 2]);
494    }
495}