Skip to main content

axonml_tensor/
view.rs

1//! Views and Slicing - Tensor Indexing Operations
2//!
3//! Provides functionality for creating views into tensors through slicing,
4//! indexing, and masking operations. Views share storage with the original
5//! tensor when possible, avoiding unnecessary copies.
6//!
7//! # Key Features
8//! - Zero-copy slicing for contiguous ranges
9//! - Advanced indexing with integer arrays
10//! - Boolean masking
11//! - Gather and scatter operations
12//!
13//! @version 0.1.0
14//! @author `AutomataNexus` Development Team
15
16use axonml_core::dtype::{Numeric, Scalar};
17use axonml_core::error::{Error, Result};
18
19use crate::shape::{numel, Shape};
20use crate::tensor::Tensor;
21
22// =============================================================================
23// Slice Specification
24// =============================================================================
25
26/// Specifies how to slice along a single dimension.
27#[derive(Debug, Clone, Copy)]
28pub enum SliceSpec {
29    /// Select a single index, reducing dimensionality.
30    Index(isize),
31    /// Select a range [start, stop) with optional step.
32    Range {
33        /// Start index (inclusive), None = beginning
34        start: Option<isize>,
35        /// Stop index (exclusive), None = end
36        stop: Option<isize>,
37        /// Step size, default 1
38        step: isize,
39    },
40    /// Keep all elements in this dimension.
41    All,
42    /// Add a new dimension of size 1.
43    NewAxis,
44}
45
46impl SliceSpec {
47    /// Creates a range slice from start to stop.
48    #[must_use]
49    pub fn range(start: isize, stop: isize) -> Self {
50        Self::Range {
51            start: Some(start),
52            stop: Some(stop),
53            step: 1,
54        }
55    }
56
57    /// Creates a range slice with step.
58    #[must_use]
59    pub fn range_step(start: isize, stop: isize, step: isize) -> Self {
60        Self::Range {
61            start: Some(start),
62            stop: Some(stop),
63            step,
64        }
65    }
66
67    /// Creates a slice from start to end.
68    #[must_use]
69    pub fn from(start: isize) -> Self {
70        Self::Range {
71            start: Some(start),
72            stop: None,
73            step: 1,
74        }
75    }
76
77    /// Creates a slice from beginning to stop.
78    #[must_use]
79    pub fn to(stop: isize) -> Self {
80        Self::Range {
81            start: None,
82            stop: Some(stop),
83            step: 1,
84        }
85    }
86}
87
88// =============================================================================
89// Slicing Implementation
90// =============================================================================
91
92impl<T: Scalar> Tensor<T> {
93    /// Returns a slice of the tensor along the first dimension.
94    ///
95    /// # Arguments
96    /// * `start` - Start index (inclusive)
97    /// * `end` - End index (exclusive)
98    pub fn slice_dim0(&self, start: usize, end: usize) -> Result<Self> {
99        if self.ndim() == 0 {
100            return Err(Error::invalid_operation("Cannot slice a scalar"));
101        }
102
103        let dim_size = self.shape[0];
104        if start > end || end > dim_size {
105            return Err(Error::IndexOutOfBounds {
106                index: end,
107                size: dim_size,
108            });
109        }
110
111        let mut new_shape = self.shape.clone();
112        new_shape[0] = end - start;
113
114        let new_offset = self.offset + start * self.strides[0] as usize;
115
116        Ok(Self {
117            storage: self.storage.clone(),
118            shape: new_shape,
119            strides: self.strides.clone(),
120            offset: new_offset,
121        })
122    }
123
124    /// Returns a view selecting a single index along a dimension.
125    ///
126    /// This reduces the dimensionality by 1.
127    ///
128    /// # Arguments
129    /// * `dim` - Dimension to select from
130    /// * `index` - Index to select
131    pub fn select(&self, dim: usize, index: usize) -> Result<Self> {
132        if dim >= self.ndim() {
133            return Err(Error::InvalidDimension {
134                index: dim as i64,
135                ndim: self.ndim(),
136            });
137        }
138
139        if index >= self.shape[dim] {
140            return Err(Error::IndexOutOfBounds {
141                index,
142                size: self.shape[dim],
143            });
144        }
145
146        let mut new_shape = self.shape.clone();
147        new_shape.remove(dim);
148
149        let mut new_strides = self.strides.clone();
150        new_strides.remove(dim);
151
152        let new_offset = self.offset + index * self.strides[dim] as usize;
153
154        Ok(Self {
155            storage: self.storage.clone(),
156            shape: new_shape,
157            strides: new_strides,
158            offset: new_offset,
159        })
160    }
161
162    /// Returns a narrow view along a dimension.
163    ///
164    /// # Arguments
165    /// * `dim` - Dimension to narrow
166    /// * `start` - Start index
167    /// * `length` - Length of the narrow view
168    pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> {
169        if dim >= self.ndim() {
170            return Err(Error::InvalidDimension {
171                index: dim as i64,
172                ndim: self.ndim(),
173            });
174        }
175
176        if start + length > self.shape[dim] {
177            return Err(Error::IndexOutOfBounds {
178                index: start + length,
179                size: self.shape[dim],
180            });
181        }
182
183        let mut new_shape = self.shape.clone();
184        new_shape[dim] = length;
185
186        let new_offset = self.offset + start * self.strides[dim] as usize;
187
188        Ok(Self {
189            storage: self.storage.clone(),
190            shape: new_shape,
191            strides: self.strides.clone(),
192            offset: new_offset,
193        })
194    }
195
196    /// Splits the tensor into chunks along a dimension.
197    ///
198    /// # Arguments
199    /// * `chunks` - Number of chunks
200    /// * `dim` - Dimension to split along
201    pub fn chunk(&self, chunks: usize, dim: usize) -> Result<Vec<Self>> {
202        if dim >= self.ndim() {
203            return Err(Error::InvalidDimension {
204                index: dim as i64,
205                ndim: self.ndim(),
206            });
207        }
208
209        let dim_size = self.shape[dim];
210        let chunk_size = dim_size.div_ceil(chunks);
211        let mut result = Vec::with_capacity(chunks);
212
213        let mut start = 0;
214        while start < dim_size {
215            let length = (chunk_size).min(dim_size - start);
216            result.push(self.narrow(dim, start, length)?);
217            start += length;
218        }
219
220        Ok(result)
221    }
222
223    /// Splits the tensor into parts of specified sizes along a dimension.
224    ///
225    /// # Arguments
226    /// * `sizes` - Size of each part
227    /// * `dim` - Dimension to split along
228    pub fn split(&self, sizes: &[usize], dim: usize) -> Result<Vec<Self>> {
229        if dim >= self.ndim() {
230            return Err(Error::InvalidDimension {
231                index: dim as i64,
232                ndim: self.ndim(),
233            });
234        }
235
236        let total: usize = sizes.iter().sum();
237        if total != self.shape[dim] {
238            return Err(Error::invalid_operation(format!(
239                "Split sizes {} don't sum to dimension size {}",
240                total, self.shape[dim]
241            )));
242        }
243
244        let mut result = Vec::with_capacity(sizes.len());
245        let mut start = 0;
246
247        for &size in sizes {
248            result.push(self.narrow(dim, start, size)?);
249            start += size;
250        }
251
252        Ok(result)
253    }
254}
255
256// =============================================================================
257// Indexing Implementation
258// =============================================================================
259
260impl<T: Numeric> Tensor<T> {
261    /// Gathers values along a dimension according to indices.
262    ///
263    /// # Arguments
264    /// * `dim` - Dimension to gather along
265    /// * `indices` - Indices tensor
266    pub fn gather(&self, dim: usize, indices: &Tensor<i64>) -> Result<Self> {
267        if dim >= self.ndim() {
268            return Err(Error::InvalidDimension {
269                index: dim as i64,
270                ndim: self.ndim(),
271            });
272        }
273
274        // For simplicity, this is a basic implementation
275        // A full implementation would match PyTorch's semantics exactly
276        let output_shape = indices.shape();
277        let mut output_data = vec![T::zero(); numel(output_shape)];
278
279        let indices_data = indices.to_vec();
280        let self_data = self.to_vec();
281
282        for (out_idx, &index) in indices_data.iter().enumerate() {
283            let index = index as usize;
284            if index >= self.shape[dim] {
285                return Err(Error::IndexOutOfBounds {
286                    index,
287                    size: self.shape[dim],
288                });
289            }
290            // Simplified: assumes 1D case
291            output_data[out_idx] = self_data[index];
292        }
293
294        Tensor::from_vec(output_data, output_shape)
295    }
296
297    /// Returns elements selected by a boolean mask.
298    ///
299    /// # Arguments
300    /// * `mask` - Boolean mask tensor
301    pub fn masked_select(&self, mask: &[bool]) -> Result<Self> {
302        if mask.len() != self.numel() {
303            return Err(Error::shape_mismatch(&[mask.len()], &[self.numel()]));
304        }
305
306        let data = self.to_vec();
307        let selected: Vec<T> = data
308            .into_iter()
309            .zip(mask.iter())
310            .filter(|(_, &m)| m)
311            .map(|(v, _)| v)
312            .collect();
313
314        let len = selected.len();
315        Tensor::from_vec(selected, &[len])
316    }
317
318    /// Sets elements according to a boolean mask.
319    ///
320    /// # Arguments
321    /// * `mask` - Boolean mask tensor
322    /// * `value` - Value to set where mask is true
323    pub fn masked_fill_(&self, mask: &[bool], value: T) -> Result<()> {
324        if mask.len() != self.numel() {
325            return Err(Error::shape_mismatch(&[mask.len()], &[self.numel()]));
326        }
327
328        if !self.is_contiguous() {
329            return Err(Error::NotContiguous);
330        }
331
332        {
333            let mut guard = self.storage.as_slice_mut();
334            for (idx, &m) in mask.iter().enumerate() {
335                if m {
336                    guard[self.offset + idx] = value;
337                }
338            }
339        }
340
341        Ok(())
342    }
343}
344
345// =============================================================================
346// Concatenation and Stacking
347// =============================================================================
348
349/// Concatenates tensors along an existing dimension.
350///
351/// # Arguments
352/// * `tensors` - Slice of tensors to concatenate
353/// * `dim` - Dimension along which to concatenate
354pub fn cat<T: Scalar>(tensors: &[Tensor<T>], dim: usize) -> Result<Tensor<T>> {
355    if tensors.is_empty() {
356        return Err(Error::invalid_operation("Cannot concatenate empty list"));
357    }
358
359    let first = &tensors[0];
360    let ndim = first.ndim();
361
362    if dim >= ndim {
363        return Err(Error::InvalidDimension {
364            index: dim as i64,
365            ndim,
366        });
367    }
368
369    // Validate shapes match except for concat dimension
370    for t in tensors.iter().skip(1) {
371        if t.ndim() != ndim {
372            return Err(Error::invalid_operation(
373                "All tensors must have same number of dimensions",
374            ));
375        }
376        for (d, (&s1, &s2)) in first.shape().iter().zip(t.shape().iter()).enumerate() {
377            if d != dim && s1 != s2 {
378                return Err(Error::shape_mismatch(first.shape(), t.shape()));
379            }
380        }
381    }
382
383    // Compute output shape
384    let mut output_shape = Shape::from_slice(first.shape());
385    output_shape[dim] = tensors.iter().map(|t| t.shape()[dim]).sum();
386
387    // Allocate output
388    let total_numel = numel(&output_shape);
389    let mut output_data = vec![T::zeroed(); total_numel];
390
391    // Copy data - simplified for contiguous case
392    let mut offset = 0;
393    for t in tensors {
394        let data = t.to_vec();
395        for val in data {
396            output_data[offset] = val;
397            offset += 1;
398        }
399    }
400
401    Tensor::from_vec(output_data, &output_shape)
402}
403
404/// Stacks tensors along a new dimension.
405///
406/// # Arguments
407/// * `tensors` - Slice of tensors to stack
408/// * `dim` - Dimension at which to insert the new axis
409pub fn stack<T: Scalar>(tensors: &[Tensor<T>], dim: usize) -> Result<Tensor<T>> {
410    if tensors.is_empty() {
411        return Err(Error::invalid_operation("Cannot stack empty list"));
412    }
413
414    // Unsqueeze each tensor and then concatenate
415    let unsqueezed: Result<Vec<Tensor<T>>> =
416        tensors.iter().map(|t| t.unsqueeze(dim as i64)).collect();
417
418    cat(&unsqueezed?, dim)
419}
420
421// =============================================================================
422// Tests
423// =============================================================================
424
425#[cfg(test)]
426mod tests {
427    use super::*;
428
429    #[test]
430    fn test_slice_dim0() {
431        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]).unwrap();
432
433        let s = t.slice_dim0(1, 3).unwrap();
434        assert_eq!(s.shape(), &[2, 2]);
435        assert_eq!(s.get(&[0, 0]).unwrap(), 3.0);
436    }
437
438    #[test]
439    fn test_select() {
440        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
441
442        let s = t.select(0, 1).unwrap();
443        assert_eq!(s.shape(), &[3]);
444        assert_eq!(s.to_vec(), vec![4.0, 5.0, 6.0]);
445    }
446
447    #[test]
448    fn test_narrow() {
449        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
450
451        let n = t.narrow(0, 1, 3).unwrap();
452        assert_eq!(n.shape(), &[3]);
453        assert_eq!(n.to_vec(), vec![2.0, 3.0, 4.0]);
454    }
455
456    #[test]
457    fn test_chunk() {
458        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6]).unwrap();
459
460        let chunks = t.chunk(3, 0).unwrap();
461        assert_eq!(chunks.len(), 3);
462        assert_eq!(chunks[0].to_vec(), vec![1.0, 2.0]);
463        assert_eq!(chunks[1].to_vec(), vec![3.0, 4.0]);
464        assert_eq!(chunks[2].to_vec(), vec![5.0, 6.0]);
465    }
466
467    #[test]
468    fn test_cat() {
469        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0], &[2]).unwrap();
470        let b = Tensor::<f32>::from_vec(vec![3.0, 4.0], &[2]).unwrap();
471
472        let c = cat(&[a, b], 0).unwrap();
473        assert_eq!(c.shape(), &[4]);
474        assert_eq!(c.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
475    }
476
477    #[test]
478    fn test_stack() {
479        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0], &[2]).unwrap();
480        let b = Tensor::<f32>::from_vec(vec![3.0, 4.0], &[2]).unwrap();
481
482        let c = stack(&[a, b], 0).unwrap();
483        assert_eq!(c.shape(), &[2, 2]);
484    }
485}