Skip to main content

ronn_core/ops/
shape.rs

1//! Shape manipulation operations for tensors.
2//!
3//! This module provides operations for manipulating tensor shapes including
4//! reshape, flatten, squeeze, unsqueeze, and permute operations.
5
6use crate::tensor::Tensor;
7use anyhow::{Result, anyhow};
8
9/// Trait for shape manipulation operations on tensors.
10pub trait ShapeOps {
11    /// Reshape the tensor to a new shape.
12    fn reshape(&self, new_shape: &[usize]) -> Result<Tensor>;
13
14    /// Flatten the tensor to 1D.
15    fn flatten(&self) -> Result<Tensor>;
16
17    /// Flatten starting from a specific dimension.
18    fn flatten_from(&self, start_dim: usize) -> Result<Tensor>;
19
20    /// Remove dimensions of size 1.
21    fn squeeze(&self) -> Result<Tensor>;
22
23    /// Remove a specific dimension of size 1.
24    fn squeeze_dim(&self, dim: usize) -> Result<Tensor>;
25
26    /// Add a dimension of size 1.
27    fn unsqueeze(&self, dim: usize) -> Result<Tensor>;
28
29    /// Permute the dimensions of the tensor.
30    fn permute(&self, dims: &[usize]) -> Result<Tensor>;
31
32    /// Expand the tensor to a new shape (broadcasting).
33    fn expand(&self, new_shape: &[usize]) -> Result<Tensor>;
34
35    /// View the tensor with a new shape (no data copy).
36    fn view(&self, new_shape: &[usize]) -> Result<Tensor>;
37}
38
39impl ShapeOps for Tensor {
40    fn reshape(&self, new_shape: &[usize]) -> Result<Tensor> {
41        // Calculate total elements to verify shape compatibility
42        let current_elements = self.numel();
43        let new_elements: usize = new_shape.iter().product();
44
45        if current_elements != new_elements {
46            return Err(anyhow!(
47                "Cannot reshape tensor with {} elements to shape {:?} ({} elements)",
48                current_elements,
49                new_shape,
50                new_elements
51            ));
52        }
53
54        let result_candle = self.candle_tensor().reshape(new_shape)?;
55
56        Ok(Tensor::from_candle(
57            result_candle,
58            self.dtype(),
59            self.layout(),
60        ))
61    }
62
63    fn flatten(&self) -> Result<Tensor> {
64        let total_elements = self.numel();
65        self.reshape(&[total_elements])
66    }
67
68    fn flatten_from(&self, start_dim: usize) -> Result<Tensor> {
69        let shape = self.shape();
70
71        if start_dim >= shape.len() {
72            return Err(anyhow!(
73                "start_dim {} is out of bounds for tensor with {} dimensions",
74                start_dim,
75                shape.len()
76            ));
77        }
78
79        if start_dim == 0 {
80            return self.flatten();
81        }
82
83        // Keep dimensions before start_dim, flatten the rest
84        let mut new_shape = shape[..start_dim].to_vec();
85        let remaining_elements: usize = shape[start_dim..].iter().product();
86        new_shape.push(remaining_elements);
87
88        self.reshape(&new_shape)
89    }
90
91    fn squeeze(&self) -> Result<Tensor> {
92        let shape = self.shape();
93        let new_shape: Vec<usize> = shape.into_iter().filter(|&dim| dim != 1).collect();
94
95        // If all dimensions were 1, result should be scalar (empty shape)
96        if new_shape.is_empty() {
97            return self.reshape(&[1]);
98        }
99
100        self.reshape(&new_shape)
101    }
102
103    fn squeeze_dim(&self, dim: usize) -> Result<Tensor> {
104        let shape = self.shape();
105
106        if dim >= shape.len() {
107            return Err(anyhow!(
108                "Dimension {} is out of bounds for tensor with {} dimensions",
109                dim,
110                shape.len()
111            ));
112        }
113
114        if shape[dim] != 1 {
115            return Err(anyhow!(
116                "Cannot squeeze dimension {} with size {}",
117                dim,
118                shape[dim]
119            ));
120        }
121
122        let mut new_shape = shape;
123        new_shape.remove(dim);
124
125        if new_shape.is_empty() {
126            new_shape.push(1);
127        }
128
129        self.reshape(&new_shape)
130    }
131
132    fn unsqueeze(&self, dim: usize) -> Result<Tensor> {
133        let shape = self.shape();
134
135        if dim > shape.len() {
136            return Err(anyhow!(
137                "Dimension {} is out of bounds for unsqueeze (max {})",
138                dim,
139                shape.len()
140            ));
141        }
142
143        let mut new_shape = shape;
144        new_shape.insert(dim, 1);
145
146        self.reshape(&new_shape)
147    }
148
149    fn permute(&self, dims: &[usize]) -> Result<Tensor> {
150        let shape = self.shape();
151
152        if dims.len() != shape.len() {
153            return Err(anyhow!(
154                "Number of dimensions in permutation ({}) doesn't match tensor dimensions ({})",
155                dims.len(),
156                shape.len()
157            ));
158        }
159
160        // Check that all dimensions are valid and unique
161        let mut sorted_dims = dims.to_vec();
162        sorted_dims.sort_unstable();
163        let expected_dims: Vec<usize> = (0..shape.len()).collect();
164
165        if sorted_dims != expected_dims {
166            return Err(anyhow!("Invalid permutation: {:?}", dims));
167        }
168
169        let result_candle = self.candle_tensor().permute(dims)?;
170
171        Ok(Tensor::from_candle(
172            result_candle,
173            self.dtype(),
174            self.layout(),
175        ))
176    }
177
178    fn expand(&self, new_shape: &[usize]) -> Result<Tensor> {
179        let current_shape = self.shape();
180
181        // Check if expansion is valid
182        if new_shape.len() < current_shape.len() {
183            return Err(anyhow!(
184                "Cannot expand tensor with {} dimensions to {} dimensions",
185                current_shape.len(),
186                new_shape.len()
187            ));
188        }
189
190        // Align dimensions from the right
191        let offset = new_shape.len() - current_shape.len();
192        for (i, &current_dim) in current_shape.iter().enumerate() {
193            let new_dim = new_shape[offset + i];
194            if current_dim != 1 && current_dim != new_dim {
195                return Err(anyhow!(
196                    "Cannot expand dimension {} from {} to {}",
197                    offset + i,
198                    current_dim,
199                    new_dim
200                ));
201            }
202        }
203
204        let result_candle = self.candle_tensor().expand(new_shape)?;
205
206        Ok(Tensor::from_candle(
207            result_candle,
208            self.dtype(),
209            self.layout(),
210        ))
211    }
212
213    fn view(&self, new_shape: &[usize]) -> Result<Tensor> {
214        // View is similar to reshape but requires compatible strides
215        self.reshape(new_shape)
216    }
217}
218
219/// Additional shape manipulation methods.
220impl Tensor {
221    /// Split the tensor into chunks along a specific dimension.
222    pub fn chunk(&self, chunks: usize, dim: usize) -> Result<Vec<Tensor>> {
223        let shape = self.shape();
224
225        if dim >= shape.len() {
226            return Err(anyhow!(
227                "Dimension {} is out of bounds for tensor with {} dimensions",
228                dim,
229                shape.len()
230            ));
231        }
232
233        let dim_size = shape[dim];
234        let chunk_size = (dim_size + chunks - 1) / chunks; // Ceiling division
235
236        let mut result = Vec::new();
237
238        for i in 0..chunks {
239            let start = i * chunk_size;
240            let end = std::cmp::min(start + chunk_size, dim_size);
241
242            if start >= dim_size {
243                break;
244            }
245
246            let chunk_tensor = self.slice(dim, start, end)?;
247            result.push(chunk_tensor);
248        }
249
250        Ok(result)
251    }
252
253    /// Slice the tensor along a specific dimension.
254    pub fn slice(&self, dim: usize, start: usize, end: usize) -> Result<Tensor> {
255        let shape = self.shape();
256
257        if dim >= shape.len() {
258            return Err(anyhow!(
259                "Dimension {} is out of bounds for tensor with {} dimensions",
260                dim,
261                shape.len()
262            ));
263        }
264
265        if start >= end || end > shape[dim] {
266            return Err(anyhow!(
267                "Invalid slice range: {}:{} for dimension of size {}",
268                start,
269                end,
270                shape[dim]
271            ));
272        }
273
274        let result_candle = self.candle_tensor().narrow(dim, start, end - start)?;
275
276        Ok(Tensor::from_candle(
277            result_candle,
278            self.dtype(),
279            self.layout(),
280        ))
281    }
282
283    /// Concatenate tensors along a specific dimension.
284    pub fn concat(tensors: &[&Tensor], dim: usize) -> Result<Tensor> {
285        if tensors.is_empty() {
286            return Err(anyhow!("Cannot concatenate empty list of tensors"));
287        }
288
289        let first_tensor = tensors[0];
290        let first_shape = first_tensor.shape();
291
292        if dim >= first_shape.len() {
293            return Err(anyhow!(
294                "Dimension {} is out of bounds for tensor with {} dimensions",
295                dim,
296                first_shape.len()
297            ));
298        }
299
300        // Check that all tensors have compatible shapes
301        for (i, tensor) in tensors.iter().enumerate() {
302            let tensor_shape = tensor.shape();
303            if tensor_shape.len() != first_shape.len() {
304                return Err(anyhow!(
305                    "Tensor {} has {} dimensions, expected {}",
306                    i,
307                    tensor_shape.len(),
308                    first_shape.len()
309                ));
310            }
311
312            for (j, (&dim_size, &expected_size)) in
313                tensor_shape.iter().zip(first_shape.iter()).enumerate()
314            {
315                if j != dim && dim_size != expected_size {
316                    return Err(anyhow!(
317                        "Tensor {} has size {} in dimension {}, expected {}",
318                        i,
319                        dim_size,
320                        j,
321                        expected_size
322                    ));
323                }
324            }
325        }
326
327        let candle_tensors: Vec<&candle_core::Tensor> =
328            tensors.iter().map(|t| t.candle_tensor()).collect();
329
330        let result_candle = candle_core::Tensor::cat(&candle_tensors, dim)?;
331
332        Ok(Tensor::from_candle(
333            result_candle,
334            first_tensor.dtype(),
335            first_tensor.layout(),
336        ))
337    }
338
339    /// Repeat the tensor along specified dimensions.
340    pub fn repeat(&self, repeats: &[usize]) -> Result<Tensor> {
341        let shape = self.shape();
342
343        if repeats.len() != shape.len() {
344            return Err(anyhow!(
345                "Number of repeats ({}) must match tensor dimensions ({})",
346                repeats.len(),
347                shape.len()
348            ));
349        }
350
351        let result_candle = self.candle_tensor().repeat(repeats)?;
352
353        Ok(Tensor::from_candle(
354            result_candle,
355            self.dtype(),
356            self.layout(),
357        ))
358    }
359
360    /// Tile the tensor with the given multiples.
361    pub fn tile(&self, multiples: &[usize]) -> Result<Tensor> {
362        // Tile is similar to repeat but with different semantics
363        self.repeat(multiples)
364    }
365
366    /// Pad the tensor with zeros.
367    pub fn pad_zeros(&self, padding: &[(usize, usize)]) -> Result<Tensor> {
368        let shape = self.shape();
369
370        if padding.len() != shape.len() {
371            return Err(anyhow!(
372                "Padding length ({}) must match tensor dimensions ({})",
373                padding.len(),
374                shape.len()
375            ));
376        }
377
378        // Calculate new shape after padding
379        let new_shape: Vec<usize> = shape
380            .iter()
381            .zip(padding.iter())
382            .map(|(&dim, &(pad_before, pad_after))| dim + pad_before + pad_after)
383            .collect();
384
385        // Create a zero tensor with the new shape
386        let _padded = Tensor::zeros(new_shape, self.dtype(), self.layout())?;
387
388        // Calculate slice ranges for placing the original tensor
389        let _slice_ranges: Vec<(usize, usize)> = padding
390            .iter()
391            .zip(shape.iter())
392            .map(|(&(pad_before, _), &dim)| (pad_before, pad_before + dim))
393            .collect();
394
395        // This is a simplified implementation - in practice, we'd need more complex indexing
396        // For now, this is a placeholder that works for simple cases
397        if padding
398            .iter()
399            .all(|&(before, after)| before == 0 && after == 0)
400        {
401            // No padding needed
402            return Ok(self.clone());
403        }
404
405        // For non-zero padding, we'd need to implement tensor slicing assignment
406        // This is a complex operation that would require more advanced indexing
407        Err(anyhow!(
408            "Complex padding operations not yet fully implemented"
409        ))
410    }
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416    use crate::types::{DataType, TensorLayout};
417
418    #[test]
419    fn test_reshape() -> Result<()> {
420        let a = Tensor::from_data(
421            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
422            vec![2, 3],
423            DataType::F32,
424            TensorLayout::RowMajor,
425        )?;
426
427        let reshaped = a.reshape(&[3, 2])?;
428        assert_eq!(reshaped.shape(), vec![3, 2]);
429
430        let reshaped_1d = a.reshape(&[6])?;
431        assert_eq!(reshaped_1d.shape(), vec![6]);
432
433        Ok(())
434    }
435
436    #[test]
437    fn test_flatten() -> Result<()> {
438        let a = Tensor::from_data(
439            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
440            vec![2, 2, 2],
441            DataType::F32,
442            TensorLayout::RowMajor,
443        )?;
444
445        let flattened = a.flatten()?;
446        assert_eq!(flattened.shape(), vec![8]);
447
448        let flat_from = a.flatten_from(1)?;
449        assert_eq!(flat_from.shape(), vec![2, 4]);
450
451        Ok(())
452    }
453
454    #[test]
455    fn test_squeeze_unsqueeze() -> Result<()> {
456        let a = Tensor::from_data(
457            vec![1.0, 2.0, 3.0, 4.0],
458            vec![1, 2, 2, 1],
459            DataType::F32,
460            TensorLayout::RowMajor,
461        )?;
462
463        let squeezed = a.squeeze(None)?;
464        assert_eq!(squeezed.shape(), vec![2, 2]);
465
466        let squeeze_dim = a.squeeze_dim(0)?;
467        assert_eq!(squeeze_dim.shape(), vec![2, 2, 1]);
468
469        let unsqueezed = squeezed.unsqueeze(&[0])?;
470        assert_eq!(unsqueezed.shape(), vec![1, 2, 2]);
471
472        Ok(())
473    }
474
475    #[test]
476    fn test_permute() -> Result<()> {
477        let a = Tensor::from_data(
478            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
479            vec![2, 3],
480            DataType::F32,
481            TensorLayout::RowMajor,
482        )?;
483
484        let permuted = a.permute(&[1, 0])?;
485        assert_eq!(permuted.shape(), vec![3, 2]);
486
487        Ok(())
488    }
489
490    #[test]
491    fn test_slice() -> Result<()> {
492        let a = Tensor::from_data(
493            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
494            vec![2, 3],
495            DataType::F32,
496            TensorLayout::RowMajor,
497        )?;
498
499        let sliced = a.slice(1, 1, 3)?;
500        assert_eq!(sliced.shape(), vec![2, 2]);
501
502        Ok(())
503    }
504
505    #[test]
506    fn test_concat() -> Result<()> {
507        let a = Tensor::from_data(
508            vec![1.0, 2.0, 3.0, 4.0],
509            vec![2, 2],
510            DataType::F32,
511            TensorLayout::RowMajor,
512        )?;
513
514        let b = Tensor::from_data(
515            vec![5.0, 6.0, 7.0, 8.0],
516            vec![2, 2],
517            DataType::F32,
518            TensorLayout::RowMajor,
519        )?;
520
521        let concat_0 = Tensor::concat(&[&a, &b], 0)?;
522        assert_eq!(concat_0.shape(), vec![4, 2]);
523
524        let concat_1 = Tensor::concat(&[&a, &b], 1)?;
525        assert_eq!(concat_1.shape(), vec![2, 4]);
526
527        Ok(())
528    }
529
530    #[test]
531    fn test_stack() -> Result<()> {
532        let a = Tensor::from_data(
533            vec![1.0, 2.0, 3.0, 4.0],
534            vec![2, 2],
535            DataType::F32,
536            TensorLayout::RowMajor,
537        )?;
538
539        let b = Tensor::from_data(
540            vec![5.0, 6.0, 7.0, 8.0],
541            vec![2, 2],
542            DataType::F32,
543            TensorLayout::RowMajor,
544        )?;
545
546        let stacked_0 = Tensor::stack(&[&a, &b], 0)?;
547        assert_eq!(stacked_0.shape(), vec![2, 2, 2]);
548
549        let stacked_1 = Tensor::stack(&[&a, &b], 1)?;
550        assert_eq!(stacked_1.shape(), vec![2, 2, 2]);
551
552        Ok(())
553    }
554
555    #[test]
556    fn test_chunk() -> Result<()> {
557        let a = Tensor::from_data(
558            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
559            vec![6],
560            DataType::F32,
561            TensorLayout::RowMajor,
562        )?;
563
564        let chunks = a.chunk(3, 0)?;
565        assert_eq!(chunks.len(), 3);
566        assert_eq!(chunks[0].shape(), vec![2]);
567        assert_eq!(chunks[1].shape(), vec![2]);
568        assert_eq!(chunks[2].shape(), vec![2]);
569
570        Ok(())
571    }
572
573    #[test]
574    fn test_repeat() -> Result<()> {
575        let a = Tensor::from_data(
576            vec![1.0, 2.0],
577            vec![2],
578            DataType::F32,
579            TensorLayout::RowMajor,
580        )?;
581
582        let repeated = a.repeat(&[3])?;
583        assert_eq!(repeated.shape(), vec![6]);
584
585        let repeated_data = repeated.to_vec()?;
586        assert_eq!(repeated_data, vec![1.0, 2.0, 1.0, 2.0, 1.0, 2.0]);
587
588        Ok(())
589    }
590
591    #[test]
592    fn test_error_handling() {
593        let a = Tensor::from_data(
594            vec![1.0, 2.0, 3.0, 4.0],
595            vec![2, 2],
596            DataType::F32,
597            TensorLayout::RowMajor,
598        )
599        .unwrap();
600
601        // Invalid reshape
602        assert!(a.reshape(&[3, 2]).is_err());
603
604        // Out of bounds dimension for squeeze
605        assert!(a.squeeze_dim(5).is_err());
606
607        // Cannot squeeze dimension that's not size 1
608        assert!(a.squeeze_dim(0).is_err());
609
610        // Out of bounds dimension for unsqueeze
611        assert!(a.unsqueeze(&[5]).is_err());
612
613        // Invalid permutation
614        assert!(a.permute(&[0, 0]).is_err());
615        assert!(a.permute(&[0, 1, 2]).is_err());
616
617        // Invalid slice
618        assert!(a.slice(0, 5, 6).is_err());
619        assert!(a.slice(0, 2, 1).is_err());
620
621        // Empty concat
622        let empty_tensors: Vec<&Tensor> = vec![];
623        assert!(Tensor::concat(&empty_tensors, 0).is_err());
624
625        // Incompatible shapes for concat
626        let b = Tensor::from_data(
627            vec![1.0, 2.0, 3.0],
628            vec![3],
629            DataType::F32,
630            TensorLayout::RowMajor,
631        )
632        .unwrap();
633        assert!(Tensor::concat(&[&a, &b], 0).is_err());
634    }
635}