1use crate::tensor::Tensor;
7use anyhow::{Result, anyhow};
8
9pub trait ShapeOps {
11 fn reshape(&self, new_shape: &[usize]) -> Result<Tensor>;
13
14 fn flatten(&self) -> Result<Tensor>;
16
17 fn flatten_from(&self, start_dim: usize) -> Result<Tensor>;
19
20 fn squeeze(&self) -> Result<Tensor>;
22
23 fn squeeze_dim(&self, dim: usize) -> Result<Tensor>;
25
26 fn unsqueeze(&self, dim: usize) -> Result<Tensor>;
28
29 fn permute(&self, dims: &[usize]) -> Result<Tensor>;
31
32 fn expand(&self, new_shape: &[usize]) -> Result<Tensor>;
34
35 fn view(&self, new_shape: &[usize]) -> Result<Tensor>;
37}
38
39impl ShapeOps for Tensor {
40 fn reshape(&self, new_shape: &[usize]) -> Result<Tensor> {
41 let current_elements = self.numel();
43 let new_elements: usize = new_shape.iter().product();
44
45 if current_elements != new_elements {
46 return Err(anyhow!(
47 "Cannot reshape tensor with {} elements to shape {:?} ({} elements)",
48 current_elements,
49 new_shape,
50 new_elements
51 ));
52 }
53
54 let result_candle = self.candle_tensor().reshape(new_shape)?;
55
56 Ok(Tensor::from_candle(
57 result_candle,
58 self.dtype(),
59 self.layout(),
60 ))
61 }
62
63 fn flatten(&self) -> Result<Tensor> {
64 let total_elements = self.numel();
65 self.reshape(&[total_elements])
66 }
67
68 fn flatten_from(&self, start_dim: usize) -> Result<Tensor> {
69 let shape = self.shape();
70
71 if start_dim >= shape.len() {
72 return Err(anyhow!(
73 "start_dim {} is out of bounds for tensor with {} dimensions",
74 start_dim,
75 shape.len()
76 ));
77 }
78
79 if start_dim == 0 {
80 return self.flatten();
81 }
82
83 let mut new_shape = shape[..start_dim].to_vec();
85 let remaining_elements: usize = shape[start_dim..].iter().product();
86 new_shape.push(remaining_elements);
87
88 self.reshape(&new_shape)
89 }
90
91 fn squeeze(&self) -> Result<Tensor> {
92 let shape = self.shape();
93 let new_shape: Vec<usize> = shape.into_iter().filter(|&dim| dim != 1).collect();
94
95 if new_shape.is_empty() {
97 return self.reshape(&[1]);
98 }
99
100 self.reshape(&new_shape)
101 }
102
103 fn squeeze_dim(&self, dim: usize) -> Result<Tensor> {
104 let shape = self.shape();
105
106 if dim >= shape.len() {
107 return Err(anyhow!(
108 "Dimension {} is out of bounds for tensor with {} dimensions",
109 dim,
110 shape.len()
111 ));
112 }
113
114 if shape[dim] != 1 {
115 return Err(anyhow!(
116 "Cannot squeeze dimension {} with size {}",
117 dim,
118 shape[dim]
119 ));
120 }
121
122 let mut new_shape = shape;
123 new_shape.remove(dim);
124
125 if new_shape.is_empty() {
126 new_shape.push(1);
127 }
128
129 self.reshape(&new_shape)
130 }
131
132 fn unsqueeze(&self, dim: usize) -> Result<Tensor> {
133 let shape = self.shape();
134
135 if dim > shape.len() {
136 return Err(anyhow!(
137 "Dimension {} is out of bounds for unsqueeze (max {})",
138 dim,
139 shape.len()
140 ));
141 }
142
143 let mut new_shape = shape;
144 new_shape.insert(dim, 1);
145
146 self.reshape(&new_shape)
147 }
148
149 fn permute(&self, dims: &[usize]) -> Result<Tensor> {
150 let shape = self.shape();
151
152 if dims.len() != shape.len() {
153 return Err(anyhow!(
154 "Number of dimensions in permutation ({}) doesn't match tensor dimensions ({})",
155 dims.len(),
156 shape.len()
157 ));
158 }
159
160 let mut sorted_dims = dims.to_vec();
162 sorted_dims.sort_unstable();
163 let expected_dims: Vec<usize> = (0..shape.len()).collect();
164
165 if sorted_dims != expected_dims {
166 return Err(anyhow!("Invalid permutation: {:?}", dims));
167 }
168
169 let result_candle = self.candle_tensor().permute(dims)?;
170
171 Ok(Tensor::from_candle(
172 result_candle,
173 self.dtype(),
174 self.layout(),
175 ))
176 }
177
178 fn expand(&self, new_shape: &[usize]) -> Result<Tensor> {
179 let current_shape = self.shape();
180
181 if new_shape.len() < current_shape.len() {
183 return Err(anyhow!(
184 "Cannot expand tensor with {} dimensions to {} dimensions",
185 current_shape.len(),
186 new_shape.len()
187 ));
188 }
189
190 let offset = new_shape.len() - current_shape.len();
192 for (i, ¤t_dim) in current_shape.iter().enumerate() {
193 let new_dim = new_shape[offset + i];
194 if current_dim != 1 && current_dim != new_dim {
195 return Err(anyhow!(
196 "Cannot expand dimension {} from {} to {}",
197 offset + i,
198 current_dim,
199 new_dim
200 ));
201 }
202 }
203
204 let result_candle = self.candle_tensor().expand(new_shape)?;
205
206 Ok(Tensor::from_candle(
207 result_candle,
208 self.dtype(),
209 self.layout(),
210 ))
211 }
212
213 fn view(&self, new_shape: &[usize]) -> Result<Tensor> {
214 self.reshape(new_shape)
216 }
217}
218
219impl Tensor {
221 pub fn chunk(&self, chunks: usize, dim: usize) -> Result<Vec<Tensor>> {
223 let shape = self.shape();
224
225 if dim >= shape.len() {
226 return Err(anyhow!(
227 "Dimension {} is out of bounds for tensor with {} dimensions",
228 dim,
229 shape.len()
230 ));
231 }
232
233 let dim_size = shape[dim];
234 let chunk_size = (dim_size + chunks - 1) / chunks; let mut result = Vec::new();
237
238 for i in 0..chunks {
239 let start = i * chunk_size;
240 let end = std::cmp::min(start + chunk_size, dim_size);
241
242 if start >= dim_size {
243 break;
244 }
245
246 let chunk_tensor = self.slice(dim, start, end)?;
247 result.push(chunk_tensor);
248 }
249
250 Ok(result)
251 }
252
253 pub fn slice(&self, dim: usize, start: usize, end: usize) -> Result<Tensor> {
255 let shape = self.shape();
256
257 if dim >= shape.len() {
258 return Err(anyhow!(
259 "Dimension {} is out of bounds for tensor with {} dimensions",
260 dim,
261 shape.len()
262 ));
263 }
264
265 if start >= end || end > shape[dim] {
266 return Err(anyhow!(
267 "Invalid slice range: {}:{} for dimension of size {}",
268 start,
269 end,
270 shape[dim]
271 ));
272 }
273
274 let result_candle = self.candle_tensor().narrow(dim, start, end - start)?;
275
276 Ok(Tensor::from_candle(
277 result_candle,
278 self.dtype(),
279 self.layout(),
280 ))
281 }
282
283 pub fn concat(tensors: &[&Tensor], dim: usize) -> Result<Tensor> {
285 if tensors.is_empty() {
286 return Err(anyhow!("Cannot concatenate empty list of tensors"));
287 }
288
289 let first_tensor = tensors[0];
290 let first_shape = first_tensor.shape();
291
292 if dim >= first_shape.len() {
293 return Err(anyhow!(
294 "Dimension {} is out of bounds for tensor with {} dimensions",
295 dim,
296 first_shape.len()
297 ));
298 }
299
300 for (i, tensor) in tensors.iter().enumerate() {
302 let tensor_shape = tensor.shape();
303 if tensor_shape.len() != first_shape.len() {
304 return Err(anyhow!(
305 "Tensor {} has {} dimensions, expected {}",
306 i,
307 tensor_shape.len(),
308 first_shape.len()
309 ));
310 }
311
312 for (j, (&dim_size, &expected_size)) in
313 tensor_shape.iter().zip(first_shape.iter()).enumerate()
314 {
315 if j != dim && dim_size != expected_size {
316 return Err(anyhow!(
317 "Tensor {} has size {} in dimension {}, expected {}",
318 i,
319 dim_size,
320 j,
321 expected_size
322 ));
323 }
324 }
325 }
326
327 let candle_tensors: Vec<&candle_core::Tensor> =
328 tensors.iter().map(|t| t.candle_tensor()).collect();
329
330 let result_candle = candle_core::Tensor::cat(&candle_tensors, dim)?;
331
332 Ok(Tensor::from_candle(
333 result_candle,
334 first_tensor.dtype(),
335 first_tensor.layout(),
336 ))
337 }
338
339 pub fn repeat(&self, repeats: &[usize]) -> Result<Tensor> {
341 let shape = self.shape();
342
343 if repeats.len() != shape.len() {
344 return Err(anyhow!(
345 "Number of repeats ({}) must match tensor dimensions ({})",
346 repeats.len(),
347 shape.len()
348 ));
349 }
350
351 let result_candle = self.candle_tensor().repeat(repeats)?;
352
353 Ok(Tensor::from_candle(
354 result_candle,
355 self.dtype(),
356 self.layout(),
357 ))
358 }
359
360 pub fn tile(&self, multiples: &[usize]) -> Result<Tensor> {
362 self.repeat(multiples)
364 }
365
366 pub fn pad_zeros(&self, padding: &[(usize, usize)]) -> Result<Tensor> {
368 let shape = self.shape();
369
370 if padding.len() != shape.len() {
371 return Err(anyhow!(
372 "Padding length ({}) must match tensor dimensions ({})",
373 padding.len(),
374 shape.len()
375 ));
376 }
377
378 let new_shape: Vec<usize> = shape
380 .iter()
381 .zip(padding.iter())
382 .map(|(&dim, &(pad_before, pad_after))| dim + pad_before + pad_after)
383 .collect();
384
385 let _padded = Tensor::zeros(new_shape, self.dtype(), self.layout())?;
387
388 let _slice_ranges: Vec<(usize, usize)> = padding
390 .iter()
391 .zip(shape.iter())
392 .map(|(&(pad_before, _), &dim)| (pad_before, pad_before + dim))
393 .collect();
394
395 if padding
398 .iter()
399 .all(|&(before, after)| before == 0 && after == 0)
400 {
401 return Ok(self.clone());
403 }
404
405 Err(anyhow!(
408 "Complex padding operations not yet fully implemented"
409 ))
410 }
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416 use crate::types::{DataType, TensorLayout};
417
418 #[test]
419 fn test_reshape() -> Result<()> {
420 let a = Tensor::from_data(
421 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
422 vec![2, 3],
423 DataType::F32,
424 TensorLayout::RowMajor,
425 )?;
426
427 let reshaped = a.reshape(&[3, 2])?;
428 assert_eq!(reshaped.shape(), vec![3, 2]);
429
430 let reshaped_1d = a.reshape(&[6])?;
431 assert_eq!(reshaped_1d.shape(), vec![6]);
432
433 Ok(())
434 }
435
436 #[test]
437 fn test_flatten() -> Result<()> {
438 let a = Tensor::from_data(
439 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
440 vec![2, 2, 2],
441 DataType::F32,
442 TensorLayout::RowMajor,
443 )?;
444
445 let flattened = a.flatten()?;
446 assert_eq!(flattened.shape(), vec![8]);
447
448 let flat_from = a.flatten_from(1)?;
449 assert_eq!(flat_from.shape(), vec![2, 4]);
450
451 Ok(())
452 }
453
454 #[test]
455 fn test_squeeze_unsqueeze() -> Result<()> {
456 let a = Tensor::from_data(
457 vec![1.0, 2.0, 3.0, 4.0],
458 vec![1, 2, 2, 1],
459 DataType::F32,
460 TensorLayout::RowMajor,
461 )?;
462
463 let squeezed = a.squeeze(None)?;
464 assert_eq!(squeezed.shape(), vec![2, 2]);
465
466 let squeeze_dim = a.squeeze_dim(0)?;
467 assert_eq!(squeeze_dim.shape(), vec![2, 2, 1]);
468
469 let unsqueezed = squeezed.unsqueeze(&[0])?;
470 assert_eq!(unsqueezed.shape(), vec![1, 2, 2]);
471
472 Ok(())
473 }
474
475 #[test]
476 fn test_permute() -> Result<()> {
477 let a = Tensor::from_data(
478 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
479 vec![2, 3],
480 DataType::F32,
481 TensorLayout::RowMajor,
482 )?;
483
484 let permuted = a.permute(&[1, 0])?;
485 assert_eq!(permuted.shape(), vec![3, 2]);
486
487 Ok(())
488 }
489
490 #[test]
491 fn test_slice() -> Result<()> {
492 let a = Tensor::from_data(
493 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
494 vec![2, 3],
495 DataType::F32,
496 TensorLayout::RowMajor,
497 )?;
498
499 let sliced = a.slice(1, 1, 3)?;
500 assert_eq!(sliced.shape(), vec![2, 2]);
501
502 Ok(())
503 }
504
505 #[test]
506 fn test_concat() -> Result<()> {
507 let a = Tensor::from_data(
508 vec![1.0, 2.0, 3.0, 4.0],
509 vec![2, 2],
510 DataType::F32,
511 TensorLayout::RowMajor,
512 )?;
513
514 let b = Tensor::from_data(
515 vec![5.0, 6.0, 7.0, 8.0],
516 vec![2, 2],
517 DataType::F32,
518 TensorLayout::RowMajor,
519 )?;
520
521 let concat_0 = Tensor::concat(&[&a, &b], 0)?;
522 assert_eq!(concat_0.shape(), vec![4, 2]);
523
524 let concat_1 = Tensor::concat(&[&a, &b], 1)?;
525 assert_eq!(concat_1.shape(), vec![2, 4]);
526
527 Ok(())
528 }
529
530 #[test]
531 fn test_stack() -> Result<()> {
532 let a = Tensor::from_data(
533 vec![1.0, 2.0, 3.0, 4.0],
534 vec![2, 2],
535 DataType::F32,
536 TensorLayout::RowMajor,
537 )?;
538
539 let b = Tensor::from_data(
540 vec![5.0, 6.0, 7.0, 8.0],
541 vec![2, 2],
542 DataType::F32,
543 TensorLayout::RowMajor,
544 )?;
545
546 let stacked_0 = Tensor::stack(&[&a, &b], 0)?;
547 assert_eq!(stacked_0.shape(), vec![2, 2, 2]);
548
549 let stacked_1 = Tensor::stack(&[&a, &b], 1)?;
550 assert_eq!(stacked_1.shape(), vec![2, 2, 2]);
551
552 Ok(())
553 }
554
555 #[test]
556 fn test_chunk() -> Result<()> {
557 let a = Tensor::from_data(
558 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
559 vec![6],
560 DataType::F32,
561 TensorLayout::RowMajor,
562 )?;
563
564 let chunks = a.chunk(3, 0)?;
565 assert_eq!(chunks.len(), 3);
566 assert_eq!(chunks[0].shape(), vec![2]);
567 assert_eq!(chunks[1].shape(), vec![2]);
568 assert_eq!(chunks[2].shape(), vec![2]);
569
570 Ok(())
571 }
572
573 #[test]
574 fn test_repeat() -> Result<()> {
575 let a = Tensor::from_data(
576 vec![1.0, 2.0],
577 vec![2],
578 DataType::F32,
579 TensorLayout::RowMajor,
580 )?;
581
582 let repeated = a.repeat(&[3])?;
583 assert_eq!(repeated.shape(), vec![6]);
584
585 let repeated_data = repeated.to_vec()?;
586 assert_eq!(repeated_data, vec![1.0, 2.0, 1.0, 2.0, 1.0, 2.0]);
587
588 Ok(())
589 }
590
591 #[test]
592 fn test_error_handling() {
593 let a = Tensor::from_data(
594 vec![1.0, 2.0, 3.0, 4.0],
595 vec![2, 2],
596 DataType::F32,
597 TensorLayout::RowMajor,
598 )
599 .unwrap();
600
601 assert!(a.reshape(&[3, 2]).is_err());
603
604 assert!(a.squeeze_dim(5).is_err());
606
607 assert!(a.squeeze_dim(0).is_err());
609
610 assert!(a.unsqueeze(&[5]).is_err());
612
613 assert!(a.permute(&[0, 0]).is_err());
615 assert!(a.permute(&[0, 1, 2]).is_err());
616
617 assert!(a.slice(0, 5, 6).is_err());
619 assert!(a.slice(0, 2, 1).is_err());
620
621 let empty_tensors: Vec<&Tensor> = vec![];
623 assert!(Tensor::concat(&empty_tensors, 0).is_err());
624
625 let b = Tensor::from_data(
627 vec![1.0, 2.0, 3.0],
628 vec![3],
629 DataType::F32,
630 TensorLayout::RowMajor,
631 )
632 .unwrap();
633 assert!(Tensor::concat(&[&a, &b], 0).is_err());
634 }
635}