axonml_tensor/
view.rs

1//! Views and Slicing - Tensor Indexing Operations
2//!
3//! Provides functionality for creating views into tensors through slicing,
4//! indexing, and masking operations. Views share storage with the original
5//! tensor when possible, avoiding unnecessary copies.
6//!
7//! # Key Features
8//! - Zero-copy slicing for contiguous ranges
9//! - Advanced indexing with integer arrays
10//! - Boolean masking
11//! - Gather and scatter operations
12//!
13//! @version 0.1.0
14//! @author `AutomataNexus` Development Team
15
16use axonml_core::dtype::{Numeric, Scalar};
17use axonml_core::error::{Error, Result};
18
19use crate::shape::{numel, Shape};
20use crate::tensor::Tensor;
21
22// =============================================================================
23// Slice Specification
24// =============================================================================
25
26/// Specifies how to slice along a single dimension.
27#[derive(Debug, Clone, Copy)]
28pub enum SliceSpec {
29    /// Select a single index, reducing dimensionality.
30    Index(isize),
31    /// Select a range [start, stop) with optional step.
32    Range {
33        /// Start index (inclusive), None = beginning
34        start: Option<isize>,
35        /// Stop index (exclusive), None = end
36        stop: Option<isize>,
37        /// Step size, default 1
38        step: isize,
39    },
40    /// Keep all elements in this dimension.
41    All,
42    /// Add a new dimension of size 1.
43    NewAxis,
44}
45
46impl SliceSpec {
47    /// Creates a range slice from start to stop.
48    #[must_use] pub fn range(start: isize, stop: isize) -> Self {
49        Self::Range {
50            start: Some(start),
51            stop: Some(stop),
52            step: 1,
53        }
54    }
55
56    /// Creates a range slice with step.
57    #[must_use] pub fn range_step(start: isize, stop: isize, step: isize) -> Self {
58        Self::Range {
59            start: Some(start),
60            stop: Some(stop),
61            step,
62        }
63    }
64
65    /// Creates a slice from start to end.
66    #[must_use] pub fn from(start: isize) -> Self {
67        Self::Range {
68            start: Some(start),
69            stop: None,
70            step: 1,
71        }
72    }
73
74    /// Creates a slice from beginning to stop.
75    #[must_use] pub fn to(stop: isize) -> Self {
76        Self::Range {
77            start: None,
78            stop: Some(stop),
79            step: 1,
80        }
81    }
82}
83
84// =============================================================================
85// Slicing Implementation
86// =============================================================================
87
88impl<T: Scalar> Tensor<T> {
89    /// Returns a slice of the tensor along the first dimension.
90    ///
91    /// # Arguments
92    /// * `start` - Start index (inclusive)
93    /// * `end` - End index (exclusive)
94    pub fn slice_dim0(&self, start: usize, end: usize) -> Result<Self> {
95        if self.ndim() == 0 {
96            return Err(Error::invalid_operation("Cannot slice a scalar"));
97        }
98
99        let dim_size = self.shape[0];
100        if start > end || end > dim_size {
101            return Err(Error::IndexOutOfBounds {
102                index: end,
103                size: dim_size,
104            });
105        }
106
107        let mut new_shape = self.shape.clone();
108        new_shape[0] = end - start;
109
110        let new_offset = self.offset + start * self.strides[0] as usize;
111
112        Ok(Self {
113            storage: self.storage.clone(),
114            shape: new_shape,
115            strides: self.strides.clone(),
116            offset: new_offset,
117        })
118    }
119
120    /// Returns a view selecting a single index along a dimension.
121    ///
122    /// This reduces the dimensionality by 1.
123    ///
124    /// # Arguments
125    /// * `dim` - Dimension to select from
126    /// * `index` - Index to select
127    pub fn select(&self, dim: usize, index: usize) -> Result<Self> {
128        if dim >= self.ndim() {
129            return Err(Error::InvalidDimension {
130                index: dim as i64,
131                ndim: self.ndim(),
132            });
133        }
134
135        if index >= self.shape[dim] {
136            return Err(Error::IndexOutOfBounds {
137                index,
138                size: self.shape[dim],
139            });
140        }
141
142        let mut new_shape = self.shape.clone();
143        new_shape.remove(dim);
144
145        let mut new_strides = self.strides.clone();
146        new_strides.remove(dim);
147
148        let new_offset = self.offset + index * self.strides[dim] as usize;
149
150        Ok(Self {
151            storage: self.storage.clone(),
152            shape: new_shape,
153            strides: new_strides,
154            offset: new_offset,
155        })
156    }
157
158    /// Returns a narrow view along a dimension.
159    ///
160    /// # Arguments
161    /// * `dim` - Dimension to narrow
162    /// * `start` - Start index
163    /// * `length` - Length of the narrow view
164    pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> {
165        if dim >= self.ndim() {
166            return Err(Error::InvalidDimension {
167                index: dim as i64,
168                ndim: self.ndim(),
169            });
170        }
171
172        if start + length > self.shape[dim] {
173            return Err(Error::IndexOutOfBounds {
174                index: start + length,
175                size: self.shape[dim],
176            });
177        }
178
179        let mut new_shape = self.shape.clone();
180        new_shape[dim] = length;
181
182        let new_offset = self.offset + start * self.strides[dim] as usize;
183
184        Ok(Self {
185            storage: self.storage.clone(),
186            shape: new_shape,
187            strides: self.strides.clone(),
188            offset: new_offset,
189        })
190    }
191
192    /// Splits the tensor into chunks along a dimension.
193    ///
194    /// # Arguments
195    /// * `chunks` - Number of chunks
196    /// * `dim` - Dimension to split along
197    pub fn chunk(&self, chunks: usize, dim: usize) -> Result<Vec<Self>> {
198        if dim >= self.ndim() {
199            return Err(Error::InvalidDimension {
200                index: dim as i64,
201                ndim: self.ndim(),
202            });
203        }
204
205        let dim_size = self.shape[dim];
206        let chunk_size = dim_size.div_ceil(chunks);
207        let mut result = Vec::with_capacity(chunks);
208
209        let mut start = 0;
210        while start < dim_size {
211            let length = (chunk_size).min(dim_size - start);
212            result.push(self.narrow(dim, start, length)?);
213            start += length;
214        }
215
216        Ok(result)
217    }
218
219    /// Splits the tensor into parts of specified sizes along a dimension.
220    ///
221    /// # Arguments
222    /// * `sizes` - Size of each part
223    /// * `dim` - Dimension to split along
224    pub fn split(&self, sizes: &[usize], dim: usize) -> Result<Vec<Self>> {
225        if dim >= self.ndim() {
226            return Err(Error::InvalidDimension {
227                index: dim as i64,
228                ndim: self.ndim(),
229            });
230        }
231
232        let total: usize = sizes.iter().sum();
233        if total != self.shape[dim] {
234            return Err(Error::invalid_operation(format!(
235                "Split sizes {} don't sum to dimension size {}",
236                total, self.shape[dim]
237            )));
238        }
239
240        let mut result = Vec::with_capacity(sizes.len());
241        let mut start = 0;
242
243        for &size in sizes {
244            result.push(self.narrow(dim, start, size)?);
245            start += size;
246        }
247
248        Ok(result)
249    }
250}
251
252// =============================================================================
253// Indexing Implementation
254// =============================================================================
255
256impl<T: Numeric> Tensor<T> {
257    /// Gathers values along a dimension according to indices.
258    ///
259    /// # Arguments
260    /// * `dim` - Dimension to gather along
261    /// * `indices` - Indices tensor
262    pub fn gather(&self, dim: usize, indices: &Tensor<i64>) -> Result<Self> {
263        if dim >= self.ndim() {
264            return Err(Error::InvalidDimension {
265                index: dim as i64,
266                ndim: self.ndim(),
267            });
268        }
269
270        // For simplicity, this is a basic implementation
271        // A full implementation would match PyTorch's semantics exactly
272        let output_shape = indices.shape();
273        let mut output_data = vec![T::zero(); numel(output_shape)];
274
275        let indices_data = indices.to_vec();
276        let self_data = self.to_vec();
277
278        for (out_idx, &index) in indices_data.iter().enumerate() {
279            let index = index as usize;
280            if index >= self.shape[dim] {
281                return Err(Error::IndexOutOfBounds {
282                    index,
283                    size: self.shape[dim],
284                });
285            }
286            // Simplified: assumes 1D case
287            output_data[out_idx] = self_data[index];
288        }
289
290        Tensor::from_vec(output_data, output_shape)
291    }
292
293    /// Returns elements selected by a boolean mask.
294    ///
295    /// # Arguments
296    /// * `mask` - Boolean mask tensor
297    pub fn masked_select(&self, mask: &[bool]) -> Result<Self> {
298        if mask.len() != self.numel() {
299            return Err(Error::shape_mismatch(&[mask.len()], &[self.numel()]));
300        }
301
302        let data = self.to_vec();
303        let selected: Vec<T> = data
304            .into_iter()
305            .zip(mask.iter())
306            .filter(|(_, &m)| m)
307            .map(|(v, _)| v)
308            .collect();
309
310        let len = selected.len();
311        Tensor::from_vec(selected, &[len])
312    }
313
314    /// Sets elements according to a boolean mask.
315    ///
316    /// # Arguments
317    /// * `mask` - Boolean mask tensor
318    /// * `value` - Value to set where mask is true
319    pub fn masked_fill_(&self, mask: &[bool], value: T) -> Result<()> {
320        if mask.len() != self.numel() {
321            return Err(Error::shape_mismatch(&[mask.len()], &[self.numel()]));
322        }
323
324        if !self.is_contiguous() {
325            return Err(Error::NotContiguous);
326        }
327
328        {
329            let mut guard = self.storage.as_slice_mut();
330            for (idx, &m) in mask.iter().enumerate() {
331                if m {
332                    guard[self.offset + idx] = value;
333                }
334            }
335        }
336
337        Ok(())
338    }
339}
340
341// =============================================================================
342// Concatenation and Stacking
343// =============================================================================
344
345/// Concatenates tensors along an existing dimension.
346///
347/// # Arguments
348/// * `tensors` - Slice of tensors to concatenate
349/// * `dim` - Dimension along which to concatenate
350pub fn cat<T: Scalar>(tensors: &[Tensor<T>], dim: usize) -> Result<Tensor<T>> {
351    if tensors.is_empty() {
352        return Err(Error::invalid_operation("Cannot concatenate empty list"));
353    }
354
355    let first = &tensors[0];
356    let ndim = first.ndim();
357
358    if dim >= ndim {
359        return Err(Error::InvalidDimension {
360            index: dim as i64,
361            ndim,
362        });
363    }
364
365    // Validate shapes match except for concat dimension
366    for t in tensors.iter().skip(1) {
367        if t.ndim() != ndim {
368            return Err(Error::invalid_operation(
369                "All tensors must have same number of dimensions",
370            ));
371        }
372        for (d, (&s1, &s2)) in first.shape().iter().zip(t.shape().iter()).enumerate() {
373            if d != dim && s1 != s2 {
374                return Err(Error::shape_mismatch(first.shape(), t.shape()));
375            }
376        }
377    }
378
379    // Compute output shape
380    let mut output_shape = Shape::from_slice(first.shape());
381    output_shape[dim] = tensors.iter().map(|t| t.shape()[dim]).sum();
382
383    // Allocate output
384    let total_numel = numel(&output_shape);
385    let mut output_data = vec![T::zeroed(); total_numel];
386
387    // Copy data - simplified for contiguous case
388    let mut offset = 0;
389    for t in tensors {
390        let data = t.to_vec();
391        for val in data {
392            output_data[offset] = val;
393            offset += 1;
394        }
395    }
396
397    Tensor::from_vec(output_data, &output_shape)
398}
399
400/// Stacks tensors along a new dimension.
401///
402/// # Arguments
403/// * `tensors` - Slice of tensors to stack
404/// * `dim` - Dimension at which to insert the new axis
405pub fn stack<T: Scalar>(tensors: &[Tensor<T>], dim: usize) -> Result<Tensor<T>> {
406    if tensors.is_empty() {
407        return Err(Error::invalid_operation("Cannot stack empty list"));
408    }
409
410    // Unsqueeze each tensor and then concatenate
411    let unsqueezed: Result<Vec<Tensor<T>>> =
412        tensors.iter().map(|t| t.unsqueeze(dim as i64)).collect();
413
414    cat(&unsqueezed?, dim)
415}
416
417// =============================================================================
418// Tests
419// =============================================================================
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424
425    #[test]
426    fn test_slice_dim0() {
427        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]).unwrap();
428
429        let s = t.slice_dim0(1, 3).unwrap();
430        assert_eq!(s.shape(), &[2, 2]);
431        assert_eq!(s.get(&[0, 0]).unwrap(), 3.0);
432    }
433
434    #[test]
435    fn test_select() {
436        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
437
438        let s = t.select(0, 1).unwrap();
439        assert_eq!(s.shape(), &[3]);
440        assert_eq!(s.to_vec(), vec![4.0, 5.0, 6.0]);
441    }
442
443    #[test]
444    fn test_narrow() {
445        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
446
447        let n = t.narrow(0, 1, 3).unwrap();
448        assert_eq!(n.shape(), &[3]);
449        assert_eq!(n.to_vec(), vec![2.0, 3.0, 4.0]);
450    }
451
452    #[test]
453    fn test_chunk() {
454        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6]).unwrap();
455
456        let chunks = t.chunk(3, 0).unwrap();
457        assert_eq!(chunks.len(), 3);
458        assert_eq!(chunks[0].to_vec(), vec![1.0, 2.0]);
459        assert_eq!(chunks[1].to_vec(), vec![3.0, 4.0]);
460        assert_eq!(chunks[2].to_vec(), vec![5.0, 6.0]);
461    }
462
463    #[test]
464    fn test_cat() {
465        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0], &[2]).unwrap();
466        let b = Tensor::<f32>::from_vec(vec![3.0, 4.0], &[2]).unwrap();
467
468        let c = cat(&[a, b], 0).unwrap();
469        assert_eq!(c.shape(), &[4]);
470        assert_eq!(c.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
471    }
472
473    #[test]
474    fn test_stack() {
475        let a = Tensor::<f32>::from_vec(vec![1.0, 2.0], &[2]).unwrap();
476        let b = Tensor::<f32>::from_vec(vec![3.0, 4.0], &[2]).unwrap();
477
478        let c = stack(&[a, b], 0).unwrap();
479        assert_eq!(c.shape(), &[2, 2]);
480    }
481}