Skip to main content

axonml_tensor/
view.rs

1//! Views and Slicing - Tensor Indexing Operations
2//!
3//! # File
4//! `crates/axonml-tensor/src/view.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml_core::dtype::{Numeric, Scalar};
18use axonml_core::error::{Error, Result};
19
20use crate::shape::{Shape, numel};
21use crate::tensor::Tensor;
22
23// =============================================================================
24// Slice Specification
25// =============================================================================
26
27/// Specifies how to slice along a single dimension.
28#[derive(Debug, Clone, Copy)]
29pub enum SliceSpec {
30    /// Select a single index, reducing dimensionality.
31    Index(isize),
32    /// Select a range [start, stop) with optional step.
33    Range {
34        /// Start index (inclusive), None = beginning
35        start: Option<isize>,
36        /// Stop index (exclusive), None = end
37        stop: Option<isize>,
38        /// Step size, default 1
39        step: isize,
40    },
41    /// Keep all elements in this dimension.
42    All,
43    /// Add a new dimension of size 1.
44    NewAxis,
45}
46
47impl SliceSpec {
48    /// Creates a range slice from start to stop.
49    #[must_use]
50    pub fn range(start: isize, stop: isize) -> Self {
51        Self::Range {
52            start: Some(start),
53            stop: Some(stop),
54            step: 1,
55        }
56    }
57
58    /// Creates a range slice with step.
59    #[must_use]
60    pub fn range_step(start: isize, stop: isize, step: isize) -> Self {
61        Self::Range {
62            start: Some(start),
63            stop: Some(stop),
64            step,
65        }
66    }
67
68    /// Creates a slice from start to end.
69    #[must_use]
70    pub fn from(start: isize) -> Self {
71        Self::Range {
72            start: Some(start),
73            stop: None,
74            step: 1,
75        }
76    }
77
78    /// Creates a slice from beginning to stop.
79    #[must_use]
80    pub fn to(stop: isize) -> Self {
81        Self::Range {
82            start: None,
83            stop: Some(stop),
84            step: 1,
85        }
86    }
87}
88
89// =============================================================================
90// Slicing Implementation
91// =============================================================================
92
93impl<T: Scalar> Tensor<T> {
94    /// Returns a slice of the tensor along the first dimension.
95    ///
96    /// # Arguments
97    /// * `start` - Start index (inclusive)
98    /// * `end` - End index (exclusive)
99    pub fn slice_dim0(&self, start: usize, end: usize) -> Result<Self> {
100        if self.ndim() == 0 {
101            return Err(Error::invalid_operation("Cannot slice a scalar"));
102        }
103
104        let dim_size = self.shape[0];
105        if start > end || end > dim_size {
106            return Err(Error::IndexOutOfBounds {
107                index: end,
108                size: dim_size,
109            });
110        }
111
112        let mut new_shape = self.shape.clone();
113        new_shape[0] = end - start;
114
115        let new_offset = self.offset + start * self.strides[0] as usize;
116
117        Ok(Self {
118            storage: self.storage.clone(),
119            shape: new_shape,
120            strides: self.strides.clone(),
121            offset: new_offset,
122        })
123    }
124
125    /// Returns a view selecting a single index along a dimension.
126    ///
127    /// This reduces the dimensionality by 1.
128    ///
129    /// # Arguments
130    /// * `dim` - Dimension to select from
131    /// * `index` - Index to select
132    pub fn select(&self, dim: usize, index: usize) -> Result<Self> {
133        if dim >= self.ndim() {
134            return Err(Error::InvalidDimension {
135                index: dim as i64,
136                ndim: self.ndim(),
137            });
138        }
139
140        if index >= self.shape[dim] {
141            return Err(Error::IndexOutOfBounds {
142                index,
143                size: self.shape[dim],
144            });
145        }
146
147        let mut new_shape = self.shape.clone();
148        new_shape.remove(dim);
149
150        let mut new_strides = self.strides.clone();
151        new_strides.remove(dim);
152
153        let new_offset = self.offset + index * self.strides[dim] as usize;
154
155        Ok(Self {
156            storage: self.storage.clone(),
157            shape: new_shape,
158            strides: new_strides,
159            offset: new_offset,
160        })
161    }
162
163    /// Returns a narrow view along a dimension.
164    ///
165    /// # Arguments
166    /// * `dim` - Dimension to narrow
167    /// * `start` - Start index
168    /// * `length` - Length of the narrow view
169    pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> {
170        if dim >= self.ndim() {
171            return Err(Error::InvalidDimension {
172                index: dim as i64,
173                ndim: self.ndim(),
174            });
175        }
176
177        if start + length > self.shape[dim] {
178            return Err(Error::IndexOutOfBounds {
179                index: start + length,
180                size: self.shape[dim],
181            });
182        }
183
184        let mut new_shape = self.shape.clone();
185        new_shape[dim] = length;
186
187        let new_offset = self.offset + start * self.strides[dim] as usize;
188
189        Ok(Self {
190            storage: self.storage.clone(),
191            shape: new_shape,
192            strides: self.strides.clone(),
193            offset: new_offset,
194        })
195    }
196
197    /// Splits the tensor into chunks along a dimension.
198    ///
199    /// # Arguments
200    /// * `chunks` - Number of chunks
201    /// * `dim` - Dimension to split along
202    pub fn chunk(&self, chunks: usize, dim: usize) -> Result<Vec<Self>> {
203        if dim >= self.ndim() {
204            return Err(Error::InvalidDimension {
205                index: dim as i64,
206                ndim: self.ndim(),
207            });
208        }
209
210        let dim_size = self.shape[dim];
211        let chunk_size = dim_size.div_ceil(chunks);
212        let mut result = Vec::with_capacity(chunks);
213
214        let mut start = 0;
215        while start < dim_size {
216            let length = (chunk_size).min(dim_size - start);
217            result.push(self.narrow(dim, start, length)?);
218            start += length;
219        }
220
221        Ok(result)
222    }
223
224    /// Splits the tensor into parts of specified sizes along a dimension.
225    ///
226    /// # Arguments
227    /// * `sizes` - Size of each part
228    /// * `dim` - Dimension to split along
229    pub fn split(&self, sizes: &[usize], dim: usize) -> Result<Vec<Self>> {
230        if dim >= self.ndim() {
231            return Err(Error::InvalidDimension {
232                index: dim as i64,
233                ndim: self.ndim(),
234            });
235        }
236
237        let total: usize = sizes.iter().sum();
238        if total != self.shape[dim] {
239            return Err(Error::invalid_operation(format!(
240                "Split sizes {} don't sum to dimension size {}",
241                total, self.shape[dim]
242            )));
243        }
244
245        let mut result = Vec::with_capacity(sizes.len());
246        let mut start = 0;
247
248        for &size in sizes {
249            result.push(self.narrow(dim, start, size)?);
250            start += size;
251        }
252
253        Ok(result)
254    }
255}
256
257// =============================================================================
258// Indexing Implementation
259// =============================================================================
260
261impl<T: Numeric> Tensor<T> {
262    /// Gathers values along a dimension according to indices.
263    ///
264    /// # Arguments
265    /// * `dim` - Dimension to gather along
266    /// * `indices` - Indices tensor
267    pub fn gather(&self, dim: usize, indices: &Tensor<i64>) -> Result<Self> {
268        if dim >= self.ndim() {
269            return Err(Error::InvalidDimension {
270                index: dim as i64,
271                ndim: self.ndim(),
272            });
273        }
274
275        // For simplicity, this is a basic implementation
276        // A full implementation would match PyTorch's semantics exactly
277        let output_shape = indices.shape();
278        let mut output_data = vec![T::zero(); numel(output_shape)];
279
280        let indices_data = indices.to_vec();
281        let self_data = self.to_vec();
282
283        for (out_idx, &index) in indices_data.iter().enumerate() {
284            let index = index as usize;
285            if index >= self.shape[dim] {
286                return Err(Error::IndexOutOfBounds {
287                    index,
288                    size: self.shape[dim],
289                });
290            }
291            // Simplified: assumes 1D case
292            output_data[out_idx] = self_data[index];
293        }
294
295        Tensor::from_vec(output_data, output_shape)
296    }
297
298    /// Returns elements selected by a boolean mask.
299    ///
300    /// # Arguments
301    /// * `mask` - Boolean mask tensor
302    pub fn masked_select(&self, mask: &[bool]) -> Result<Self> {
303        if mask.len() != self.numel() {
304            return Err(Error::shape_mismatch(&[mask.len()], &[self.numel()]));
305        }
306
307        let data = self.to_vec();
308        let selected: Vec<T> = data
309            .into_iter()
310            .zip(mask.iter())
311            .filter(|(_, m)| **m)
312            .map(|(v, _)| v)
313            .collect();
314
315        let len = selected.len();
316        Tensor::from_vec(selected, &[len])
317    }
318
319    /// Sets elements according to a boolean mask.
320    ///
321    /// # Arguments
322    /// * `mask` - Boolean mask tensor
323    /// * `value` - Value to set where mask is true
324    pub fn masked_fill_(&self, mask: &[bool], value: T) -> Result<()> {
325        if mask.len() != self.numel() {
326            return Err(Error::shape_mismatch(&[mask.len()], &[self.numel()]));
327        }
328
329        if !self.is_contiguous() {
330            return Err(Error::NotContiguous);
331        }
332
333        {
334            let mut guard = self.storage.as_slice_mut();
335            for (idx, &m) in mask.iter().enumerate() {
336                if m {
337                    guard[self.offset + idx] = value;
338                }
339            }
340        }
341
342        Ok(())
343    }
344}
345
346// =============================================================================
347// Concatenation and Stacking
348// =============================================================================
349
350/// Concatenates tensors along an existing dimension.
351///
352/// # Arguments
353/// * `tensors` - Slice of tensors to concatenate
354/// * `dim` - Dimension along which to concatenate
355pub fn cat<T: Scalar>(tensors: &[Tensor<T>], dim: usize) -> Result<Tensor<T>> {
356    if tensors.is_empty() {
357        return Err(Error::invalid_operation("Cannot concatenate empty list"));
358    }
359
360    let first = &tensors[0];
361    let ndim = first.ndim();
362
363    if dim >= ndim {
364        return Err(Error::InvalidDimension {
365            index: dim as i64,
366            ndim,
367        });
368    }
369
370    // Validate shapes match except for concat dimension
371    for t in tensors.iter().skip(1) {
372        if t.ndim() != ndim {
373            return Err(Error::invalid_operation(
374                "All tensors must have same number of dimensions",
375            ));
376        }
377        for (d, (&s1, &s2)) in first.shape().iter().zip(t.shape().iter()).enumerate() {
378            if d != dim && s1 != s2 {
379                return Err(Error::shape_mismatch(first.shape(), t.shape()));
380            }
381        }
382    }
383
384    // Compute output shape
385    let mut output_shape = Shape::from_slice(first.shape());
386    output_shape[dim] = tensors.iter().map(|t| t.shape()[dim]).sum();
387
388    // Allocate output
389    let total_numel = numel(&output_shape);
390    let mut output_data = vec![T::zeroed(); total_numel];
391
392    // Copy data - simplified for contiguous case
393    let mut offset = 0;
394    for t in tensors {
395        let data = t.to_vec();
396        for val in data {
397            output_data[offset] = val;
398            offset += 1;
399        }
400    }
401
402    Tensor::from_vec(output_data, &output_shape)
403}
404
405/// Stacks tensors along a new dimension.
406///
407/// # Arguments
408/// * `tensors` - Slice of tensors to stack
409/// * `dim` - Dimension at which to insert the new axis
410pub fn stack<T: Scalar>(tensors: &[Tensor<T>], dim: usize) -> Result<Tensor<T>> {
411    if tensors.is_empty() {
412        return Err(Error::invalid_operation("Cannot stack empty list"));
413    }
414
415    // Unsqueeze each tensor and then concatenate
416    let unsqueezed: Result<Vec<Tensor<T>>> =
417        tensors.iter().map(|t| t.unsqueeze(dim as i64)).collect();
418
419    cat(&unsqueezed?, dim)
420}
421
422// =============================================================================
423// Tests
424// =============================================================================
425
426#[cfg(test)]
427mod tests {
428    use super::*;
429
430    #[test]
431    fn test_slice_dim0() {
432        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]).unwrap();
433
434        let s = t.slice_dim0(1, 3).unwrap();
435        assert_eq!(s.shape(), &[2, 2]);
436        assert_eq!(s.get(&[0, 0]).unwrap(), 3.0);
437    }
438
439    #[test]
440    fn test_select() {
441        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
442
443        let s = t.select(0, 1).unwrap();
444        assert_eq!(s.shape(), &[3]);
445        assert_eq!(s.to_vec(), vec![4.0, 5.0, 6.0]);
446    }
447
448    #[test]
449    fn test_narrow() {
450        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
451
452        let n = t.narrow(0, 1, 3).unwrap();
453        assert_eq!(n.shape(), &[3]);
454        assert_eq!(n.to_vec(), vec![2.0, 3.0, 4.0]);
455    }
456
457    #[test]
458    fn test_chunk() {
459        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6]).unwrap();
460
461        let chunks = t.chunk(3, 0).unwrap();
462        assert_eq!(chunks.len(), 3);
463        assert_eq!(chunks[0].to_vec(), vec![1.0, 2.0]);
464        assert_eq!(chunks[1].to_vec(), vec![3.0, 4.0]);
465        assert_eq!(chunks[2].to_vec(), vec![5.0, 6.0]);
466    }
467
468    #[test]
469    fn test_cat() {
470        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0], &[2]).unwrap();
471        let b = Tensor::<f32>::from_vec(vec![3.0, 4.0], &[2]).unwrap();
472
473        let c = cat(&[a, b], 0).unwrap();
474        assert_eq!(c.shape(), &[4]);
475        assert_eq!(c.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
476    }
477
478    #[test]
479    fn test_stack() {
480        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0], &[2]).unwrap();
481        let b = Tensor::<f32>::from_vec(vec![3.0, 4.0], &[2]).unwrap();
482
483        let c = stack(&[a, b], 0).unwrap();
484        assert_eq!(c.shape(), &[2, 2]);
485    }
486}