Skip to main content

matten/tensor/
ops.rs

1//! Shape operations, slicing, and boundary integration methods for `Tensor`
2//! (RFC-007, RFC-008, RFC-009). Split from `tensor.rs` per the 300-ELOC guideline.
3
4use crate::{MattenError, Tensor};
5
6impl Tensor {
7    // ---- Shape operations (M4 / RFC-007) ------------------------------------
8
9    /// Reshapes the tensor to `new_shape`, returning a new owned tensor.
10    ///
11    /// The total element count must be unchanged. Data order is preserved
12    /// (row-major flat order).
13    ///
14    /// # Panics
15    ///
16    /// Panics on element-count mismatch or invalid shape. Use
17    /// [`try_reshape`](Tensor::try_reshape) for recoverable construction.
18    ///
19    /// ```
20    /// use matten::Tensor;
21    /// let t = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]);
22    /// let flat = t.reshape(&[4]);
23    /// assert_eq!(flat.shape(), &[4]);
24    /// assert_eq!(flat.as_slice(), &[1.0, 2.0, 3.0, 4.0]);
25    /// ```
26    #[must_use]
27    pub fn reshape(&self, new_shape: &[usize]) -> Tensor {
28        crate::reshape::try_reshape_impl(self, new_shape).unwrap_or_else(|e| panic!("{e}"))
29    }
30
31    /// Reshapes the tensor, returning an error instead of panicking.
32    ///
33    /// # Errors
34    ///
35    /// Returns [`MattenError::Shape`] on element-count mismatch or invalid shape.
36    pub fn try_reshape(&self, new_shape: &[usize]) -> Result<Tensor, MattenError> {
37        crate::reshape::try_reshape_impl(self, new_shape)
38    }
39
40    /// Flattens the tensor to a 1-D tensor, preserving row-major order.
41    ///
42    /// A scalar (shape `[]`) is returned as shape `[1]`.
43    ///
44    /// ```
45    /// use matten::Tensor;
46    /// let t = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]);
47    /// let flat = t.flatten();
48    /// assert_eq!(flat.shape(), &[4]);
49    /// ```
50    #[must_use]
51    pub fn flatten(&self) -> Tensor {
52        #[cfg(feature = "dynamic")]
53        if self.is_dynamic() {
54            panic!(
55                "matten unsupported error in flatten: dynamic tensors do not support flatten; call try_numeric() first to convert to a numeric tensor"
56            );
57        }
58        let len = self.data.len();
59        Tensor {
60            data: self.data.clone(),
61            shape: vec![len],
62            #[cfg(feature = "dynamic")]
63            dynamic: None,
64        }
65    }
66
67    /// Transposes the tensor by reversing the axis order.
68    ///
69    /// For a rank-2 tensor this swaps rows and columns. For rank > 2 the axis
70    /// order is reversed: `[d0, d1, d2] → [d2, d1, d0]`.
71    ///
72    /// # Panics
73    ///
74    /// Panics for a rank-0 scalar (no axes to transpose).
75    ///
76    /// ```
77    /// use matten::Tensor;
78    /// let m = Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
79    /// let mt = m.transpose();
80    /// assert_eq!(mt.shape(), &[3, 2]);
81    /// assert_eq!(mt.as_slice(), &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
82    /// ```
83    #[must_use]
84    pub fn transpose(&self) -> Tensor {
85        let ndim = self.ndim();
86        if ndim == 0 {
87            panic!("matten shape error in transpose: cannot transpose a scalar (rank 0)");
88        }
89        let perm: Vec<usize> = (0..ndim).rev().collect();
90        crate::reshape::permute_axes(self, &perm)
91    }
92
93    /// Alias for [`transpose`](Tensor::transpose).
94    #[must_use]
95    pub fn t(&self) -> Tensor {
96        self.transpose()
97    }
98
99    /// Returns a new tensor with `axis1` and `axis2` swapped.
100    ///
101    /// # Panics
102    ///
103    /// Panics if either axis is out of range.
104    ///
105    /// ```
106    /// use matten::Tensor;
107    /// let t = Tensor::new((1..=24).map(|x| x as f64).collect(), &[2, 3, 4]);
108    /// let s = t.swap_axes(0, 2);
109    /// assert_eq!(s.shape(), &[4, 3, 2]);
110    /// ```
111    #[must_use]
112    pub fn swap_axes(&self, axis1: usize, axis2: usize) -> Tensor {
113        crate::reshape::validate_axes(axis1, axis2, self.ndim(), "swap_axes")
114            .unwrap_or_else(|e| panic!("{e}"));
115        let mut perm: Vec<usize> = (0..self.ndim()).collect();
116        perm.swap(axis1, axis2);
117        crate::reshape::permute_axes(self, &perm)
118    }
119
120    /// Removes all axes of length `1`, returning a new owned tensor.
121    ///
122    /// Data order is unchanged. A scalar stays a scalar, and a tensor whose every
123    /// axis is `1` (e.g. `[1, 1]`) becomes a scalar (shape `[]`).
124    ///
125    /// # Panics
126    ///
127    /// Panics on a dynamic tensor; call `try_numeric()` first.
128    ///
129    /// ```
130    /// use matten::Tensor;
131    /// let t = Tensor::new(vec![1.0, 2.0, 3.0], &[1, 3, 1]);
132    /// assert_eq!(t.squeeze().shape(), &[3]);
133    /// ```
134    #[must_use]
135    pub fn squeeze(&self) -> Tensor {
136        #[cfg(feature = "dynamic")]
137        if self.is_dynamic() {
138            panic!(
139                "matten unsupported error in squeeze: dynamic tensors do not support squeeze; call try_numeric() first to convert to a numeric tensor"
140            );
141        }
142        let shape: Vec<usize> = self.shape.iter().copied().filter(|&d| d != 1).collect();
143        Tensor {
144            data: self.data.clone(),
145            shape,
146            #[cfg(feature = "dynamic")]
147            dynamic: None,
148        }
149    }
150
151    /// Inserts a new axis of length `1` at `axis`, returning a new owned tensor.
152    ///
153    /// `axis` may be `0..=ndim` (inserting at `ndim` appends a trailing axis).
154    /// Data order is unchanged.
155    ///
156    /// # Panics
157    ///
158    /// Panics if `axis > ndim`, or on a dynamic tensor. Use
159    /// [`try_expand_dims`](Tensor::try_expand_dims) for the non-panicking form.
160    ///
161    /// ```
162    /// use matten::Tensor;
163    /// let t = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
164    /// assert_eq!(t.expand_dims(0).shape(), &[1, 3]);
165    /// assert_eq!(t.expand_dims(1).shape(), &[3, 1]);
166    /// ```
167    #[must_use]
168    pub fn expand_dims(&self, axis: usize) -> Tensor {
169        self.try_expand_dims(axis).unwrap_or_else(|e| panic!("{e}"))
170    }
171
172    /// Non-panicking [`expand_dims`](Tensor::expand_dims).
173    ///
174    /// # Errors
175    ///
176    /// Returns [`MattenError::InvalidArgument`] if `axis > ndim`, or
177    /// [`MattenError::Unsupported`] on a dynamic tensor.
178    ///
179    /// ```
180    /// use matten::Tensor;
181    /// let t = Tensor::from_vec(vec![1.0, 2.0, 3.0]);
182    /// assert!(t.try_expand_dims(5).is_err());
183    /// ```
184    pub fn try_expand_dims(&self, axis: usize) -> Result<Tensor, MattenError> {
185        #[cfg(feature = "dynamic")]
186        if self.is_dynamic() {
187            return Err(MattenError::Unsupported {
188                operation: "expand_dims",
189                message: "dynamic tensors do not support expand_dims; call try_numeric() first"
190                    .to_string(),
191            });
192        }
193        let ndim = self.shape.len();
194        if axis > ndim {
195            return Err(MattenError::InvalidArgument {
196                operation: "expand_dims",
197                argument: "axis",
198                message: format!(
199                    "axis {axis} is out of range for a rank-{ndim} tensor (valid 0..={ndim})"
200                ),
201            });
202        }
203        let mut shape = self.shape.clone();
204        shape.insert(axis, 1);
205        Ok(Tensor {
206            data: self.data.clone(),
207            shape,
208            #[cfg(feature = "dynamic")]
209            dynamic: None,
210        })
211    }
212
213    /// Returns the element at the multidimensional `coord`, or `None` if the
214    /// coordinate rank doesn't match or any component is out of bounds.
215    ///
216    /// ```
217    /// use matten::Tensor;
218    /// let t = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]);
219    /// assert_eq!(t.get(&[0, 1]), Some(2.0));
220    /// assert_eq!(t.get(&[5, 0]), None);
221    /// ```
222    pub fn get(&self, coord: &[usize]) -> Option<f64> {
223        #[cfg(feature = "dynamic")]
224        self.panic_if_dynamic("get");
225        let flat = crate::shape::coord_to_flat(coord, &self.shape)?;
226        self.data.get(flat).copied()
227    }
228
229    /// Returns the element at flat row-major `index`, or `None` if out of bounds.
230    ///
231    /// This is the flat-index companion to [`get`](Tensor::get). The index
232    /// follows the same row-major layout as [`as_slice`](Tensor::as_slice).
233    ///
234    /// ```
235    /// use matten::Tensor;
236    /// let t = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]);
237    /// assert_eq!(t.get_flat(1), Some(2.0));
238    /// assert_eq!(t.get_flat(10), None);
239    /// ```
240    pub fn get_flat(&self, index: usize) -> Option<f64> {
241        #[cfg(feature = "dynamic")]
242        self.panic_if_dynamic("get_flat");
243        self.data.get(index).copied()
244    }
245
246    // ---- Slicing (M4 / RFC-008) ---------------------------------------------
247
248    /// Starts a slice builder for this tensor. The builder is the canonical
249    /// slicing API; [`slice_str`](Tensor::slice_str) is a convenience wrapper.
250    ///
251    /// ```
252    /// use matten::Tensor;
253    /// let t = Tensor::new(vec![1.0,2.0,3.0,4.0,5.0,6.0], &[2, 3]);
254    /// let row = t.slice().index(0).all().build().unwrap();
255    /// assert_eq!(row.as_slice(), &[1.0, 2.0, 3.0]);
256    /// ```
257    pub fn slice(&self) -> crate::slice::SliceBuilder<'_> {
258        crate::slice::SliceBuilder::new(self)
259    }
260
261    /// Slices this tensor using a NumPy-like string specification.
262    ///
263    /// This is a convenience wrapper over the builder API. It always returns
264    /// `Result` and never panics on malformed input.
265    ///
266    /// # Errors
267    ///
268    /// Returns [`MattenError::Slice`] for any parse or bounds error.
269    ///
270    /// ```
271    /// use matten::Tensor;
272    /// let t = Tensor::new(vec![1.0,2.0,3.0,4.0,5.0,6.0], &[2, 3]);
273    /// let top = t.slice_str("0, :").unwrap();
274    /// assert_eq!(top.as_slice(), &[1.0, 2.0, 3.0]);
275    /// ```
276    pub fn slice_str(&self, spec: &str) -> Result<Tensor, MattenError> {
277        let specs = crate::slice::parse_slice_str(spec)?;
278        crate::slice::execute_slice(self, &specs, "slice_str")
279    }
280}
281
282// ---- Boundary integration (M5 / RFC-009) --------------------------------
283
284impl Tensor {
285    /// Parses a JSON string into a `Tensor`.
286    ///
287    /// Accepts the canonical `{"shape":[…],"data":[…]}` object form and the
288    /// convenience nested-array form (rank 1 and 2). Returns
289    /// [`MattenError::Parse`] for any error; never panics.
290    ///
291    /// ```
292    /// use matten::Tensor;
293    ///
294    /// // Canonical object form
295    /// let t = Tensor::from_json(r#"{"shape":[2,2],"data":[1.0,2.0,3.0,4.0]}"#).unwrap();
296    /// assert_eq!(t.shape(), &[2, 2]);
297    ///
298    /// // Nested-array convenience form
299    /// let t = Tensor::from_json("[[1.0,2.0],[3.0,4.0]]").unwrap();
300    /// assert_eq!(t.shape(), &[2, 2]);
301    /// ```
302    #[cfg(feature = "json")]
303    pub fn from_json(input: &str) -> Result<Tensor, MattenError> {
304        crate::parse::json::from_json_str(input)
305    }
306
307    /// Parses a CSV string into a `Tensor` with shape `[rows, cols]`.
308    ///
309    /// All fields must be valid `f64` values. Returns [`MattenError::Parse`]
310    /// for ragged rows or non-numeric fields; never panics.
311    ///
312    /// ```
313    /// use matten::Tensor;
314    ///
315    /// let t = Tensor::from_csv("1.0,2.0\n3.0,4.0\n").unwrap();
316    /// assert_eq!(t.shape(), &[2, 2]);
317    /// assert_eq!(t.as_slice(), &[1.0, 2.0, 3.0, 4.0]);
318    /// ```
319    #[cfg(feature = "csv")]
320    pub fn from_csv(input: &str) -> Result<Tensor, MattenError> {
321        crate::parse::csv::from_csv_str(input)
322    }
323
324    /// Loads and parses a JSON file into a `Tensor`.
325    ///
326    /// Returns [`MattenError::Io`] for file errors, [`MattenError::Parse`] for
327    /// parse errors.
328    ///
329    /// # Errors
330    ///
331    /// Returns an error if the file cannot be read or the content is invalid.
332    #[cfg(feature = "json")]
333    pub fn load_json(path: impl AsRef<std::path::Path>) -> Result<Tensor, MattenError> {
334        let path = path.as_ref();
335        let content = std::fs::read_to_string(path).map_err(|e| MattenError::Io {
336            path: path.to_path_buf(),
337            source: e,
338        })?;
339        crate::parse::json::from_json_str(&content)
340    }
341
342    /// Loads and parses a CSV file into a `Tensor` with shape `[rows, cols]`.
343    ///
344    /// Returns [`MattenError::Io`] for file errors, [`MattenError::Parse`] for
345    /// parse errors.
346    ///
347    /// # Errors
348    ///
349    /// Returns an error if the file cannot be read or the content is invalid.
350    #[cfg(feature = "csv")]
351    pub fn load_csv(path: impl AsRef<std::path::Path>) -> Result<Tensor, MattenError> {
352        let path = path.as_ref();
353        let content = std::fs::read_to_string(path).map_err(|e| MattenError::Io {
354            path: path.to_path_buf(),
355            source: e,
356        })?;
357        crate::parse::csv::from_csv_str(&content)
358    }
359}
360
361#[cfg(test)]
362mod tests;